diff --git a/.ci/aarch64_linux/aarch64_wheel_ci_build.py b/.ci/aarch64_linux/aarch64_wheel_ci_build.py index 92dabf0fee48..1cce2836974d 100755 --- a/.ci/aarch64_linux/aarch64_wheel_ci_build.py +++ b/.ci/aarch64_linux/aarch64_wheel_ci_build.py @@ -136,6 +136,9 @@ def complete_wheel(folder: str) -> str: """ wheel_name = list_dir(f"/{folder}/dist")[0] + # Please note for cuda we don't run auditwheel since we use custom script to package + # the cuda dependencies to the wheel file using update_wheel() method. + # However we need to make sure filename reflects the correct Manylinux platform. if "pytorch" in folder and not enable_cuda: print("Repairing Wheel with AuditWheel") check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder) @@ -147,7 +150,14 @@ def complete_wheel(folder: str) -> str: f"/{folder}/dist/{repaired_wheel_name}", ) else: - repaired_wheel_name = wheel_name + repaired_wheel_name = wheel_name.replace( + "linux_aarch64", "manylinux_2_28_aarch64" + ) + print(f"Renaming {wheel_name} wheel to {repaired_wheel_name}") + os.rename( + f"/{folder}/dist/{wheel_name}", + f"/{folder}/dist/{repaired_wheel_name}", + ) print(f"Copying {repaired_wheel_name} to artifacts") shutil.copy2( diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index 5f17a6332dd1..7548bd28bcc0 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -44,6 +44,8 @@ FROM base as cuda ARG CUDA_VERSION=12.4 RUN rm -rf /usr/local/cuda-* ADD ./common/install_cuda.sh install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION} # Preserve CUDA_VERSION for the builds ENV CUDA_VERSION=${CUDA_VERSION} diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 07e991658b7a..1e1ec8b491ae 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -460,10 +460,18 @@ if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then fi fi +no_cache_flag="" +progress_flag="" +# Do not use cache and progress=plain when in CI +if [[ -n "${CI:-}" ]]; then + no_cache_flag="--no-cache" + progress_flag="--progress=plain" +fi + # Build image docker build \ - --no-cache \ - --progress=plain \ + ${no_cache_flag} \ + ${progress_flag} \ --build-arg "BUILD_ENVIRONMENT=${image}" \ --build-arg "PROTOBUF=${PROTOBUF:-}" \ --build-arg "LLVMDEV=${LLVMDEV:-}" \ diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index dc4f8b30fe87..39005b14ab7e 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -cedf52aa8e4df879886270a5920da6fe84cbaa67 +7e487c24e1c20c3f4606c2d8aca2778873b00b4c diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index 10f3c7733f4f..3959880b53c5 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -2,7 +2,6 @@ set -ex -NCCL_VERSION=v2.26.2-1 CUDNN_VERSION=9.5.1.17 function install_cusparselt_040 { @@ -40,8 +39,7 @@ function install_cusparselt_063 { function install_118 { CUDNN_VERSION=9.1.0.70 - NCCL_VERSION=v2.21.5-1 - echo "Installing CUDA 11.8 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.4.0" + echo "Installing CUDA 11.8 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.4.0" rm -rf /usr/local/cuda-11.8 /usr/local/cuda # install CUDA 11.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run @@ -59,14 +57,7 @@ function install_118 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=11.8 bash install_nccl.sh install_cusparselt_040 @@ -75,7 +66,7 @@ function install_118 { function install_124 { CUDNN_VERSION=9.1.0.70 - echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.2" + echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2" rm -rf /usr/local/cuda-12.4 /usr/local/cuda # install CUDA 12.4.1 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run @@ -93,14 +84,7 @@ function install_124 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.4 bash install_nccl.sh install_cusparselt_062 @@ -108,7 +92,7 @@ function install_124 { } function install_126 { - echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.3" + echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" rm -rf /usr/local/cuda-12.6 /usr/local/cuda # install CUDA 12.6.3 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run @@ -126,14 +110,7 @@ function install_126 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.6 bash install_nccl.sh install_cusparselt_063 @@ -241,7 +218,7 @@ function prune_126 { function install_128 { CUDNN_VERSION=9.8.0.87 - echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.3" + echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" rm -rf /usr/local/cuda-12.8 /usr/local/cuda # install CUDA 12.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run @@ -259,14 +236,7 @@ function install_128 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.8 bash install_nccl.sh install_cusparselt_063 diff --git a/.ci/docker/common/install_cuda_aarch64.sh b/.ci/docker/common/install_cuda_aarch64.sh index 3f154a103aa7..ae4983712989 100644 --- a/.ci/docker/common/install_cuda_aarch64.sh +++ b/.ci/docker/common/install_cuda_aarch64.sh @@ -3,7 +3,6 @@ set -ex -NCCL_VERSION=v2.26.2-1 CUDNN_VERSION=9.8.0.87 function install_cusparselt_063 { @@ -18,7 +17,7 @@ function install_cusparselt_063 { } function install_128 { - echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.3" + echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" rm -rf /usr/local/cuda-12.8 /usr/local/cuda # install CUDA 12.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux_sbsa.run @@ -36,14 +35,7 @@ function install_128 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b ${NCCL_VERSION} --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.8 bash install_nccl.sh install_cusparselt_063 diff --git a/.ci/docker/common/install_executorch.sh b/.ci/docker/common/install_executorch.sh index a9a558b86f99..e30e0a787bbe 100755 --- a/.ci/docker/common/install_executorch.sh +++ b/.ci/docker/common/install_executorch.sh @@ -50,8 +50,7 @@ setup_executorch() { pushd executorch export PYTHON_EXECUTABLE=python - export EXECUTORCH_BUILD_PYBIND=ON - export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" + export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" as_jenkins .ci/scripts/setup-linux.sh --build-tool cmake || true popd diff --git a/.ci/docker/common/install_halide.sh b/.ci/docker/common/install_halide.sh index 0cfcfbce107b..ed1d7d33649d 100644 --- a/.ci/docker/common/install_halide.sh +++ b/.ci/docker/common/install_halide.sh @@ -35,7 +35,9 @@ git clone https://github.com/halide/Halide.git pushd Halide git checkout ${COMMIT} && git submodule update --init --recursive pip_install -r requirements.txt -cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -S . -B build +# NOTE: pybind has a requirement for cmake > 3.5 so set the minimum cmake version here with a flag +# Context: https://github.com/pytorch/pytorch/issues/150420 +cmake -G Ninja -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_BUILD_TYPE=Release -S . -B build cmake --build build test -e ${CONDA_PREFIX}/lib/python3 || ln -s python${ANACONDA_PYTHON_VERSION} ${CONDA_PREFIX}/lib/python3 cmake --install build --prefix ${CONDA_PREFIX} diff --git a/.ci/docker/common/install_nccl.sh b/.ci/docker/common/install_nccl.sh new file mode 100644 index 000000000000..17d80ebe7d27 --- /dev/null +++ b/.ci/docker/common/install_nccl.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -ex + +NCCL_VERSION="" +if [[ ${CUDA_VERSION:0:2} == "11" ]]; then + NCCL_VERSION=$(cat ci_commit_pins/nccl-cu11.txt) +elif [[ ${CUDA_VERSION:0:2} == "12" ]]; then + NCCL_VERSION=$(cat ci_commit_pins/nccl-cu12.txt) +else + echo "Unexpected CUDA_VERSION ${CUDA_VERSION}" + exit 1 +fi + +if [[ -n "${NCCL_VERSION}" ]]; then + # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses + # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build + git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git + pushd nccl + make -j src.build + cp -a build/include/* /usr/local/cuda/include/ + cp -a build/lib/* /usr/local/cuda/lib64/ + popd + rm -rf nccl + ldconfig +fi diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile index f9ae32ad7f8e..e90306767ff6 100644 --- a/.ci/docker/libtorch/Dockerfile +++ b/.ci/docker/libtorch/Dockerfile @@ -49,6 +49,8 @@ RUN bash ./install_mkl.sh && rm install_mkl.sh FROM cpu as cuda ADD ./common/install_cuda.sh install_cuda.sh ADD ./common/install_magma.sh install_magma.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ ENV CUDA_HOME /usr/local/cuda FROM cuda as cuda11.8 diff --git a/.ci/docker/linter-cuda/Dockerfile b/.ci/docker/linter-cuda/Dockerfile index d93f69a149f2..ed8fc7eabba5 100644 --- a/.ci/docker/linter-cuda/Dockerfile +++ b/.ci/docker/linter-cuda/Dockerfile @@ -30,7 +30,9 @@ RUN bash ./install_python.sh && rm install_python.sh /opt/requirements-ci.txt # Install cuda and cudnn ARG CUDA_VERSION COPY ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh install_nccl.sh /ci_commit_pins/nccl-cu* ENV DESIRED_CUDA ${CUDA_VERSION} ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH diff --git a/.ci/docker/manywheel/Dockerfile b/.ci/docker/manywheel/Dockerfile index d7daf989b496..75f2ab9a5ce0 100644 --- a/.ci/docker/manywheel/Dockerfile +++ b/.ci/docker/manywheel/Dockerfile @@ -64,7 +64,9 @@ FROM base as cuda ARG BASE_CUDA_VERSION=10.2 # Install CUDA ADD ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh install_nccl.sh /ci_commit_pins/nccl-cu* FROM base as intel # MKL diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index e3ac65f5ca21..fbf74fb81c01 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -36,7 +36,9 @@ FROM base as cuda ARG BASE_CUDA_VERSION=11.8 # Install CUDA ADD ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh install_nccl.sh ci_commit_pins/nccl-cu* FROM base as intel # MKL diff --git a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 index dfd766b4dd5a..fe2a04fd92db 100644 --- a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 @@ -67,7 +67,9 @@ FROM base as cuda ARG BASE_CUDA_VERSION # Install CUDA ADD ./common/install_cuda_aarch64.sh install_cuda_aarch64.sh -RUN bash ./install_cuda_aarch64.sh ${BASE_CUDA_VERSION} && rm install_cuda_aarch64.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda_aarch64.sh ${BASE_CUDA_VERSION} && rm install_cuda_aarch64.sh install_nccl.sh ci_commit_pins/nccl-cu* FROM base as magma ARG BASE_CUDA_VERSION diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index c9579950e0ac..4739271899c3 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -158,6 +158,16 @@ COPY ./common/install_cusparselt.sh install_cusparselt.sh RUN bash install_cusparselt.sh RUN rm install_cusparselt.sh +# Install NCCL +ARG CUDA_VERSION +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash install_nccl.sh +RUN rm install_nccl.sh /ci_commit_pins/nccl-cu* +ENV USE_SYSTEM_NCCL=1 +ENV NCCL_INCLUDE_DIR="/usr/local/cuda/include/" +ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" + # Install CUDSS ARG CUDA_VERSION COPY ./common/install_cudss.sh install_cudss.sh diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 11888f37bff2..c33abda4aaa7 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -52,9 +52,16 @@ RUN bash ./install_lcov.sh && rm install_lcov.sh # Install cuda and cudnn ARG CUDA_VERSION COPY ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh install_nccl.sh /ci_commit_pins/nccl-cu* ENV DESIRED_CUDA ${CUDA_VERSION} ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH +# No effect if cuda not installed +ENV USE_SYSTEM_NCCL=1 +ENV NCCL_INCLUDE_DIR="/usr/local/cuda/include/" +ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" + # (optional) Install UCC ARG UCX_COMMIT diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index 4a2f63a2ed10..d538581c09a6 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -33,55 +33,11 @@ if which sccache > /dev/null; then export PATH="${tmp_dir}:$PATH" fi -cross_compile_arm64() { - # Cross compilation for arm64 - # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests - # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 - USE_DISTRIBUTED=0 CMAKE_OSX_ARCHITECTURES=arm64 MACOSX_DEPLOYMENT_TARGET=11.0 USE_MKLDNN=OFF USE_QNNPACK=OFF WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel -} - -compile_arm64() { - # Compilation for arm64 - # TODO: Compile with OpenMP support (but this causes CI regressions as cross-compilation were done with OpenMP disabled) - USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel -} - -compile_x86_64() { - USE_DISTRIBUTED=0 WERROR=1 python setup.py bdist_wheel --plat-name=macosx_10_9_x86_64 -} - -build_lite_interpreter() { - echo "Testing libtorch (lite interpreter)." - - CPP_BUILD="$(pwd)/../cpp_build" - # Ensure the removal of the tmp directory - trap 'rm -rfv ${CPP_BUILD}' EXIT - rm -rf "${CPP_BUILD}" - mkdir -p "${CPP_BUILD}/caffe2" - - # It looks libtorch need to be built in "${CPP_BUILD}/caffe2 folder. - BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py - pushd "${CPP_BUILD}/caffe2" || exit - VERBOSE=1 DEBUG=1 python "${BUILD_LIBTORCH_PY}" - popd || exit - - "${CPP_BUILD}/caffe2/build/bin/test_lite_interpreter_runtime" -} - print_cmake_info -if [[ ${BUILD_ENVIRONMENT} = *arm64* ]]; then - if [[ $(uname -m) == "arm64" ]]; then - compile_arm64 - else - cross_compile_arm64 - fi -elif [[ ${BUILD_ENVIRONMENT} = *lite-interpreter* ]]; then - export BUILD_LITE_INTERPRETER=1 - build_lite_interpreter -else - compile_x86_64 -fi +# Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests +# that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 +USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel if which sccache > /dev/null; then print_sccache_stats diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index c4f41a874774..24d1d64dd205 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -227,7 +227,10 @@ def compare_pypi_to_torch_versions( def smoke_test_cuda( - package: str, runtime_error_check: str, torch_compile_check: str + package: str, + runtime_error_check: str, + torch_compile_check: str, + pypi_pkg_check: str, ) -> None: if not torch.cuda.is_available() and is_cuda_system: raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.") @@ -269,12 +272,15 @@ def smoke_test_cuda( torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") - # Pypi dependencies are installed on linux ony and nccl is availbale only on Linux. if sys.platform in ["linux", "linux2"]: + torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) + print(f"Torch nccl; version: {torch_nccl_version}") + + # Pypi dependencies are installed on linux ony and nccl is availbale only on Linux. + if pypi_pkg_check == "enabled" and sys.platform in ["linux", "linux2"]: compare_pypi_to_torch_versions( "cudnn", find_pypi_package_version("nvidia-cudnn"), torch_cudnn_version ) - torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) compare_pypi_to_torch_versions( "nccl", find_pypi_package_version("nvidia-nccl"), torch_nccl_version ) @@ -436,6 +442,13 @@ def parse_args(): choices=["enabled", "disabled"], default="enabled", ) + parser.add_argument( + "--pypi-pkg-check", + help="Check pypi package versions cudnn and nccl", + type=str, + choices=["enabled", "disabled"], + default="enabled", + ) return parser.parse_args() @@ -460,7 +473,10 @@ def main() -> None: smoke_test_modules() smoke_test_cuda( - options.package, options.runtime_error_check, options.torch_compile_check + options.package, + options.runtime_error_check, + options.torch_compile_check, + options.pypi_pkg_check, ) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 96a160cf618d..1e6b50f04f26 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1175,7 +1175,6 @@ build_xla() { # These functions are defined in .circleci/common.sh in pytorch/xla repo retry install_pre_deps_pytorch_xla $XLA_DIR $USE_CACHE CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch:${CMAKE_PREFIX_PATH}" XLA_SANDBOX_BUILD=1 build_torch_xla $XLA_DIR - retry install_post_deps_pytorch_xla assert_git_not_dirty } @@ -1475,8 +1474,7 @@ test_executorch() { pushd /executorch export PYTHON_EXECUTABLE=python - export EXECUTORCH_BUILD_PYBIND=ON - export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" + export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" # For llama3 bash examples/models/llama3_2_vision/install_requirements.sh diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 3ee84f46d8fa..051b4f16f27a 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -90,8 +90,17 @@ fi /pytorch/.ci/pytorch/check_binary.sh if [[ "\$GPU_ARCH_TYPE" != *s390x* && "\$GPU_ARCH_TYPE" != *xpu* && "\$GPU_ARCH_TYPE" != *rocm* && "$PACKAGE_TYPE" != libtorch ]]; then - # Exclude s390, xpu, rocm and libtorch builds from smoke testing - python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled + + torch_pkg_size="$(ls -1 /final_pkgs/torch-* | sort |tail -1 |xargs wc -c |cut -d ' ' -f1)" + # todo: implement check for large binaries + # if the package is larger than 1.5GB, we disable the pypi check. + # this package contains all libraries packaged in torch libs folder + # example of such package is https://download.pytorch.org/whl/cu126_full/torch + if [[ "\$torch_pkg_size" -gt 1500000000 ]]; then + python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled --pypi-pkg-check disabled + else + python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled $extra_parameters + fi fi # Clean temp files diff --git a/.clang-tidy b/.clang-tidy index df40a6df91c0..4b1548d646b2 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -52,7 +52,6 @@ modernize-*, -modernize-macro-to-enum, -modernize-return-braced-init-list, -modernize-use-auto, --modernize-use-default-member-init, -modernize-use-using, -modernize-use-trailing-return-type, -modernize-use-nodiscard, diff --git a/.github/ISSUE_TEMPLATE/pt2-bug-report.yml b/.github/ISSUE_TEMPLATE/pt2-bug-report.yml index be22b1446b4e..2f8ab54a2337 100644 --- a/.github/ISSUE_TEMPLATE/pt2-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/pt2-bug-report.yml @@ -20,7 +20,7 @@ body: - Don't compare indices of max/min etc, because that avoids the above requirement - - If comparing eager and torch.compile at fp16/bf16, you should use fp32 as baseline + - When comparing eager and torch.compile, use a higher precision result as a baseline. `torch._dynamo.utils.same` with fp64_ref will handle this comparison. - Ensure rng state used to compare results is equivalent. Use `torch._inductor.config.fallback_random=True` and reset the torch rng seed between comparisons diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index c33e09d37efc..1c44ba1f888a 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -3,9 +3,6 @@ self-hosted-runner: # GitHub hosted runner that actionlint doesn't recognize because actionlint version (1.6.21) is too old - ubuntu-24.04 # GitHub hosted x86 Linux runners - # TODO: Cleanup mentions of linux.20_04 when upgrade to linux.24_04 is complete - - linux.20_04.4x - - linux.20_04.16x - linux.24_04.4x - linux.24_04.16x # Organization-wide AWS Linux Runners diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index 5effc5f3689a..fe949516402d 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -48,14 +48,8 @@ runs: run: | # Remove any previous usage logs if they exist rm -f logs-*.zip - # this workflow is also run in bazel build test, but we dont generate usage reports for it - # so check to see if the file exists first - if [ -f 'usage_log.txt' ]; then - zip "logs-${FILE_SUFFIX}.zip" 'usage_log.txt' - fi - if find "test/test-reports" -name "*.log" 2>/dev/null | stdbuf -o0 grep -q .; then - zip -r "logs-${FILE_SUFFIX}.zip" test/test-reports -i '*.log' - fi + zip "logs-${FILE_SUFFIX}.zip" 'usage_log.txt' || true + zip -r "logs-${FILE_SUFFIX}.zip" test/test-reports -i '*.log' || true - name: Zip debugging artifacts for upload if: runner.os != 'Windows' && !inputs.use-gha diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 71fe6e9fb351..d585cc27cdab 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -318bace01aebc1f82ae13d0d133fcf9fede73383 +bccaa454a54c3c648697cc2f46a4fb0500b1f01b diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 2925b494d999..96bf43f4c0e2 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -760675ad9aa8e7202d4f9f51fe862e8a9bedb713 +ac9a39f4b768cef09b9d2be8e074be496d7783b6 diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 0a091ecadbe5..bae188d2a335 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -540,6 +540,7 @@ - bdhirsh - zou3519 - isuruf + - Chillee mandatory_checks_name: - EasyCLA - Lint diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 306061787d58..4f29628373e4 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -294,6 +294,7 @@ class OperatingSystem: BinaryBuildWorkflow( os=OperatingSystem.WINDOWS_ARM64, package_type="libtorch", + build_variant=generate_binary_build_matrix.DEBUG, build_configs=generate_binary_build_matrix.generate_libtorch_matrix( OperatingSystem.WINDOWS_ARM64, generate_binary_build_matrix.DEBUG, diff --git a/.github/scripts/windows/build_magma.bat b/.github/scripts/windows/build_magma.bat index beabb0070554..b8701ddde3fc 100644 --- a/.github/scripts/windows/build_magma.bat +++ b/.github/scripts/windows/build_magma.bat @@ -54,7 +54,8 @@ cmake .. -DGPU_TARGET="%GPU_TARGET%" ^ -DCMAKE_BUILD_TYPE=%CONFIG% ^ -DCMAKE_GENERATOR=Ninja ^ -DCMAKE_INSTALL_PREFIX=..\install\ ^ - -DCUDA_ARCH_LIST="%CUDA_ARCH_LIST%" + -DCUDA_ARCH_LIST="%CUDA_ARCH_LIST%" ^ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 if errorlevel 1 exit /b 1 cmake --build . --target install --config %CONFIG% -- -j%NUMBER_OF_PROCESSORS% diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index 245740c66cdb..97cd535a4988 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -9,7 +9,8 @@ if "%PY_VERS%" == "3.13t" ( ) else ( call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% ) -call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake setuptools==72.1.0 ninja +:: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja dir "%VC_INSTALL_PATH%" diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index 57a66798468f..507d5419a042 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -23,7 +23,7 @@ on: description: Hardware to run this "build" job on, linux.12xlarge or linux.arm64.2xlarge. timeout-minutes: required: false - default: 210 + default: 240 type: number description: timeout for the job use_split_build: diff --git a/.github/workflows/build-magma-windows.yml b/.github/workflows/build-magma-windows.yml index 85f2884e5351..4a3fb9855a06 100644 --- a/.github/workflows/build-magma-windows.yml +++ b/.github/workflows/build-magma-windows.yml @@ -22,7 +22,7 @@ jobs: runs-on: windows-2019 strategy: matrix: - cuda_version: ["128", "126", "124", "118"] + cuda_version: ["128", "126", "118"] config: ["Release", "Debug"] env: CUDA_VERSION: ${{ matrix.cuda_version }} diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 9921b018fcc3..99d71c7082b7 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -12,6 +12,8 @@ on: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt + - .github/scripts/windows/install_vs2022.ps1 + - .github/scripts/windows/build_triton.bat - .ci/docker/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton-xpu.txt workflow_dispatch: @@ -20,6 +22,8 @@ on: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt + - .github/scripts/windows/install_vs2022.ps1 + - .github/scripts/windows/build_triton.bat - .ci/docker/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton-xpu.txt @@ -134,7 +138,7 @@ jobs: fi docker exec -t "${container_name}" yum install -y zlib-devel zip - docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==67.4.0 pybind11==2.13.1 auditwheel wheel + docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "rocm" || "${{ matrix.device }}" == "aarch64" ) ]]; then # With this install, it gets clang 16.0.6. @@ -244,7 +248,6 @@ jobs: .github/scripts/windows/build_triton.bat mkdir -p "${RUNNER_TEMP}/artifacts/" mv ./*.whl "${RUNNER_TEMP}/artifacts/" - - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }} diff --git a/.github/workflows/generated-windows-arm64-binary-libtorch-nightly.yml b/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml similarity index 98% rename from .github/workflows/generated-windows-arm64-binary-libtorch-nightly.yml rename to .github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml index a70e26c114cc..42e1e18d5dc7 100644 --- a/.github/workflows/generated-windows-arm64-binary-libtorch-nightly.yml +++ b/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml @@ -2,7 +2,7 @@ # Template is at: .github/templates/windows_arm64_binary_build_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: windows-arm64-binary-libtorch +name: windows-arm64-binary-libtorch-debug on: push: @@ -17,7 +17,7 @@ on: workflow_dispatch: env: - BUILD_ENVIRONMENT: windows-arm64-binary-libtorch + BUILD_ENVIRONMENT: windows-arm64-binary-libtorch-debug GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index 076e67b08ed2..55a37f031fdf 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -4,6 +4,9 @@ on: pull_request: paths: - .github/workflows/inductor-nightly.yml + - benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv + - benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_timm_inference.csv + - benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv workflow_dispatch: schedule: # Run every day at 7:00 AM UTC diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index df84b158f1c0..f1ff593161db 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -78,12 +78,12 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-inductor-benchmark-build: + linux-focal-rocm-py3_10-inductor-benchmark-build: if: github.repository_owner == 'pytorch' - name: rocm6_3-py3_10-inductor-benchmark-build + name: rocm-py3_10-inductor-benchmark-build uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6_3-py3_10 + build-environment: linux-focal-rocm-py3_10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -102,18 +102,18 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-inductor-benchmark-test: + linux-focal-rocm-py3_10-inductor-benchmark-test: permissions: id-token: write contents: read - name: rocm6_3-py3_10-inductor-benchmark-test + name: rocm-py3_10-inductor-benchmark-test uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-inductor-benchmark-build + needs: linux-focal-rocm-py3_10-inductor-benchmark-build with: - build-environment: linux-focal-rocm6_3-py3_10 + build-environment: linux-focal-rocm-py3_10 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-benchmark-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-benchmark-build.outputs.test-matrix }} + docker-image: ${{ needs.linux-focal-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }} timeout-minutes: 720 # Disable monitor in perf tests for more investigation disable-monitor: true diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 3d7a1c7da941..6d08179df512 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -67,12 +67,12 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit - linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build: + linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build: if: github.repository_owner == 'pytorch' - name: rocm6_3-py3_10-periodic-dynamo-benchmarks + name: rocm-py3_10-periodic-dynamo-benchmarks uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6_3-py3_10 + build-environment: linux-focal-rocm-py3_10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -95,17 +95,17 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-test: + linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-test: permissions: id-token: write contents: read - name: rocm6_3-py3_10-periodic-dynamo-benchmarks + name: rocm-py3_10-periodic-dynamo-benchmarks uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build + needs: linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build with: - build-environment: linux-focal-rocm6_3-py3_10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3_10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit linux-focal-cuda12_6-py3_10-gcc9-inductor-build-gcp: diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index bddb625cbc90..753c30e6427a 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -36,13 +36,13 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-inductor-build: - name: rocm6.3-py3.10-inductor + linux-focal-rocm-py3_10-inductor-build: + name: rocm-py3.10-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -51,15 +51,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-inductor-test: + linux-focal-rocm-py3_10-inductor-test: permissions: id-token: write contents: read - name: rocm6.3-py3.10-inductor + name: rocm-py3.10-inductor uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-inductor-build + needs: linux-focal-rocm-py3_10-inductor-build with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index bcbbe0dd85bc..0d21b4570c2e 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -29,13 +29,13 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-inductor-build: - name: rocm6.3-py3.10-inductor + linux-focal-rocm-py3_10-inductor-build: + name: rocm-py3.10-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -44,15 +44,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-inductor-test: + linux-focal-rocm-py3_10-inductor-test: permissions: id-token: write contents: read - name: rocm6.3-py3.10-inductor + name: rocm-py3.10-inductor uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-inductor-build + needs: linux-focal-rocm-py3_10-inductor-build with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9d72b0c5bbb6..db00515a79b6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -233,10 +233,8 @@ jobs: runner: linux.24_04.4x - test_type: without_torch runner: linux.24_04.4x - # NOTE: The oldest supported version of python for 24.04 is 3.8 - # so this cannot be updated if we want to keep this test at 3.6 - test_type: older_python_version - runner: linux.20_04.4x + runner: linux.24_04.4x steps: # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required, to allow us to use git log @@ -256,7 +254,7 @@ jobs: if: matrix.test_type == 'older_python_version' uses: actions/setup-python@v5 with: - python-version: 3.6 + python-version: 3.8 architecture: x64 check-latest: false cache: pip diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 7da1b438c7e9..805a7d328575 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -11,6 +11,9 @@ on: type: string default: 'short' description: tag filter for operator benchmarks, options from long, short, all + schedule: + # Run at 07:00 UTC every Sunday + - cron: 0 7 * * 0 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 8aadcd548b7e..686f5b83a92e 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -140,13 +140,13 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-debug-build.outputs.test-matrix }} secrets: inherit - linux-focal-rocm6_3-py3_10-build: - name: linux-focal-rocm6.3-py3.10 + linux-focal-rocm-py3_10-build: + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -156,19 +156,19 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-build: diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index bb967e2f3e82..e4ee18664ba5 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -411,15 +411,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: # don't run build twice on main if: github.event_name == 'pull_request' - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index f1a16ddea234..cce7ff72cdc4 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -36,14 +36,14 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -57,17 +57,17 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 6ff8667a9d94..063daaf4fe67 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -26,12 +26,12 @@ jobs: id-token: write contents: read - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -45,17 +45,17 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 1d1b8d5eb567..0a8cf3721e70 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -103,13 +103,13 @@ jobs: test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} secrets: inherit - linux-focal-rocm6_3-py3_10-build: - name: linux-focal-rocm6.3-py3.10 + linux-focal-rocm-py3_10-build: + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -118,19 +118,19 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit linux-jammy-py3_10-clang15-asan-build: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index ec98f4faf3c6..15d2b53ed81b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -165,14 +165,14 @@ jobs: runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" secrets: inherit - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -183,20 +183,20 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" secrets: inherit diff --git a/.lintrunner.toml b/.lintrunner.toml index e7541e6dabe5..ec62529d1f49 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -55,6 +55,7 @@ init_command = [ code = 'CLANGFORMAT' include_patterns = [ 'aten/src/ATen/*.h', + 'aten/src/ATen/cpu/vec/*.h', 'aten/src/ATen/mps/**/*.mm', 'aten/src/ATen/mps/**/*.h', 'aten/src/ATen/xpu/**/*.h', @@ -271,6 +272,7 @@ exclude_patterns = [ 'torch/csrc/utils/generated_serialization_types.h', 'torch/csrc/utils/pythoncapi_compat.h', 'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h', + 'aten/src/ATen/ExpandBase.h', ] init_command = [ 'python3', diff --git a/README.md b/README.md index 299cca8e34fd..5085abc87b7d 100644 --- a/README.md +++ b/README.md @@ -169,8 +169,6 @@ Professional, or Community Editions. You can also install the build tools from https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* come with Visual Studio Code by default. -\* We highly recommend installing an [Anaconda](https://www.anaconda.com/download) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. - An example of environment setup is shown below: * Linux: @@ -223,7 +221,7 @@ Other potentially useful environment variables may be found in `setup.py`. #### Get the PyTorch Source ```bash -git clone --recursive https://github.com/pytorch/pytorch +git clone https://github.com/pytorch/pytorch cd pytorch # if you are updating an existing checkout git submodule sync @@ -355,6 +353,16 @@ Please make sure [the common prerequisites](#prerequisites) as well as [the prer Then PyTorch can be built with the command: ```cmd +:: CMD Commands: +:: Set the CMAKE_PREFIX_PATH to help find corresponding packages +:: %CONDA_PREFIX% only works after `conda activate custom_env` + +if defined CMAKE_PREFIX_PATH ( + set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library;%CMAKE_PREFIX_PATH%" +) else ( + set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library" +) + python setup.py develop ``` diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index e10fdb7e88ee..d939b7b7b084 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -384,12 +384,11 @@ endif() ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) - if(WIN32) # Windows doesn't support Composable Kernels and Triton + if(WIN32) # Windows doesn't support Composable Kernels file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" - ${native_hip_bgemm} ${native_hip_ck} - ${native_transformers_hip_hip} ${native_transformers_hip_cpp}) + ${native_hip_bgemm} ${native_hip_ck}) endif() # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) list(APPEND all_hip_cpp @@ -408,9 +407,6 @@ endif() ${miopen_cpp} ${all_hip_cpp} ) - if(WIN32) # Windows doesn't support Triton - exclude(all_hip_cpp "${all_hip_cpp}" ${native_transformers_hip_cpp}) - endif() endif() if(USE_XPU) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 2b6cbfa6e7bf..b5ce540b52ab 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -340,7 +340,7 @@ at::BlasBackend Context::blasPreferredBackend() { #endif }; for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { return false; } } @@ -359,14 +359,14 @@ at::BlasBackend Context::blasPreferredBackend() { static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 - "gfx1100", "gfx1101", "gfx1200", "gfx1201" + "gfx1100", "gfx1101", "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" #endif }; for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { TORCH_WARN_ONCE( "Attempting to use hipBLASLt on an unsupported architecture! " "Overriding blas backend to hipblas"); @@ -419,7 +419,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { "gfx90a", "gfx942" }; for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { TORCH_WARN_ONCE( "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); return true; diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h index d187d7b7c116..96cd25fec10b 100644 --- a/aten/src/ATen/core/Dict.h +++ b/aten/src/ATen/core/Dict.h @@ -116,10 +116,7 @@ class DictIterator final { DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} - DictIterator& operator=(const DictIterator& rhs) { - entryRef_ = rhs.entryRef_; - return *this; - } + DictIterator& operator=(const DictIterator& rhs) = default; DictIterator& operator=(DictIterator&& rhs) noexcept { entryRef_ = std::move(rhs.entryRef_); return *this; diff --git a/aten/src/ATen/core/alias_info.h b/aten/src/ATen/core/alias_info.h index a8a55bb782c4..bf0ff6ee72d3 100644 --- a/aten/src/ATen/core/alias_info.h +++ b/aten/src/ATen/core/alias_info.h @@ -1,4 +1,6 @@ #pragma once +#include +#include #include #include #include @@ -18,6 +20,15 @@ namespace c10 { */ class AliasInfo { public: + AliasInfo() = default; + AliasInfo(bool is_write, const std::set& before_qual_strings, const std::set& after_qual_strings) : isWrite_(is_write) { + for (const auto& s: before_qual_strings) { + beforeSets_.insert(Symbol::fromQualString(s)); + } + for (const auto& s : after_qual_strings) { + afterSets_.insert(Symbol::fromQualString(s)); + } + } // Symbol for the set that can alias anything static Symbol wildcardSet() { static const Symbol wc = Symbol::fromQualString("alias::*"); diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 27438b926db5..61a3c1801294 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -225,8 +225,7 @@ struct TORCH_API DispatchKeyExtractor final { explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse), - nonFallthroughKeys_(DispatchKeySet::FULL), - requiresBitsetPerBackend_(false) { + nonFallthroughKeys_(DispatchKeySet::FULL) { for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; } @@ -252,7 +251,7 @@ struct TORCH_API DispatchKeyExtractor final { // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast // path), or if we need to fall back to the slower path and check // nonFallthroughKeysPerBackend_ - bool requiresBitsetPerBackend_; + bool requiresBitsetPerBackend_{false}; }; } // namespace c10 diff --git a/aten/src/ATen/core/library.cpp b/aten/src/ATen/core/library.cpp index b8a5b418bbc0..5dcac2b0e2fb 100644 --- a/aten/src/ATen/core/library.cpp +++ b/aten/src/ATen/core/library.cpp @@ -58,6 +58,18 @@ void Library::reset() { #define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")" +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) +namespace detail { + std::vector torch_library_initializers; +} // namespace detail +void initialize_torch_libraries() { + for (auto* initializer : detail::torch_library_initializers) { + initializer->initialize(); + } + detail::torch_library_initializers.clear(); +} +#endif + Library::Library(Kind kind, std::string ns, std::optional k, const char* file, uint32_t line) : kind_(kind) , ns_(ns == "_" ? std::nullopt : std::make_optional(std::move(ns))) diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index 4d1d05ea8d32..e7429d18712d 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -29,16 +29,21 @@ inline scalar_t vec_reduce_all( template struct VecReduceAllSIMD { - static inline scalar_t apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline scalar_t apply( + const Op& vec_fun, + const Vectorized& acc_vec) { return vec_reduce_all(vec_fun, acc_vec, Vectorized::size()); } }; -#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) +#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \ + !defined(C10_MOBILE) #if defined(CPU_CAPABILITY_AVX2) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 128-bit shuffle @@ -57,7 +62,9 @@ struct VecReduceAllSIMD { #if defined(CPU_CAPABILITY_AVX512) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 256-bit shuffle @@ -76,25 +83,33 @@ struct VecReduceAllSIMD { } }; #endif // defined(CPU_CAPABILITY_AVX512) -#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) +#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && + // !defined(C10_MOBILE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; - // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -] + // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, + // a4+a8, a1+a5, a2+a6, -, -, -, -] float32x4_t v1_1 = vextq_f32(v, v, 2); Vec v1 = v1_1; // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] v = vec_fun(v, v1); - // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -] + // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, + // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, + // -] v1_1 = vrev64q_f32(v); v1 = v1_1; - // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] + // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, + // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] v = vec_fun(v, v1); return v[0]; @@ -102,10 +117,13 @@ struct VecReduceAllSIMD { }; #endif // defined(__aarch64__) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + defined(CPU_CAPABILITY_SVE256) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 128-bit shuffle @@ -125,15 +143,21 @@ struct VecReduceAllSIMD { }; #endif // defined(__aarch64__) - template -inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized& acc_vec) { +inline scalar_t vec_reduce_all( + const Op& vec_fun, + const Vectorized& acc_vec) { return VecReduceAllSIMD::apply(vec_fun, acc_vec); } -template , int> = 0> -inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline scalar_t reduce_all( + const Op& vec_fun, + const scalar_t* data, + int64_t size) { using Vec = vec::Vectorized; if (size < Vec::size()) return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); @@ -151,16 +175,22 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size } // similar to reduce_all, but reduces into two outputs -template , int> = 0> -inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, - const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { using Vec = vec::Vectorized; if (size < Vec::size()) { auto loaded_data = Vec::loadu(data, size); return std::pair( - vec_reduce_all(vec_fun1, loaded_data, size), - vec_reduce_all(vec_fun2, loaded_data, size)); + vec_reduce_all(vec_fun1, loaded_data, size), + vec_reduce_all(vec_fun2, loaded_data, size)); } int64_t d = Vec::size(); Vec acc_vec1 = Vec::loadu(data); @@ -176,12 +206,14 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); } return std::pair( - vec_reduce_all(vec_fun1, acc_vec1), - vec_reduce_all(vec_fun2, acc_vec2)); + vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2)); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -205,8 +237,11 @@ inline scalar_t map_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map2_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -237,8 +272,11 @@ inline scalar_t map2_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map3_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -274,8 +312,10 @@ inline scalar_t map3_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -293,8 +333,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map2( const Op& vec_fun, scalar_t* output_data, @@ -317,8 +359,10 @@ inline void map2( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map3( const Op& vec_fun, scalar_t* output_data, @@ -344,8 +388,10 @@ inline void map3( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map4( const Op& vec_fun, scalar_t* output_data, diff --git a/aten/src/ATen/cpu/vec/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h index 3bd22b3820f0..d4a40acaeefd 100644 --- a/aten/src/ATen/cpu/vec/functional_bfloat16.h +++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h @@ -8,86 +8,120 @@ namespace at::vec { // BFloat16 specification -template struct VecScalarType { using type = scalar_t; }; -template <> struct VecScalarType { using type = float; }; -template <> struct VecScalarType { using type = float; }; +template +struct VecScalarType { + using type = scalar_t; +}; +template <> +struct VecScalarType { + using type = float; +}; +template <> +struct VecScalarType { + using type = float; +}; // This is different from at::acc_type since we only need to specialize BFloat16 template using vec_scalar_t = typename VecScalarType::type; // Vector conversion between float and bfloat16/half -template , int> = 0> -inline std::tuple, Vectorized> convert_to_float(const Vectorized&); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized&); template <> -inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) { +inline std::tuple, Vectorized> convert_to_float< + BFloat16>(const Vectorized& a) { return convert_bfloat16_float(a); } template <> -inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) { - return convert_half_float(a); +inline std::tuple, Vectorized> convert_to_float( + const Vectorized& a) { + return convert_half_float(a); } -template , int> = 0> -inline Vectorized convert_from_float(const Vectorized&, const Vectorized&); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline Vectorized convert_from_float( + const Vectorized&, + const Vectorized&); template <> -inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) { +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { return convert_float_bfloat16(a, b); } template <> -inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) { +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { return convert_float_half(a, b); } -template , int> = 0> -inline void load_to_float(const scalar_t *data, Vectorized &out1, Vectorized &out2); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float( + const scalar_t* data, + Vectorized& out1, + Vectorized& out2); template <> -inline void load_to_float (const BFloat16 *data, Vectorized &out1, Vectorized &out2) { +inline void load_to_float( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { load_fp32_from_bf16(data, out1, out2); } template <> -inline void load_to_float (const Half *data, Vectorized &out1, Vectorized &out2) { +inline void load_to_float( + const Half* data, + Vectorized& out1, + Vectorized& out2) { load_fp32_from_fp16(data, out1, out2); } -template , int> = 0> -inline void load_to_float(const scalar_t *data, Vectorized &out); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float(const scalar_t* data, Vectorized& out); template <> -inline void load_to_float (const BFloat16 *data, Vectorized &out) { +inline void load_to_float( + const BFloat16* data, + Vectorized& out) { load_fp32_from_bf16(data, out); } template <> -inline void load_to_float (const Half *data, Vectorized &out) { +inline void load_to_float(const Half* data, Vectorized& out) { load_fp32_from_fp16(data, out); } -// Note that we already have specialized member of Vectorized for BFloat16 -// so the following functions would run smoothly: +// Note that we already have specialized member of Vectorized for +// BFloat16 so the following functions would run smoothly: // using Vec = Vectorized; // Vec one = Vec(BFloat16(1)); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // // Then why we still need to specialize "functional"? -// If we do specialization at Vectorized<> level, the above example would need 3 pairs of -// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". -// If we do specialization at vec::map<>() level, we have only 1 pair of conversion -// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. +// If we do specialization at Vectorized<> level, the above example would need +// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and +// "/". If we do specialization at vec::map<>() level, we have only 1 pair of +// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16 +// vector only. // -// The following BFloat16 functionality will only do data type conversion for input -// and output vector (reduce functionality will only convert the final scalar back to bf16). -// Compared to Vectorized<> specialization, +// The following BFloat16 functionality will only do data type conversion for +// input and output vector (reduce functionality will only convert the final +// scalar back to bf16). Compared to Vectorized<> specialization, // 1. better performance since we have less data type conversion; // 2. less rounding error since immediate results are kept in fp32; // 3. accumulation done on data type of fp32. @@ -95,8 +129,10 @@ inline void load_to_float (const Half *data, Vectorized &out) { // If you plan to extend this file, please ensure adding unit tests at // aten/src/ATen/test/vec_test_all_types.cpp // -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { using bVec = vec::Vectorized; using fVec = vec::Vectorized; @@ -104,7 +140,8 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { bVec data_bvec = bVec::loadu(data, size); auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size > fVec::size()) { - data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(vec_fun, data_fvec0, fVec::size()); } else { return vec_reduce_all(vec_fun, data_fvec0, size); @@ -124,27 +161,37 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size - d > fVec::size()) { acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { - acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(vec_fun, acc_fvec0); } -template , int> = 0> -inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, - const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { using bVec = vec::Vectorized; using fVec = vec::Vectorized; if (size < bVec::size()) { bVec data_bvec = bVec::loadu(data, size); auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size > fVec::size()) { - fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); - fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc1_fvec = fVec::set( + data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc2_fvec = fVec::set( + data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); return std::pair( vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()), vec_reduce_all(vec_fun2, acc2_fvec, fVec::size())); @@ -171,12 +218,20 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_f auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size - d > fVec::size()) { acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); - acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size()); + acc1_fvec1 = fVec::set( + acc1_fvec1, + vec_fun1(acc1_fvec1, data_fvec1), + size - d - fVec::size()); acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); - acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size()); + acc2_fvec1 = fVec::set( + acc2_fvec1, + vec_fun2(acc2_fvec1, data_fvec1), + size - d - fVec::size()); } else { - acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); - acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); + acc1_fvec0 = + fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); + acc2_fvec0 = + fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); } } acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); @@ -186,8 +241,11 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_f vec_reduce_all(vec_fun2, acc2_fvec0)); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -201,7 +259,8 @@ inline float map_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0); data_fvec1 = map_fun(data_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0); @@ -228,18 +287,23 @@ inline float map_reduce_all( data_fvec0 = map_fun(data_fvec0); data_fvec1 = map_fun(data_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map2_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -256,7 +320,8 @@ inline float map2_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0); @@ -289,18 +354,23 @@ inline float map2_reduce_all( data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map3_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -320,7 +390,8 @@ inline float map3_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); @@ -359,18 +430,22 @@ inline float map3_reduce_all( data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -397,8 +472,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -419,7 +496,8 @@ inline void map( fVec data_fvec0, data_fvec1; if (size - d > fVec::size()) { data_fvec0 = fVec::loadu(input_data + d); - data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); + data_fvec1 = + fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); } else { // choose to align with behaviour of bVec::loadu(ptr, size), // which leaves data_fvec1 uninitialized @@ -432,8 +510,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map2( const Op& vec_fun, scalar_t* output_data, @@ -465,8 +545,10 @@ inline void map2( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map3( const Op& vec_fun, scalar_t* output_data, @@ -503,8 +585,10 @@ inline void map3( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map4( const Op& vec_fun, scalar_t* output_data, @@ -525,8 +609,10 @@ inline void map4( auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); bVec data4_bvec = bVec::loadu(input_data4 + d); auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); - fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); - fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); output_bvec.store(output_data + d); } @@ -539,8 +625,10 @@ inline void map4( auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); - fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); - fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); output_bvec.store(output_data + d, size - d); } diff --git a/aten/src/ATen/cpu/vec/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h index 48b18793b079..f9086f7d3d0b 100644 --- a/aten/src/ATen/cpu/vec/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -13,10 +13,14 @@ /* Microsoft C/C++-compatible compiler */ #include #if _MSC_VER <= 1900 -#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) -#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) -#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) -#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) +#define _mm256_extract_epi64(X, Y) \ + (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) +#define _mm256_extract_epi32(X, Y) \ + (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) +#define _mm256_extract_epi16(X, Y) \ + (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) +#define _mm256_extract_epi8(X, Y) \ + (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) #endif #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ @@ -25,9 +29,9 @@ /* GCC-compatible compiler, targeting ARM with SVE */ #include #endif -#if defined (MISSING_ARM_VLD1) +#if defined(MISSING_ARM_VLD1) #include -#elif defined (MISSING_ARM_VST1) +#elif defined(MISSING_ARM_VST1) #include #endif #elif defined(__GNUC__) && defined(__IWMMXT__) @@ -36,8 +40,8 @@ #elif defined(__s390x__) // targets Z/architecture // we will include vecintrin later -#elif (defined(__GNUC__) || defined(__xlC__)) && \ - (defined(__VEC__) || defined(__ALTIVEC__)) +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ #include /* We need to undef those tokens defined by to avoid conflicts diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h index 6a3dc2bc1c10..dd35787dfb5b 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_float.h +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -85,6 +85,58 @@ template <> class Vectorized { } return b; } + //Implementation is picked from https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 + inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { + const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f + const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f + const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f + const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f + const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f + const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = svmla_f32_z(pg, shift, x, inv_ln2); + const auto n = svsub_f32_z(pg, z, shift); + const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = svmul_f32_z(pg, r, r); + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + auto poly = svmla_f32_z(pg, scale, p12345, scale); + // Handle underflow and overflow. + poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); + return poly; + } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return svld1_f32(ptrue, reinterpret_cast(ptr)); @@ -333,8 +385,34 @@ template <> class Vectorized { Vectorized tan() const { return USE_SLEEF(Vectorized(Sleef_tanfx_u10sve(values)),map(std::tan)); } + //Implementation is picked from https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179 Vectorized tanh() const { - return USE_SLEEF(Vectorized(Sleef_tanhfx_u10sve(values)),map(std::tanh)); + // Constants used for the tanh calculation. + const svfloat32_t CONST_1 = svdup_n_f32(1.f); // Constant 1.0f for the tanh formula. + const svfloat32_t CONST_2 = svdup_n_f32(2.f); // Constant 2.0f for the tanh formula (used in exp(2x)). + const svfloat32_t CONST_MIN_TANH = svdup_n_f32(-10.f); // Minimum threshold for input values to prevent overflow. + const svfloat32_t CONST_MAX_TANH = svdup_n_f32(10.f); // Maximum threshold for input values to prevent overflow. + + // Step 1: Clamp the values within the range [-10, 10] to prevent overflow during exponentiation. + // The tanh function approaches ±1 rapidly as the input grows large, so we limit the input range to avoid numerical instability. + // svmax_f32_z ensures values are greater than -10, and svmin_f32_z ensures they are less than 10. + svfloat32_t x = svmin_f32_z(ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); + + // Step 2: Calculate exp(2 * x), where x is the clamped value. + // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of the result. + svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); + + // Step 3: Calculate the numerator of the tanh function, which is exp(2x) - 1. + svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1); + + // Step 4: Calculate the denominator of the tanh function, which is exp(2x) + 1. + svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1); + + // Step 5: Calculate the tanh function as the ratio of the numerator and denominator: num / den. + svfloat32_t tanh = svdiv_f32_z(ptrue, num, den); + + // Return the calculated tanh values. + return tanh; } Vectorized trunc() const { return svrintz_f32_x(ptrue, values); diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index e4b0c4b95d84..0bfe65cd1959 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -28,21 +28,30 @@ inline Vectorized Vectorized::loadu(const void* ptr) { } template <> -inline Vectorized Vectorized::loadu(const void* ptr, int64_t count) { +inline Vectorized Vectorized::loadu( + const void* ptr, + int64_t count) { // See NOTE [Loading boolean values] return convert_to_bool(Vectorized::loadu(ptr, count)); } template -struct VecHoldType { using hold_type = typename VT::value_type; }; +struct VecHoldType { + using hold_type = typename VT::value_type; +}; template <> -struct VecHoldType> { using hold_type = BFloat16; }; +struct VecHoldType> { + using hold_type = BFloat16; +}; template <> -struct VecHoldType> {using hold_type = Half; }; +struct VecHoldType> { + using hold_type = Half; +}; template using vechold_type = typename VecHoldType::hold_type; -}} // namespace at::vec::CPU_CAPABILITY +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 2591338881ae..0f24ccf385df 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -1,5 +1,6 @@ #pragma once -#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && defined(__ARM_FEATURE_SVE) +#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ + defined(__ARM_FEATURE_SVE) // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 #pragma GCC optimize("no-tree-vectorize") #endif @@ -18,27 +19,27 @@ // See https://github.com/pytorch/pytorch/issues/37577 for an instance // of this bug in the past. -#include #include +#include #include +#include +#include #include #include -#include #include -#include +#include #include #include -#include -#include -#include -#include -#include #include -#include #include -#include +#include +#include +#include #include +#include +#include +#include #if defined(__GNUC__) #define __FORCE_INLINE __attribute__((always_inline)) inline @@ -66,7 +67,8 @@ Windows llvm will not have this definition. #endif #define VECTOR_WIDTH 64 #define int_vector __m512i -#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +#elif defined(__aarch64__) && \ + !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 // SVE code expects 256-vectors; leave that set for SVE? #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(16))) @@ -93,40 +95,43 @@ namespace at::vec { inline namespace CPU_CAPABILITY { // at::Half and at::BFloat16 should be treated as floating point template -struct is_floating_point: - std::integral_constant || - std::is_same_v || - std::is_same_v> { -}; +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point_v || std::is_same_v || + std::is_same_v> {}; -template +template constexpr bool is_floating_point_v = is_floating_point::value; template -struct is_reduced_floating_point: - std::integral_constant || - std::is_same_v> { -}; +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; template -constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; template -struct is_8bit_integer: - std::integral_constant || - std::is_same_v> { +struct is_8bit_integer + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> { }; template constexpr bool is_8bit_integer_v = is_8bit_integer::value; -template struct int_of_size; +template +struct int_of_size; -#define DEFINE_INT_OF_SIZE(int_t) \ -template<> struct int_of_size { using type = int_t; } +#define DEFINE_INT_OF_SIZE(int_t) \ + template <> \ + struct int_of_size { \ + using type = int_t; \ + } DEFINE_INT_OF_SIZE(int64_t); DEFINE_INT_OF_SIZE(int32_t); @@ -142,14 +147,15 @@ using int_same_size_t = typename int_of_size::type; // emulates Vectorized types #if defined(__s390x__) -template +template #else template #endif struct Vectorized { -private: + private: __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; -public: + + public: using value_type = T; using size_type = int; @@ -163,11 +169,11 @@ struct Vectorized { values[i] = val; } } - template> - Vectorized(Args... vals) : values{vals...}{ - } - Vectorized(const T(&arr)[kSize]) { + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) : values{vals...} {} + Vectorized(const T (&arr)[kSize]) { std::memcpy(values, arr, sizeof(values)); } // This also implies const T& operator[](int idx) const @@ -198,20 +204,23 @@ struct Vectorized { } // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 #if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) - static Vectorized __attribute__ ((optimize("-fno-tree-loop-vectorize"))) blendv(const Vectorized& a, + static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) + blendv( + const Vectorized& a, #else - static Vectorized blendv(const Vectorized& a, + static Vectorized blendv( + const Vectorized& a, #endif - const Vectorized& b, const Vectorized& mask) { + const Vectorized& b, + const Vectorized& mask) { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); #if defined(__clang__) && __ARM_FEATURE_SVE - #pragma clang loop vectorize(disable) +#pragma clang loop vectorize(disable) #endif for (const auto i : c10::irange(size())) { - if (buffer[i] & 0x01) - { + if (buffer[i] & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; @@ -219,15 +228,21 @@ struct Vectorized { } return vector; } - template // step sometimes requires a higher precision type (e.g., T=int, step_t=double) - static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) { + template // step sometimes requires a higher precision type + // (e.g., T=int, step_t=double) + static Vectorized arange( + T base = static_cast(0), + step_t step = static_cast(1)) { Vectorized vector; for (const auto i : c10::irange(size())) { vector.values[i] = base + i * step; } return vector; } - static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { Vectorized vector; for (const auto i : c10::irange(size())) { if (i < count) { @@ -249,7 +264,9 @@ struct Vectorized { return vector; } static Vectorized loadu_one_fourth(const void* ptr) { - static_assert(std::is_same_v || std::is_same_v, "For byte types only"); + static_assert( + std::is_same_v || std::is_same_v, + "For byte types only"); return Vectorized::loadu(ptr, 8); } @@ -257,9 +274,10 @@ struct Vectorized { std::memcpy(ptr, values, count * sizeof(T)); } int zero_mask() const { - // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit int mask = 0; - for (int i = 0; i < size(); ++ i) { + for (int i = 0; i < size(); ++i) { if (values[i] == static_cast(0)) { mask |= (1 << i); } @@ -279,15 +297,18 @@ struct Vectorized { } bool has_inf_nan() const { for (int64_t i = 0; i != size(); i++) { - if(_isnan(values[i]) || _isinf(values[i])) { + if (_isnan(values[i]) || _isinf(values[i])) { return true; } } return false; } -// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows Arm64 -// See https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 -#if defined(_WIN32) && defined(__aarch64__) && ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) +// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows +// Arm64 +// See +// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) && \ + ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i < size(); i++) { @@ -322,38 +343,44 @@ struct Vectorized { return ret; } #endif - Vectorized map(T (*const f)(const T &)) const { + Vectorized map(T (*const f)(const T&)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } - T reduce(T (*const f)(const T &)) const { + T reduce(T (*const f)(const T&)) const { T ret = 0; for (int64_t i = 0; i != size(); i++) { ret = f(ret, values[i]); } return ret; } - template && !c10::is_complex::value, int> = 0> + template < + typename other_t_abs = T, + typename std::enable_if_t< + !is_floating_point_v && + !c10::is_complex::value, + int> = 0> Vectorized abs() const { // other_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_abs must be T"); return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); } - template , int> = 0> + template < + typename float_t_abs = T, + typename std::enable_if_t, int> = 0> Vectorized abs() const { // float_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "float_t_abs must be T"); - // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in - // 0.0) properly. + // Specifically deal with floating-point because the generic code above + // won't handle -0.0 (which should result in 0.0) properly. return map([](T x) -> T { return std::abs(x); }); } - template ::value, int> = 0> + template < + typename complex_t_abs = T, + typename std::enable_if_t::value, int> = 0> Vectorized abs() const { // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "complex_t_abs must be T"); @@ -361,66 +388,85 @@ struct Vectorized { return map([](T x) { return static_cast(std::abs(x)); }); } - template ::value, int> = 0> + template < + typename other_t_sgn = T, + typename std::enable_if_t::value, int> = 0> Vectorized sgn() const { return map(at::native::sgn_impl); } - template ::value, int> = 0> + template < + typename other_t_angle = T, + typename std::enable_if_t::value, int> = + 0> Vectorized angle() const { // other_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_angle must be T"); - return map(at::native::angle_impl); // compiler is unable to resolve the overload without + return map(at::native::angle_impl); // compiler is unable to resolve the + // overload without } - template ::value, int> = 0> + template < + typename complex_t_angle = T, + typename std::enable_if_t::value, int> = + 0> Vectorized angle() const { // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_angle must be T"); + static_assert( + std::is_same_v, "complex_t_angle must be T"); return map([](T x) { return static_cast(std::arg(x)); }); } - template ::value, int> = 0> + template < + typename other_t_real = T, + typename std::enable_if_t::value, int> = 0> Vectorized real() const { // other_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_real must be T"); return *this; } - template ::value, int> = 0> + template < + typename complex_t_real = T, + typename std::enable_if_t::value, int> = + 0> Vectorized real() const { // complex_t_real is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_real must be T"); + static_assert( + std::is_same_v, "complex_t_real must be T"); return map([](T x) { return static_cast(x.real()); }); } - template ::value, int> = 0> + template < + typename other_t_imag = T, + typename std::enable_if_t::value, int> = 0> Vectorized imag() const { // other_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_imag must be T"); return Vectorized(0); } - template ::value, int> = 0> + template < + typename complex_t_imag = T, + typename std::enable_if_t::value, int> = + 0> Vectorized imag() const { // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_imag must be T"); + static_assert( + std::is_same_v, "complex_t_imag must be T"); return map([](T x) { return static_cast(x.imag()); }); } - template ::value, int> = 0> + template < + typename other_t_conj = T, + typename std::enable_if_t::value, int> = 0> Vectorized conj() const { // other_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_conj must be T"); return *this; } - template ::value, int> = 0> + template < + typename complex_t_conj = T, + typename std::enable_if_t::value, int> = + 0> Vectorized conj() const { // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_conj must be T"); + static_assert( + std::is_same_v, "complex_t_conj must be T"); return map([](T x) { return static_cast(std::conj(x)); }); } Vectorized acos() const { @@ -441,7 +487,7 @@ struct Vectorized { Vectorized atanh() const { return map(std::atanh); } - Vectorized atan2(const Vectorized &exp) const { + Vectorized atan2(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::atan2(values[i], exp[i]); @@ -449,9 +495,9 @@ struct Vectorized { return ret; } template < - typename U = T, - typename std::enable_if_t, int> = 0> - Vectorized copysign(const Vectorized &sign) const { + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized copysign(const Vectorized& sign) const { Vectorized ret; for (size_type i = 0; i < size(); i++) { ret[i] = c10::copysign(values[i], sign[i]); @@ -483,8 +529,8 @@ struct Vectorized { return *this - this->trunc(); } template < - typename U = T, - typename std::enable_if_t, int> = 0> + typename U = T, + typename std::enable_if_t, int> = 0> Vectorized fmod(const Vectorized& q) const { // U is for SFINAE purposes only. Make sure it is not changed. static_assert(std::is_same_v, "U must be T"); @@ -503,20 +549,24 @@ struct Vectorized { Vectorized log1p() const { return map(std::log1p); } - template ::value, int> = 0> + template < + typename other_t_log2 = T, + typename std::enable_if_t::value, int> = 0> Vectorized log2() const { // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_log2 must be T"); return map(std::log2); } - template ::value, int> = 0> + template < + typename complex_t_log2 = T, + typename std::enable_if_t::value, int> = + 0> Vectorized log2() const { // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_log2 must be T"); + static_assert( + std::is_same_v, "complex_t_log2 must be T"); const T log_2 = T(std::log(2.0)); - return Vectorized(map(std::log))/Vectorized(log_2); + return Vectorized(map(std::log)) / Vectorized(log_2); } Vectorized ceil() const { return map(at::native::ceil_impl); @@ -530,7 +580,7 @@ struct Vectorized { Vectorized floor() const { return map(at::native::floor_impl); } - Vectorized hypot(const Vectorized &b) const { + Vectorized hypot(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::hypot(values[i], b[i]); @@ -546,14 +596,14 @@ struct Vectorized { Vectorized digamma() const { return map(calc_digamma); } - Vectorized igamma(const Vectorized &x) const { + Vectorized igamma(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igamma(values[i], x[i]); } return ret; } - Vectorized igammac(const Vectorized &x) const { + Vectorized igammac(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igammac(values[i], x[i]); @@ -566,7 +616,7 @@ struct Vectorized { // promotion return map([](T x) -> T { return -x; }); } - Vectorized nextafter(const Vectorized &b) const { + Vectorized nextafter(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::nextafter(values[i], b[i]); @@ -574,7 +624,8 @@ struct Vectorized { return ret; } Vectorized round() const { - // We do not use std::round because we would like to round midway numbers to the nearest even integer. + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. return map(at::native::round_impl); } Vectorized sin() const { @@ -604,20 +655,21 @@ struct Vectorized { Vectorized rsqrt() const { return map([](T x) { return (T)1 / std::sqrt(x); }); } - Vectorized pow(const Vectorized &exp) const { + Vectorized pow(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::pow(values[i], exp[i]); } return ret; } - T reduce_add() const { + T reduce_add() const { return reduce([](T x, T y) -> T { return x + y; }); } T reduce_max() const { return reduce(std::max); } -private: + + private: template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. @@ -632,35 +684,85 @@ struct Vectorized { return vector; } -public: - Vectorized operator==(const Vectorized& other) const { return binary_pred(other, std::equal_to()); } - Vectorized operator!=(const Vectorized& other) const { return binary_pred(other, std::not_equal_to()); } - Vectorized operator>=(const Vectorized& other) const { return binary_pred(other, std::greater_equal()); } - Vectorized operator<=(const Vectorized& other) const { return binary_pred(other, std::less_equal()); } - Vectorized operator>(const Vectorized& other) const { return binary_pred(other, std::greater()); } - Vectorized operator<(const Vectorized& other) const { return binary_pred(other, std::less()); } + public: + Vectorized operator==(const Vectorized& other) const { + return binary_pred(other, std::equal_to()); + } + Vectorized operator!=(const Vectorized& other) const { + return binary_pred(other, std::not_equal_to()); + } + Vectorized operator>=(const Vectorized& other) const { + return binary_pred(other, std::greater_equal()); + } + Vectorized operator<=(const Vectorized& other) const { + return binary_pred(other, std::less_equal()); + } + Vectorized operator>(const Vectorized& other) const { + return binary_pred(other, std::greater()); + } + Vectorized operator<(const Vectorized& other) const { + return binary_pred(other, std::less()); + } -private: + private: template - inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { + inline Vectorized binary_pred_bool(const Vectorized& other, Op op) + const { // 1 if the pred is true, otherwise 0. Vectorized vector; - for (int i = 0; i != size(); ++ i) { + for (int i = 0; i != size(); ++i) { vector[i] = static_cast(op(values[i], other.values[i])); } return vector; } -public: - Vectorized eq(const Vectorized& other) const { return binary_pred_bool(other, std::equal_to()); } - Vectorized ne(const Vectorized& other) const { return binary_pred_bool(other, std::not_equal_to()); } - Vectorized gt(const Vectorized& other) const { return binary_pred_bool(other, std::greater()); } - Vectorized ge(const Vectorized& other) const { return binary_pred_bool(other, std::greater_equal()); } - Vectorized lt(const Vectorized& other) const { return binary_pred_bool(other, std::less()); } - Vectorized le(const Vectorized& other) const { return binary_pred_bool(other, std::less_equal()); } + public: + Vectorized eq(const Vectorized& other) const { + return binary_pred_bool(other, std::equal_to()); + } + Vectorized ne(const Vectorized& other) const { + return binary_pred_bool(other, std::not_equal_to()); + } + Vectorized gt(const Vectorized& other) const { + return binary_pred_bool(other, std::greater()); + } + Vectorized ge(const Vectorized& other) const { + return binary_pred_bool(other, std::greater_equal()); + } + Vectorized lt(const Vectorized& other) const { + return binary_pred_bool(other, std::less()); + } + Vectorized le(const Vectorized& other) const { + return binary_pred_bool(other, std::less_equal()); + } }; -template Vectorized inline operator+(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator-(const Vectorized& a) { + return a.neg(); +} + +// There is an implicit conversion that would make this work if +// these operators weren't template functions, but they are template +// functions (and can't be moved to be non-member friends defined in +// the class body as suggested in +// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255 +// because we have a lot of disparate specializations of +// Vectorized). So, just explicitly make scalars work. +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \ + template \ + Vectorized inline name(const Vectorized& a, T b) { \ + return name(a, Vectorized(b)); \ + } \ + template \ + Vectorized inline name(T a, const Vectorized& b) { \ + return name(Vectorized(a), b); \ + } +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \ + VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op) + +template +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] + b[i]; @@ -668,7 +770,10 @@ template Vectorized inline operator+(const Vectorized &a, const return c; } -template Vectorized inline operator-(const Vectorized &a, const Vectorized &b) { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+) + +template +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] - b[i]; @@ -676,7 +781,10 @@ template Vectorized inline operator-(const Vectorized &a, const return c; } -template Vectorized inline operator*(const Vectorized &a, const Vectorized &b) { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-) + +template +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] * b[i]; @@ -684,7 +792,11 @@ template Vectorized inline operator*(const Vectorized &a, const return c; } -template Vectorized inline operator/(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*) + +template +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] / b[i]; @@ -692,14 +804,20 @@ template Vectorized inline operator/(const Vectorized &a, const return c; } -template , int> = 0> -Vectorized inline operator%(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/) + +template , int> = 0> +Vectorized inline operator%(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { return a - a / b * b; } -template Vectorized inline operator||( - const Vectorized &a, const Vectorized &b) { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%) + +template +Vectorized inline operator||( + const Vectorized& a, + const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] || b[i]; @@ -707,11 +825,14 @@ template Vectorized inline operator||( return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||) + // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. -template ::value, int> = 0> -Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; @@ -725,9 +846,10 @@ Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; @@ -741,11 +863,14 @@ Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum) + // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. -template ::value, int> = 0> -Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; @@ -759,9 +884,10 @@ Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; @@ -775,9 +901,15 @@ Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, const Vectorized &max_vec) { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_vec, + const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); @@ -785,9 +917,48 @@ Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, return c; } -template ::value, int> = 0> -Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_vec) { +#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \ + template \ + Vectorized inline name( \ + const Vectorized& a, const Vectorized& b, T c) { \ + return name(a, b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + const Vectorized& a, T b, const Vectorized& c) { \ + return name(a, Vectorized(b), c); \ + } \ + \ + template \ + Vectorized inline name(const Vectorized& a, T b, T c) { \ + return name(a, Vectorized(b), Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + T a, const Vectorized& b, const Vectorized& c) { \ + return name(Vectorized(a), b, c); \ + } \ + \ + template \ + Vectorized inline name(T a, const Vectorized& b, T c) { \ + return name(Vectorized(a), b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name(T a, T b, const Vectorized& c) { \ + return name(Vectorized(a), Vectorized(b), c); \ + } + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; @@ -795,9 +966,14 @@ Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_ return c; } -template ::value, int> = 0> -Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_vec) { +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max) + +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; @@ -805,18 +981,27 @@ Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_ return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min) + struct Vectorizedi; #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template -static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { int_vector buffer; #if defined(CPU_CAPABILITY_AVX2) - int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); - int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); + int_vector a_buffer = + _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm256_load_si256(reinterpret_cast((const T*)b)); #elif defined(CPU_CAPABILITY_AVX512) - int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); - int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); + int_vector a_buffer = + _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm512_load_si512(reinterpret_cast((const T*)b)); #endif buffer = op(a_buffer, b_buffer); __at_align__ T results[Vectorized::size()]; @@ -829,31 +1014,52 @@ static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vect return Vectorized::loadu(results); } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); #endif } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); #endif } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); #endif } @@ -866,12 +1072,19 @@ auto load(char const* data) -> T { return ret; } -template -static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); __at_align__ intmax_t buffer[element_no]; - static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); - static_assert(sizeof(buffer) == sizeof(Vectorized), "sizeof(buffer) must match sizeof(Vectorized)"); + static_assert( + VECTOR_WIDTH % sizeof(intmax_t) == 0, + "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); + static_assert( + sizeof(buffer) == sizeof(Vectorized), + "sizeof(buffer) must match sizeof(Vectorized)"); // We should be using memcpy in order to respect the strict aliasing rule // see: https://github.com/pytorch/pytorch/issues/66119 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 @@ -889,34 +1102,54 @@ static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vect return Vectorized::loadu(buffer); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_and()); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_or()); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) -template>, int> = 0> +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^) + +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator~(const Vectorized& a) { using int_t = int_same_size_t; - Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 + Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 return a ^ ones; } -template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { constexpr T max_shift = sizeof(T) * CHAR_BIT; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; - if ((static_cast>(shift) < 0) || (shift >= max_shift)) { + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { c[i] = 0; } else { c[i] = static_cast>(a[i]) << shift; @@ -925,13 +1158,17 @@ template Vectorized inline operator<<(const Vectorized &a, const return c; } -template Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { // right shift value to retain sign bit for signed and no bits for unsigned constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; - if ((static_cast>(shift) < 0) || (shift >= max_shift)) { + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { c[i] = a[i] >> max_shift; } else { c[i] = a[i] >> shift; @@ -941,53 +1178,63 @@ template Vectorized inline operator>>(const Vectorized &a, const } template -inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { a = a + b; return a; } template -inline Vectorized& operator -= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { a = a - b; return a; } template -inline Vectorized& operator /= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { a = a / b; return a; } template -inline Vectorized& operator %= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { a = a % b; return a; } template -inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { a = a * b; return a; } template -inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { a = a << b; return a; } template -inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { a = a >> b; return a; } template -inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { +inline Vectorized fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { return a * b + c; } +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) + template -inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) { +inline Vectorized fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { return a * b - c; } +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) + template Vectorized inline operator&&( const Vectorized& a, @@ -999,9 +1246,13 @@ Vectorized inline operator&&( return ret; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&) + template -std::enable_if_t> -inline gather(T const* base_addr, const Vectorized>& vindex) { +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + T>> inline gather(T const* base_addr, const Vectorized>& vindex) { static constexpr int size = Vectorized::size(); int_same_size_t index_arr[size]; vindex.store(static_cast(index_arr)); @@ -1013,36 +1264,39 @@ inline gather(T const* base_addr, const Vectorized>& vindex) } template -std::enable_if_t> -inline mask_gather(const Vectorized& src, T const* base_addr, - const Vectorized>& vindex, Vectorized& mask) { +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + T const* base_addr, + const Vectorized>& vindex, + Vectorized& mask) { static constexpr int size = Vectorized::size(); T src_arr[size]; - int_same_size_t mask_arr[size]; // use int type so we can logical and + int_same_size_t mask_arr[size]; // use int type so we can logical and int_same_size_t index_arr[size]; src.store(static_cast(src_arr)); mask.store(static_cast(mask_arr)); vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { - if (mask_arr[i] & 0x01) { // check highest bit + if (mask_arr[i] & 0x01) { // check highest bit buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } else { buffer[i] = src_arr[i]; } } - mask = Vectorized(static_cast(0)); // "zero out" mask + mask = Vectorized(static_cast(0)); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } // Cast a given vector to another type without changing the bits representation. // So a Vectorized of 512 bits containing all ones can be cast to a -// Vectorized of 512 bits containing all ones (i.e., eight negative 1s). -// A Vec of 256 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative +// 1s). A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). // There is a struct here because we don't have static_if and I can't // partially specialize a templated function. -template +template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { src_t src_arr[Vectorized::size()]; @@ -1051,44 +1305,51 @@ struct CastImpl { } }; -template +template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { return src; } }; -template +template inline Vectorized cast(const Vectorized& src) { return CastImpl::apply(src); } template > -inline Vectorized convert_to_int_of_same_size(const Vectorized& src) { +inline Vectorized convert_to_int_of_same_size( + const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; - std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), - [](const T& x) { return static_cast(x); }); + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { + return static_cast(x); + }); return Vectorized::loadu(static_cast(buffer.data())); } template > -inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) { +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; - std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), - [](const IntType& x) { return static_cast(x); }); + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { + return static_cast(x); + }); return Vectorized::loadu(static_cast(buffer.data())); } +// clang-format off // Example inputs for AVX512: // a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} @@ -1099,8 +1360,11 @@ inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) // b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} // returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// clang-format on template -inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; @@ -1116,10 +1380,14 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { buffer2[i] = a_arr[i * 2 + 1]; buffer2[half_size + i] = b_arr[i * 2 + 1]; } - return std::make_pair(Vectorized::loadu(static_cast(buffer1)), - Vectorized::loadu(static_cast(buffer2))); + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2) + +// clang-format off // inverse operation of deinterleave2 // Example inputs for AVX512: // a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} @@ -1131,8 +1399,11 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// clang-format on template -inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> interleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; @@ -1148,14 +1419,21 @@ interleave2(const Vectorized& a, const Vectorized& b) { buffer2[i * 2] = a_arr[half_size + i]; buffer2[i * 2 + 1] = b_arr[half_size + i]; } - return std::make_pair(Vectorized::loadu(static_cast(buffer1)), - Vectorized::loadu(static_cast(buffer2))); + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2) + +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP +#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC + template -inline void convert(const src_T *src, dst_T *dst, int64_t n) { +inline void convert(const src_T* src, dst_T* dst, int64_t n) { #ifndef _MSC_VER -# pragma unroll +#pragma unroll #endif for ([[maybe_unused]] const auto i : c10::irange(n)) { *dst = c10::convert(c10::load(src)); @@ -1165,7 +1443,7 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) { } template -inline Vectorized flip(const Vectorized & data) { +inline Vectorized flip(const Vectorized& data) { static constexpr int size = Vectorized::size(); T output[size]; T buffer[size]; @@ -1176,25 +1454,37 @@ inline Vectorized flip(const Vectorized & data) { return Vectorized::loadu(static_cast(output)); } -// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading -// dimension of `src` and `ld_dst` is the leading dimension of `dst`. +// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. +// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading +// dimension of `dst`. template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) { +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst, + int M, + int N) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { - dst[j*ld_dst + i] = src[i*ld_src + j]; + dst[j * ld_dst + i] = src[i * ld_src + j]; } } } template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } -}} // namespace at::vec::CPU_CAPABILITY +} // namespace CPU_CAPABILITY +} // namespace at::vec // additional headers for more operations that depend on vec_base -#include -#include #include +#include +#include diff --git a/aten/src/ATen/cpu/vec/vec_convert.h b/aten/src/ATen/cpu/vec/vec_convert.h index a5cee03dabcf..f5e5177908c1 100644 --- a/aten/src/ATen/cpu/vec/vec_convert.h +++ b/aten/src/ATen/cpu/vec/vec_convert.h @@ -28,8 +28,8 @@ struct VecConvert { }; template -inline std::enable_if_t, Vectorized> -convert(const Vectorized& src) { +inline std::enable_if_t, Vectorized> convert( + const Vectorized& src) { return src; } diff --git a/aten/src/ATen/cpu/vec/vec_half.h b/aten/src/ATen/cpu/vec/vec_half.h index c7c90cc95b47..972d3ee3929b 100644 --- a/aten/src/ATen/cpu/vec/vec_half.h +++ b/aten/src/ATen/cpu/vec/vec_half.h @@ -103,7 +103,9 @@ static inline void transpose_pad_2x32_block( _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); } #else -TORCH_CHECK(false, "transpose_pad_2x32_block is only supported when avx512 is supported") + TORCH_CHECK( + false, + "transpose_pad_2x32_block is only supported when avx512 is supported") #endif } @@ -124,28 +126,31 @@ static inline void pack_vnni2( for (; bk < _K; bk += 2) { int64_t bn = 0; for (; bn < _N; bn += 32) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); } int64_t nrem = N - bn; if (nrem > 0) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); } } if (K % 2 == 1) { int64_t bn = 0; for (; bn < _N; bn += 32) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); } int64_t nrem = N - bn; if (nrem > 0) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); } } #else -TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") + TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") #endif } - } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index c547e5911ecb..e19d7f75388a 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -68,7 +68,12 @@ struct VecMaskTo { } }; -template +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> struct VecMaskCast { static inline VecMask apply( const VecMask& vec_mask) { @@ -88,15 +93,17 @@ struct VecMaskCheck { static inline bool all_zero(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); - return std::all_of( - mask, mask + VectorizedN::size(), [](T m) { return m == static_cast(0); }); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m == static_cast(0); + }); } static inline bool all_masked(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); - return std::all_of( - mask, mask + VectorizedN::size(), [](T m) { return m != static_cast(0); }); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m != static_cast(0); + }); } static inline bool is_masked(const VectorizedN& vec_mask, int i) { @@ -159,13 +166,11 @@ class VecMask { } static VecMask blendv( - const VecMask& c, - const VecMask& b, - const VecMask& a) { + const VecMask& c, + const VecMask& b, + const VecMask& a) { VectorizedN result = VectorizedN::blendv( - VectorizedN(c), - VectorizedN(b), - VectorizedN(a)); + VectorizedN(c), VectorizedN(b), VectorizedN(a)); return result; } @@ -174,14 +179,14 @@ class VecMask { const VecMask& b, int64_t count = size()) { VectorizedN result = VectorizedN::set( - VectorizedN(a), - VectorizedN(b), - count); + VectorizedN(a), VectorizedN(b), count); return result; } void store(bool* b, int count = size()) { - constexpr int L = (VectorizedN::size() + Vectorized::size() - 1)/ Vectorized::size(); + constexpr int L = + (VectorizedN::size() + Vectorized::size() - 1) / + Vectorized::size(); auto res = this->to(); res.store(b, count); return; diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d39fe4be31c9..4f5e511c33bc 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -3,7 +3,6 @@ */ #include -#include #include #include #include @@ -222,36 +221,6 @@ static size_t _getWorkspaceSize() { return workspace_size; } -void* _getWorkspaceWithoutHandle() { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - auto stream = c10::cuda::getCurrentCUDAStream(); - cudaStream_t _stream = stream; - auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); - auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); - TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); - return workspace_it->second.mutable_get(); -} - -void* _getWorkspace(size_t& workspaceSize) { -// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2)) - workspaceSize = _getWorkspaceSize(); - auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize(); - if (cublasWorkspaceSize < workspaceSize) { - TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize, - " bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize, - " bytes. Please increase CUBLAS workspace size", - " via CUBLAS_WORKSPACE_CONFIG or decrease requested" - " CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace" - " size will be limited to the CUBLAS workspace size."); - workspaceSize = cublasWorkspaceSize; - } -// #else -// workspaceSize = at::cuda::getChosenWorkspaceSize(); -// #endif - auto workspace_ptr = _getWorkspaceWithoutHandle(); - return workspace_ptr; -} - } // anonymous namespace namespace at::cuda::blas { @@ -441,8 +410,9 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { } CuBlasLtMatmulPreference preference; - size_t workspaceSize = 0; - auto workspace_ptr = _getWorkspace(workspaceSize); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); #ifndef USE_ROCM @@ -454,6 +424,8 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); #endif + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; @@ -486,7 +458,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { c, Cdesc.descriptor(), &heuristicResult.algo, - workspace_ptr, + workspace.mutable_data_ptr(), workspaceSize, at::cuda::getCurrentCUDAStream()); } @@ -1132,9 +1104,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) } #if defined(USE_ROCM) && !defined(_MSC_VER) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { //no CK GEMM version for gfx1100 + if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100 gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } else{ at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(float)); @@ -1404,8 +1374,9 @@ bool gemm_and_bias( CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld); CuBlasLtMatmulPreference preference; - size_t workspaceSize = 0; - auto workspace_ptr = _getWorkspace(workspaceSize); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); #ifndef USE_ROCM @@ -1419,7 +1390,8 @@ bool gemm_and_bias( preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment); #endif - auto stream = c10::cuda::getCurrentCUDAStream(); + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); @@ -1453,9 +1425,9 @@ bool gemm_and_bias( result_ptr, Cdesc.descriptor(), &heuristicResult.algo, - workspace_ptr, + workspace.mutable_data_ptr(), workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); } if (cublasStatus != CUBLAS_STATUS_SUCCESS) { TORCH_WARN( @@ -1646,9 +1618,9 @@ void scaled_gemm( #endif // if CUDA_VERSION >= 12080 } - auto stream = c10::cuda::getCurrentCUDAStream(); - size_t workspaceSize = 0; - auto workspace_ptr = _getWorkspace(workspaceSize); + size_t workspaceSize = _getWorkspaceSize(); + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + CuBlasLtMatmulPreference preference; preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); cublasLtMatmulHeuristicResult_t heuristicResult = {}; @@ -1731,9 +1703,9 @@ void scaled_gemm( result_ptr, Ddesc.descriptor(), &heuristicResult.algo, - workspace_ptr, + workspace.mutable_data_ptr(), workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1809,8 +1781,8 @@ void int8_gemm( CuBlasLtMatmulPreference preference; size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto workspace = allocator.allocate(workspaceSize); + auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( @@ -1848,7 +1820,7 @@ void int8_gemm( nullptr, // Heuristics don't seem to work for int8 #endif #ifdef USE_ROCM - workspace.mutable_get(), + workspace.mutable_data_ptr(), #else nullptr, // Non-zero workspace doesn't seem to work. #endif diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 65019bb6097c..dc33cb541370 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -2,7 +2,6 @@ // Light-weight version of CUDAContext.h with fewer transitive includes #include -#include #include #include @@ -88,8 +87,6 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); TORCH_CUDA_CPP_API void clearCublasWorkspaces(); -TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublas_handle_stream_to_workspace(); -TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); #if defined(CUDART_VERSION) || defined(USE_ROCM) TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 8a039ea3bff9..ce1ef86d5091 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -9,6 +9,7 @@ #include #include +#include namespace at::cuda { namespace { @@ -71,6 +72,8 @@ using Block = HostBlock; struct CUDACachingHostAllocatorImpl : public CachingHostAllocatorImpl { private: + std::unordered_map use_host_register; + void allocate_host_memory(size_t size, void** ptr) override { // Pinned memory pointers allocated by any device can be directly used by // any other device, regardless of the current device at the time of @@ -89,13 +92,16 @@ struct CUDACachingHostAllocatorImpl } auto start = std::chrono::system_clock::now(); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - pinned_use_cuda_host_register()) { + bool use_register = c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_cuda_host_register(); + if (use_register) { allocWithCudaHostRegister(ptr, size); } else { // Use cudaHostAlloc for allocating pinned memory (global lock in driver) C10_CUDA_CHECK(cudaHostAlloc(ptr, size, cudaHostAllocDefault)); } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(use_host_register.count(*ptr) == 0); + use_host_register[*ptr] = use_register; + auto end = std::chrono::system_clock::now(); auto duration = std::chrono::duration_cast(end - start); @@ -108,15 +114,19 @@ struct CUDACachingHostAllocatorImpl void free_block(Block* block) override { auto start = std::chrono::system_clock::now(); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - pinned_use_cuda_host_register()) { - void* ptr = block->ptr_; + // Users may change the allocator config at will. torch unit tests do this. + // However, allocations using cudaHostRegister should use corresonding + // cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost. + void* ptr = block->ptr_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(use_host_register.count(ptr) == 1); + if (use_host_register[ptr]) { AT_CUDA_CHECK(cudaHostUnregister(ptr)); // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) std::free(ptr); } else { - AT_CUDA_CHECK(cudaFreeHost(block->ptr_)); + AT_CUDA_CHECK(cudaFreeHost(ptr)); } + use_host_register.erase(ptr); auto end = std::chrono::system_clock::now(); auto duration = std::chrono::duration_cast(end - start); @@ -185,21 +195,6 @@ struct CUDACachingHostAllocatorImpl } } - void registerPages(const void* ptr, size_t size) { - AT_CUDA_CHECK( - cudaHostRegister((void*)ptr, (size_t)size, cudaHostRegisterDefault)); - - // If host and device pointer don't match, give a warning and exit - void* devptr = nullptr; - AT_CUDA_CHECK(cudaHostGetDevicePointer(&devptr, (void*)ptr, 0)); - TORCH_CHECK( - (void*)devptr == (void*)ptr, - "Host and device pointer dont match with cudaHostRegister. " - "Please dont use this feature by setting " - "PYTORCH_CUDA_ALLOC_CONF=use_cuda_host_register:False (default)", - ""); - } - void allocWithCudaHostRegister(void** ptr, size_t roundSize) { // Here we do regular allocation, pre-fault/map the pages, and then do // cudaHostRegister with GPU mapping flags to lock the pages, so we @@ -249,7 +244,8 @@ struct CUDACachingHostAllocatorImpl } // Register the mapped pages using cudaHostRegister - registerPages(*ptr, roundSize); + AT_CUDA_CHECK( + cudaHostRegister(*ptr, roundSize, cudaHostRegisterDefault)); } }; diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 6f7f0536437c..06fa4f91abff 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -83,6 +83,11 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v #endif +std::map, at::DataPtr>& cublas_handle_stream_to_workspace() { + static auto& instance = *new std::map, at::DataPtr>; + return instance; +} + void createCublasHandle(cublasHandle_t *handle) { TORCH_CUDABLAS_CHECK(cublasCreate(handle)); } @@ -104,11 +109,6 @@ using CuBlasPoolType = DeviceThreadHandlePool, at::DataPtr>& cublas_handle_stream_to_workspace() { - static auto& instance = *new std::map, at::DataPtr>; - return instance; -} - void clearCublasWorkspaces() { cublas_handle_stream_to_workspace().clear(); } @@ -123,11 +123,9 @@ size_t parseChosenWorkspaceSize() { // for extra convenience val = getenv("ROCBLAS_WORKSPACE_CONFIG"); } - /* 32MiB default, 128MiB for MI300 */ - cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = properties->gcnArchName; - const bool gfx94 = device_arch.find("gfx94") != std::string::npos; - const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; + /* 32MiB default, 128MiB for gfx94x/gfx95x */ + const bool gfx94_95 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95"}); + const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; #else /* :4096:2:16:8 default, 32MiB for Hopper */ cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu new file mode 100644 index 000000000000..6c20daed2e02 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu new file mode 100644 index 000000000000..2adb6a519882 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu new file mode 100644 index 000000000000..39e29b7668c9 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu new file mode 100644 index 000000000000..3ad1ebd2a56a --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu new file mode 100644 index 000000000000..098615b68345 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu new file mode 100644 index 000000000000..d58e0c8d5ce7 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu new file mode 100644 index 000000000000..fe24f72151fb --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu new file mode 100644 index 000000000000..1373668316c2 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu new file mode 100644 index 000000000000..f52f97fe588a --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu new file mode 100644 index 000000000000..db28bb602acc --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu new file mode 100644 index 000000000000..7ad51b90b834 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs.cu b/aten/src/ATen/cuda/cub-RadixSortPairs.cuh similarity index 82% rename from aten/src/ATen/cuda/cub-RadixSortPairs.cu rename to aten/src/ATen/cuda/cub-RadixSortPairs.cuh index 0eefb0824e59..bd40deb4125b 100644 --- a/aten/src/ATen/cuda/cub-RadixSortPairs.cu +++ b/aten/src/ATen/cuda/cub-RadixSortPairs.cuh @@ -1,3 +1,5 @@ +#pragma once + #define TORCH_ASSERT_NO_OPERATORS #include #include @@ -66,20 +68,7 @@ void radix_sort_pairs_impl( int64_t begin_bit, \ int64_t end_bit); -AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) - #define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) -AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8) -AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8) -AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8) -AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) - } // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 9847386c3394..ac5c833070c1 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -448,8 +448,14 @@ DeviceIndex CUDAHooks::getCurrentDevice() const { } #ifdef USE_ROCM -bool CUDAHooks::isGPUArch(DeviceIndex device_index, const std::vector& archs) const { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(device_index); +bool CUDAHooks::isGPUArch(const std::vector& archs, DeviceIndex device_index) const { + hipDeviceProp_t* prop; + if (device_index == -1){ + prop = at::cuda::getCurrentDeviceProperties(); + } else { + prop = at::cuda::getDeviceProperties(device_index); + } + std::string device_arch = prop->gcnArchName; for (std::string arch : archs) { size_t substring = device_arch.find(arch); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index d0be9d5f535c..2b4c11136321 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -57,7 +57,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { DeviceIndex getCurrentDevice() const override; #ifdef USE_ROCM - bool isGPUArch(DeviceIndex device_index, const std::vector& archs) const override; + bool isGPUArch(const std::vector& archs, DeviceIndex device_index = -1) const override; #endif void deviceSynchronize(DeviceIndex device_index) const override; }; diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index b8187b4254bf..5e885d4764d2 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -40,9 +40,9 @@ enum TORCH_CUDA_CPP_API TuningStatus { class TORCH_CUDA_CPP_API ResultEntry { public: explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {} - explicit ResultEntry(std::string key, double time, const std::string& blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(blas_sig) {} - bool operator==(const ResultEntry& other) { return key_ == other.key_; } - bool operator!=(const ResultEntry& other) { return key_ != other.key_; } + explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {} + bool operator==(const ResultEntry& other) const { return key_ == other.key_; } + bool operator!=(const ResultEntry& other) const { return key_ != other.key_; } operator std::string () { return key_; } std::string GetKey() const { return key_; } double GetTime() const { return time_; } diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 9b54a84dd68d..9bc30ba84ea5 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -196,7 +196,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { } #ifdef USE_ROCM - virtual bool isGPUArch(DeviceIndex /*device_index*/, const std::vector& /*archs*/) const { + virtual bool isGPUArch(const std::vector& /*archs*/, DeviceIndex = -1 /*device_index*/) const { TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP); } #endif diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index e512efad59bb..14f03bd17f4d 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -773,6 +773,15 @@ std::tuple> scatter_add_batch_rule( self, self_bdim, dim, index, index_bdim, src, src_bdim); } +std::tuple> scatter_add__batch_rule( + const Tensor& self, std::optional self_bdim, + int64_t dim, + const Tensor& index, std::optional index_bdim, + const Tensor& src, std::optional src_bdim) { + return scatter_batch_rule(ATEN_FN(scatter_add_), + self, self_bdim, dim, index, index_bdim, src, src_bdim); +} + std::tuple> scatter_reduce_batch_rule( const Tensor& self, std::optional self_bdim, int64_t dim, @@ -1278,6 +1287,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule); VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule); VMAP_SUPPORT(scatter_add, scatter_add_batch_rule); + VMAP_SUPPORT(scatter_add_, scatter_add__batch_rule); VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule); VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule); VMAP_SUPPORT2(scatter_reduce, two, scatter_reduce_two_batch_rule); diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 897e83890c79..d1947435d2bc 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2693,12 +2693,6 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b TORCH_CHECK(other.dim() >= 2, "torch.ormqr: other must have at least 2 dimensions."); int64_t left_size_condition = left ? -2 : -1; - TORCH_CHECK( - other.size(left_size_condition) >= tau.size(-1), - "torch.ormqr: other.shape[", - left_size_condition, - "] must be greater than or equal to tau.shape[-1]"); - TORCH_CHECK( other.size(left_size_condition) == input.size(-2), "torch.ormqr: other.shape[", @@ -2706,8 +2700,10 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b "] must be equal to input.shape[-2]"); TORCH_CHECK( - tau.size(-1) <= input.size(-1), - "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]"); + std::min(other.size(left_size_condition), input.size(-1)) == tau.size(-1), + "torch.ormqr: tau.shape[-1] must be equal to min(other.shape[", + left_size_condition, + "], input.shape[-1])"); TORCH_CHECK( input.dim() - tau.dim() == 1, @@ -2716,6 +2712,7 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b tau.dim(), " and input.ndim is equal to ", input.dim()); + TORCH_CHECK( input.dim() == other.dim(), "torch.ormqr: ", diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 4019cf2ff9b1..415b8cab1364 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -30,7 +30,9 @@ void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalS _assert_match(tensor.sym_sizes(), sizes, "sizes"); _assert_match(tensor.sym_strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); - _assert_match(tensor.device(), device, "device"); + if (tensor.device().type() != DeviceType::Meta) { + _assert_match(tensor.device(), device, "device"); + } _assert_match(tensor.layout(), layout, "layout"); } @@ -38,7 +40,9 @@ void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef s _assert_match(tensor.sizes(), sizes, "sizes"); _assert_match(tensor.strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); - _assert_match(tensor.device(), device, "device"); + if (tensor.device().type() != DeviceType::Meta) { + _assert_match(tensor.device(), device, "device"); + } _assert_match(tensor.layout(), layout, "layout"); } diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 78cc6237451d..38bfdaa397f0 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -603,7 +603,7 @@ struct ConvParams { // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of // a depthwise multiplier) bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const { - return input.is_cuda() && + return (input.is_cuda() || input.is_xpu()) && !transposed && (input.ndimension() == 4 || input.ndimension() == 5) && at::symint::size(input, 1) == groups && @@ -1219,6 +1219,12 @@ ConvBackend _select_conv_backend( return ConvBackend::Cudnn; } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) { return ConvBackend::MiopenDepthwise; + } else if (params.use_mkldnn(input, weight)) { + if (params.transposed) { + return ConvBackend::MkldnnTranspose; + } else { + return ConvBackend::Mkldnn; + } } else { if (input.ndimension() == 4) { return ConvBackend::CudaDepthwise2d; diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 1be4ec37dfef..e1d329fbf30f 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -147,6 +147,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( c10::DeviceType::MPS, c10::DeviceType::MTIA, c10::DeviceType::XPU, + c10::DeviceType::HPU, c10::DeviceType::PrivateUse1 ); // Check if the device type is supported. @@ -203,6 +204,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; #endif + case DeviceType::HPU: + return hpu_dispatch_ptr != nullptr ? DispatchResult(hpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; + case DeviceType::PrivateUse1: return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 725d0d08bae1..cbe4b23c6711 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -44,6 +44,7 @@ // - MPS: Apple Silicon GPUs (Metal Performance Shaders) // - MTIA: Meta Training and Inference Devices // - XPU: Intel GPUs +// - HPU: Reserved for HPU (Intel Gaudi) device types // - PrivateUse1: Reserved for private/custom device types // // If you want to update the list of supported devices, add a new dispatch_ptr @@ -196,6 +197,7 @@ struct TORCH_API DispatchStubImpl { #if defined(USE_XPU) void* xpu_dispatch_ptr; #endif + void* hpu_dispatch_ptr; void* privateuse1_dispatch_ptr; #else std::atomic cpu_dispatch_ptr{nullptr}; @@ -206,6 +208,7 @@ struct TORCH_API DispatchStubImpl { #if defined(USE_XPU) void* xpu_dispatch_ptr = nullptr; #endif + void* hpu_dispatch_ptr = nullptr; void* privateuse1_dispatch_ptr = nullptr; #endif }; @@ -259,6 +262,10 @@ struct DispatchStub { } #endif + void set_hpu_dispatch_ptr(FnPtr fn_ptr) { + impl.hpu_dispatch_ptr = reinterpret_cast(fn_ptr); + } + void set_hip_dispatch_ptr(FnPtr fn_ptr) { impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr); } @@ -337,6 +344,13 @@ struct RegisterXPUDispatch { } }; +template +struct RegisterHPUDispatch { + RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){ + stub.set_hpu_dispatch_ptr(value); + } +}; + template struct RegisterMPSDispatch { RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { @@ -437,6 +451,9 @@ struct RegisterPRIVATEUSE1Dispatch { #define REGISTER_XPU_DISPATCH(name, fn) \ static RegisterXPUDispatch name ## __register(name, fn); +#define REGISTER_HPU_DISPATCH(name, fn) \ + static RegisterHPUDispatch name ## __register(name, fn); + #define REGISTER_HIP_DISPATCH(name, fn) \ static RegisterHIPDispatch name ## __register(name, fn); diff --git a/aten/src/ATen/native/RangeUtils.h b/aten/src/ATen/native/RangeUtils.h index d1756db75016..d3ad1c6ab7df 100644 --- a/aten/src/ATen/native/RangeUtils.h +++ b/aten/src/ATen/native/RangeUtils.h @@ -2,9 +2,9 @@ #include #include -namespace at { -namespace native { + +namespace at::native { template int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { @@ -42,4 +42,4 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar return static_cast(size_d); } -}} // namespace at::native +} // namespace at::native diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index d790a79de83e..de56c906d004 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -11,14 +11,14 @@ #include #endif -#include +#include namespace at::native { Scalar item(const Tensor& self) { auto numel = self.sym_numel(); TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar"); - if (torch::autograd::GradMode::is_enabled() && self.requires_grad()) { + if (at::GradMode::is_enabled() && self.requires_grad()) { TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n" "Consider using tensor.detach() first."); } diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 0658ed6f27bd..79aaac48034a 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -756,7 +756,7 @@ static DimVector default_alldims(const Tensor& self, at::OptionalIntArrayRef dim IntArrayRef dim_unwrapped = *dim_opt; dim.resize(dim_unwrapped.size()); for (const auto i : c10::irange(dim.size())) { - dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalars=*/false); + dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalar=*/false); } } else { dim.resize(self.dim()); diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index ce0057909830..420a81767fba 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -887,7 +887,7 @@ static inline void mvlgamma_check(const Tensor& self, int64_t p) { Tensor mvlgamma(const Tensor& self, int64_t p) { mvlgamma_check(self, p); auto dtype = c10::scalarTypeToTypeMeta(self.scalar_type()); - if (at::isIntegralType(self.scalar_type(), /*include_bool=*/true)) { + if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) { // int -> float promotion dtype = c10::get_default_dtype(); } diff --git a/aten/src/ATen/native/WeightNorm.cpp b/aten/src/ATen/native/WeightNorm.cpp index 428669e6466b..bbd39809085a 100644 --- a/aten/src/ATen/native/WeightNorm.cpp +++ b/aten/src/ATen/native/WeightNorm.cpp @@ -53,8 +53,8 @@ std::tuple weight_norm_cpu( int64_t dim) { auto w = at::empty_like(v, at::MemoryFormat::Contiguous); - // align with cuda behavior, keep norm in 'Float' when g is 'BFloat16' - const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ? + // align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'/'Half' + const auto dtype = (g.scalar_type() == at::ScalarType::BFloat16 || g.scalar_type() == at::ScalarType::Half) ? at::ScalarType::Float : g.scalar_type(); auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype)); weight_norm_stub(kCPU, w, norm, v, g, dim); @@ -93,10 +93,7 @@ Tensor _weight_norm auto v = v_in.contiguous(); auto g = g_in.contiguous(); - auto has_half_dtype = v.scalar_type() == at::ScalarType::Half - || g.scalar_type() == at::ScalarType::Half; - - bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1)); + bool can_use_fused = (dim == 0) || (dim == v.dim() - 1); if (can_use_fused) { // weight_norm does not have a derivative defined for it, so this will route back through diff --git a/aten/src/ATen/native/cpu/WeightNormKernel.cpp b/aten/src/ATen/native/cpu/WeightNormKernel.cpp index 9ee5c97be8bc..5e866d538768 100644 --- a/aten/src/ATen/native/cpu/WeightNormKernel.cpp +++ b/aten/src/ATen/native/cpu/WeightNormKernel.cpp @@ -48,7 +48,8 @@ void weight_norm_first_dim_kernel( } template -inline void sum_norm_per_row( +inline std::enable_if_t, void> +sum_norm_per_row( scalar_t* out_ptr, const scalar_t* v_ptr, int64_t size) { @@ -61,16 +62,18 @@ inline void sum_norm_per_row( size); } -inline void sum_norm_per_row( +template +inline std::enable_if_t, void> +sum_norm_per_row( float* out_ptr, - const BFloat16* v_ptr, + const scalar_t* v_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec out_fvec0 = fVec::loadu(out_ptr + d) + v_fvec0 * v_fvec0; fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + v_fvec1 * v_fvec1; @@ -84,7 +87,8 @@ inline void sum_norm_per_row( } template -inline void apply_norm_per_row( +inline std::enable_if_t, void> +apply_norm_per_row( scalar_t* w_ptr, const scalar_t* v_ptr, const scalar_t* a_ptr, @@ -98,21 +102,23 @@ inline void apply_norm_per_row( size); } -inline void apply_norm_per_row( - BFloat16* w_ptr, - const BFloat16* v_ptr, +template +inline std::enable_if_t, void> +apply_norm_per_row( + scalar_t* w_ptr, + const scalar_t* v_ptr, const float* a_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec w_fvec0 = fVec::loadu(a_ptr + d) * v_fvec0; fVec w_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * v_fvec1; - bVec w_bvec = convert_float_bfloat16(w_fvec0, w_fvec1); + bVec w_bvec = vec::convert_from_float(w_fvec0, w_fvec1); w_bvec.store(w_ptr + d); } for(; d < size; ++d) { @@ -222,7 +228,8 @@ void weight_norm_backward_first_dim_kernel( } template -inline void sum_product_per_row( +inline std::enable_if_t, void> +sum_product_per_row( scalar_t* out_ptr, const scalar_t* grad_w_ptr, const scalar_t* v_ptr, @@ -237,19 +244,21 @@ inline void sum_product_per_row( size); } -inline void sum_product_per_row( +template +inline std::enable_if_t, void> +sum_product_per_row( float* out_ptr, - const BFloat16* grad_w_ptr, - const BFloat16* v_ptr, + const scalar_t* grad_w_ptr, + const scalar_t* v_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d); - auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec); + auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float(grad_w_bvec); bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec out_fvec0 = fVec::loadu(out_ptr + d) + grad_w_fvec0 * v_fvec0; fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + grad_w_fvec1 * v_fvec1; @@ -264,7 +273,8 @@ inline void sum_product_per_row( } template -inline void apply_per_row_backward( +inline std::enable_if_t, void> +apply_per_row_backward( scalar_t* grad_v_ptr, const scalar_t* grad_w_ptr, const scalar_t* v_ptr, @@ -282,26 +292,28 @@ inline void apply_per_row_backward( size); } -inline void apply_per_row_backward( - BFloat16* grad_v_ptr, - const BFloat16* grad_w_ptr, - const BFloat16* v_ptr, +template +inline std::enable_if_t, void> +apply_per_row_backward( + scalar_t* grad_v_ptr, + const scalar_t* grad_w_ptr, + const scalar_t* v_ptr, const float* a_ptr, const float* b_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d); - auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec); + auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float(grad_w_bvec); bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec grad_v_fvec0 = fVec::loadu(a_ptr + d) * grad_w_fvec0 - fVec::loadu(b_ptr + d) * v_fvec0; fVec grad_v_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * grad_w_fvec1 - fVec::loadu(b_ptr + d + fVec::size()) * v_fvec1; - bVec grad_v_bvec = convert_float_bfloat16(grad_v_fvec0, grad_v_fvec1); + bVec grad_v_bvec = vec::convert_from_float(grad_v_fvec0, grad_v_fvec1); grad_v_bvec.store(grad_v_ptr + d); } for(; d < size; ++d) { @@ -395,7 +407,7 @@ void weight_norm_kernel( int64_t dim) { TORCH_INTERNAL_ASSERT(dim == 0 || dim == v.dim() - 1, "fused kernels can only be applied for first or last dim"); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, v.scalar_type(), + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, v.scalar_type(), "weight_norm_kernel", [&]() { using accscalar_t = at::opmath_type; if (dim == 0) { @@ -420,7 +432,7 @@ void weight_norm_backward_kernel( int64_t dim) { TORCH_INTERNAL_ASSERT(dim == 0 || dim == saved_v.dim() - 1, "fused kernels can only be applied for first or last dim"); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, saved_v.scalar_type(), + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, saved_v.scalar_type(), "weight_norm_backward_kernel", [&]() { using accscalar_t = at::opmath_type; if (dim == 0) { diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index 8c1000f8de47..4807a689e8c2 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -570,10 +570,8 @@ ComputeInternalGradients( at::parallel_for(0, N * C, 1, [=](int64_t start, int64_t end) { constexpr int64_t K = Vec::size(); const int64_t inner_size = HxW / K * K; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - std::array ds_arr; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - std::array db_arr; + std::array ds_arr{}; + std::array db_arr{}; for (const auto i : c10::irange(start, end)) { const T* dY_ptr = dY + i * HxW; const T* X_ptr = X + i * HxW; diff --git a/aten/src/ATen/native/cuda/AveragePool2d.cu b/aten/src/ATen/native/cuda/AveragePool2d.cu index 41fbddb3c583..25eda2b6eabb 100644 --- a/aten/src/ATen/native/cuda/AveragePool2d.cu +++ b/aten/src/ATen/native/cuda/AveragePool2d.cu @@ -402,11 +402,12 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_cuda) ( bool use_divisor = divisor_override.has_value(); const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 - constexpr int double_threads = 768; -#else - constexpr int double_threads = 1024; -#endif + cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); + const bool gesm10x = properties->major >= 10; + int double_threads = 1024; + if (gesm10x) { + double_threads = 768; + } AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_backward_out_cuda_frame", diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 50043e3e8534..90a1a8ee07f2 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -265,24 +266,16 @@ static bool getDisableAddmmCudaLt() { #ifdef USE_ROCM static bool isSupportedHipLtROCmArch(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 - "gfx1100", "gfx1101", "gfx1200", "gfx1201" + "gfx1100", "gfx1101", "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" #endif }; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch(archs, index); } #endif @@ -950,43 +943,31 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { return _int_mm_out_cuda(self, mat2, result); } -static bool _scaled_mm_allowed_device() { - auto dprops = at::cuda::getCurrentDeviceProperties(); +static bool _scaled_mm_allowed_device(bool sm90_only=false) { #ifdef USE_ROCM - std::string device_arch = dprops->gcnArchName; static const std::vector archs = { "gfx942", #if ROCM_VERSION >= 60300 - "gfx1200", "gfx1201" + "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" #endif }; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch(archs); #else - return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (sm90_only) { + return dprops->major == 9; + } else { + return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); + } #endif } #ifdef USE_ROCM static bool _scaled_mm_is_fnuz() { - auto dprops = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = dprops->gcnArchName; - static const std::vector archs = {"gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); } #endif @@ -1425,16 +1406,20 @@ namespace { } } - bool transposed(const Tensor& mat) { + bool check_valid_strides_and_return_transposed(const Tensor& mat) { IntArrayRef tensor_strides = mat.strides(); IntArrayRef tensor_sizes = mat.sizes(); int end_dim = mat.dim() - 1; + int alignment = 16 / mat.element_size(); + TORCH_CHECK(uint64_t(mat.data_ptr()) % 16 ==0, "expected data_ptr to be aligned to 16 bytes\n"); if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max(1, tensor_sizes[end_dim - 1]))) { + TORCH_CHECK(tensor_strides[end_dim] % alignment == 0, "strides should be multiple of 16 bytes"); return true; } else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max(1, tensor_sizes[end_dim]))) { + TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); return false; } else { - TORCH_CHECK(false, "Tensor should not be self-overlapping"); + TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); } } @@ -1500,13 +1485,13 @@ const std::optional& scale_result, std::optional out_dtype, bool use_fast_accum) { #ifndef USE_ROCM - bool allowed_device = _scaled_mm_allowed_device(); - TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); + TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0"); TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); - TORCH_CHECK(!transposed(mat_a), "Expected mat1 to not be transposed"); - TORCH_CHECK(transposed(mat_b), "Expected mat2 to be transposed"); + TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); + TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; @@ -1524,7 +1509,7 @@ bool use_fast_accum) { ")."); - + TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); if (offs.has_value()) { @@ -1567,5 +1552,42 @@ bool use_fast_accum) { } +Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, +const std::optional& offs, +const std::optional& bias, +std::optional out_dtype) { +#ifndef USE_ROCM + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); + TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0"); + + TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type()); + TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type()); + TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); + TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + // check that the strides are valid, the fn will throw an error if not + check_valid_strides_and_return_transposed(mat_a); + check_valid_strides_and_return_transposed(mat_b); + TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d"); + + if (offs.has_value()) { + TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); + TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); + } + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm"); + TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); + + const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs); + Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_)); + at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); + return out; +#else + TORCH_CHECK(false, "grouped gemm is not supported on ROCM") +#endif +} + + } // namespace at::native diff --git a/aten/src/ATen/native/cuda/CuFFTPlanCache.h b/aten/src/ATen/native/cuda/CuFFTPlanCache.h index 08d07c4b45a5..06276c72c53a 100644 --- a/aten/src/ATen/native/cuda/CuFFTPlanCache.h +++ b/aten/src/ATen/native/cuda/CuFFTPlanCache.h @@ -16,7 +16,7 @@ #include #include -namespace at { namespace native { namespace detail { +namespace at::native::detail { // Enum representing the FFT type enum class CuFFTTransformType : int8_t { @@ -58,7 +58,7 @@ struct CuFFTParams } }; -static_assert(std::is_trivial_v, ""); +static_assert(std::is_trivial_v ); // Returns true if the transform type has complex input inline bool cufft_complex_input(CuFFTTransformType type) { @@ -491,4 +491,4 @@ void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_si int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index); void cufft_clear_plan_cache_impl(DeviceIndex device_index); -}}} // namespace at::native::detail +} // namespace at::native::detail diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu new file mode 100644 index 000000000000..d43875e3c8a6 --- /dev/null +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -0,0 +1,383 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include + + +// Two warninngs in Cutlass included header files +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") + +// Determine if the architecture supports rowwise scaled mm +// Currently failing on windows with: +// https://github.com/NVIDIA/cutlass/issues/1571 +#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \ + CUDA_VERSION >= 12000 + +#define BUILD_GG_KERNEL +#endif + +#if defined(BUILD_GG_KERNEL) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { +using Strides = at::cuda::detail::Strides; // std::array; + +template +struct Schedule { + using CooperativeSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using CooperativeEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using PongEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using KernelSchedule = + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; +}; + +int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +int round_up_to_nearest_multiple(int a, int b) { + return ceildiv(a, b) * b; +} + +template < + bool a_row_major, + bool b_row_major, + bool Pong, + typename TB_M, + typename TB_N, + typename TB_K> +void bf16bf16_grouped_gemm_impl_sm90( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { + using DtypeA = cutlass::bfloat16_t; + using DtypeB = cutlass::bfloat16_t; + using DtypeOutput = cutlass::bfloat16_t; + using DtypeAccum = float; + using LayoutA = cute::conditional_t< + a_row_major, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor>; + constexpr int AlignmentA = 16 / sizeof(DtypeA); + + using LayoutB = cute::conditional_t< + b_row_major, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor>; + constexpr int AlignmentB = 16 / sizeof(DtypeB); + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using KernelSchedule = + typename Schedule::KernelSchedule; + using EpilogueSchedule = + typename Schedule::EpilogueSchedule; + using ProblemShape = cutlass::gemm::GroupProblemShape< + cute::Shape>; // per + // group + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, + DtypeAccum, + DtypeOutput, + LayoutOutput*, + AlignmentOutput, + DtypeOutput, + LayoutOutput*, + AlignmentOutput, + EpilogueSchedule, + cutlass::epilogue::fusion:: + LinearCombination>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + DtypeA, + LayoutA*, + AlignmentA, + DtypeB, + LayoutB*, + AlignmentB, + DtypeAccum, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; + int32_t M, N, K, group_count; + + M = mat_a.size(-2); + K = mat_a.size(-1); + N = mat_b.size(-1); + + if (mat_a.dim() == 2 && mat_b.dim() == 2) { + // if both inputs are ragged, K is dynamic, M and N come from inputs + group_count = offs->size(0); + K = -1; + } else if (mat_a.dim() == 2) { + group_count = mat_b.size(0); + M = -1; + } else if (mat_b.dim() == 2) { + group_count = mat_a.size(0); + N = -1; + } else { + // regular bmm + group_count = mat_a.size(0); + } + + TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups"); + const int64_t problem_shape_size = + group_count * ((int64_t)sizeof(ProblemShape::UnderlyingProblemShape)); + + const int64_t stride_size = 3 * group_count * ((int64_t)sizeof(StrideA)); + + // dummy tmas are created based on these pointer-to-pointers + // the actual values are never used, they are replaced + // by real addresses, but for dummy tma creation to succeed + // due to bug in cuda < 12.4 the pointers have to be aligned to 128 bits + const int group_alignment = 16 / sizeof(void*); + const int aligned_group_count = + round_up_to_nearest_multiple(group_count, group_alignment); + int64_t input_args_size = aligned_group_count * 3 * sizeof(void*) + + problem_shape_size + stride_size; + + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto input_buf = allocator.allocate(input_args_size); + void* buf_ptr = input_buf.get(); + DtypeA** inputA_ptrs = reinterpret_cast(buf_ptr); + DtypeB** inputB_ptrs = + reinterpret_cast(inputA_ptrs + aligned_group_count); + DtypeOutput** output_ptrs = + reinterpret_cast(inputB_ptrs + aligned_group_count); + static_assert( + sizeof(StrideA) == 8, "expected StrideA to be 8 bytes for alignment"); + StrideA* stride_A = + reinterpret_cast(output_ptrs + aligned_group_count); + StrideB* stride_B = reinterpret_cast(stride_A + group_count); + StrideOutput* stride_output = + reinterpret_cast(stride_B + group_count); + ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast( + stride_output + group_count); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto make_strides = [](at::IntArrayRef strides) -> Strides { + Strides out; + std::copy(strides.begin(), strides.end(), out.begin()); + return out; + }; + + Strides tensor_StrideA = make_strides(mat_a.strides()); + Strides tensor_StrideB = make_strides(mat_b.strides()); + Strides tensor_StrideOutput = make_strides(out.strides()); + + at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>( + reinterpret_cast(mat_a.data_ptr()), + reinterpret_cast(mat_b.data_ptr()), + reinterpret_cast(out.data_ptr()), + static_cast(nullptr), // type for template inference + static_cast(nullptr), // type for template inference + inputA_ptrs, + inputB_ptrs, + output_ptrs, + static_cast(nullptr), // type for template inference + static_cast(nullptr), // type for template inference + problem_sizes, + stride_A, + stride_B, + stride_output, + offs.has_value() ? offs->const_data_ptr() : nullptr, + M, + N, + K, + tensor_StrideA, + tensor_StrideB, + tensor_StrideOutput, + 0, + 0, + a_row_major, + b_row_major); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {group_count, problem_sizes, nullptr}, + {(const DtypeA**)inputA_ptrs, + stride_A, + (const DtypeB**)inputB_ptrs, + stride_B}, + {{}, + (const DtypeOutput**)output_ptrs, + stride_output, + output_ptrs, + stride_output}}; + + arguments.epilogue.thread.alpha = 1.0; + arguments.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + + int sm_count = + at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount; + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); + } + arguments.hw_info.sm_count = sm_count; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = allocator.allocate(workspace_size); + Gemm gemm; + TORCH_CHECK( + gemm.can_implement(arguments) == cutlass::Status::kSuccess, + "cutlass cannot implement"); + TORCH_CHECK( + gemm.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess, + "cutlass cannot initialize"); + auto status = gemm(at::cuda::getCurrentCUDAStream()); + TORCH_CHECK( + status == cutlass::Status::kSuccess, + "cutlass cannot run, error ", + int(status)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_bf16_grouped_kernel_on_tile_size( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { + int32_t M, N, K, group_count; + + M = mat_a.size(-2); + K = mat_a.size(-1); + N = mat_b.size(-1); + + // below we assume that gemms are approx same size + if (mat_a.dim() == 2 && mat_b.dim() == 2) { + // if both inputs are ragged, K is dynamic, M and N come from inputs + group_count = offs->size(0); + K = K / group_count; + } else if (mat_a.dim() == 2) { + group_count = mat_b.size(0); + M = M / group_count; + } else if (mat_b.dim() == 2) { + group_count = mat_a.size(0); + N = N / group_count; + } + // bool large = + // ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || + // (K >= 2048 && N >= 2048)); + bool small = (M <= 128 || N <= 128); + if (small) { + bf16bf16_grouped_gemm_impl_sm90< + a_row_major, + b_row_major, + /*Pong*/ true, + cute::_64, + cute::_128, + cute::_128>(mat_a, mat_b, offs, bias, out); + } else { + bf16bf16_grouped_gemm_impl_sm90< + a_row_major, + b_row_major, + /*Pong*/ false, + cute::_128, + cute::_256, + cute::_64>(mat_a, mat_b, offs, bias, out); + } +} + +void dispatch_bf16_grouped_kernel_on_ab_transpose( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { + // we already checked that one of the strides is 1 + bool a_row_major = mat_a.stride(-1) == 1; + bool b_row_major = mat_b.stride(-1) == 1; + if (a_row_major && b_row_major) { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } else if (a_row_major && !b_row_major) { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } else if (!a_row_major && b_row_major) { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } else { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } +} + +} // namespace +#endif + +namespace at::cuda::detail { + +void bf16bf16_grouped_mm( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { +#if defined(BUILD_GG_KERNEL) + dispatch_bf16_grouped_kernel_on_ab_transpose(mat_a, mat_b, offs, bias, out); +#else + TORCH_CHECK(false, "grouped mm is not supported on your system"); +#endif +} + +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/GroupMM.h b/aten/src/ATen/native/cuda/GroupMM.h new file mode 100644 index 000000000000..1fc23207a090 --- /dev/null +++ b/aten/src/ATen/native/cuda/GroupMM.h @@ -0,0 +1,12 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void bf16bf16_grouped_mm( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/GroupMMCommon.cuh b/aten/src/ATen/native/cuda/GroupMMCommon.cuh new file mode 100644 index 000000000000..613e2a6331d1 --- /dev/null +++ b/aten/src/ATen/native/cuda/GroupMMCommon.cuh @@ -0,0 +1,122 @@ +#pragma once +#include + +namespace at::cuda::detail { + +using Strides = std::array; + +template < + typename DtypeA, + typename DtypeB, + typename DtypeOutput, + typename DtypeScale, + typename ProblemShape, + typename StrideA, + typename StrideB, + typename StrideOutput> +__global__ void prepare_grouped_gemm_data( + DtypeA* A, + DtypeB* B, + DtypeOutput* output, + DtypeScale* scale_A, + DtypeScale* scale_B, + DtypeA** A_ptrs, + DtypeB** B_ptrs, + DtypeOutput** output_ptrs, + DtypeScale** inputA_scale_ptrs, + DtypeScale** inputB_scale_ptrs, + ProblemShape* problem_sizes, + // Strides for cutlass, cute::Stride + StrideA* stride_A, + StrideB* stride_B, + StrideOutput* stride_output, + const int32_t* offs, + int32_t M, + int32_t N, + int32_t K, + // Original strides of the input tensors + Strides tensor_StrideA, + Strides tensor_StrideB, + Strides tensor_StrideOutput, + int64_t a_scale_stride, + int64_t b_scale_stride, + bool a_row_major = true, + bool b_row_major = false) { + int32_t tid = threadIdx.x; + int32_t delta = 0; + if (offs != nullptr) { + int32_t start = tid == 0 ? 0 : offs[tid - 1]; + delta = offs[tid] - start; + int align = 16 / sizeof(DtypeA); + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected dynamic dimension byte size to be multiple of 16 \n"); + } + int64_t lda, ldb, ldoutput; + if (M < 0) { + // A and output is 2d + M = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + } else if (N < 0) { + N = delta; + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[1]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; + } + } else if (K < 0) { + // A, B is 2d, output is 3d + K = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * M; + inputB_scale_ptrs[tid] = scale_B + tid * N; + } + } else { + // A, B, output are 3D + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + } + problem_sizes[tid] = ProblemShape(M, N, K); + + // make_cute_packed_stride only replaces one of the stride elements with + // one the provided values in the shape arguments + // the indices of the src/dst depend on whether A/B are row-major + // so constructing shape argument with two similar lda values + // while it looks non-sensical (and it is a nonsensical shape) + // is fine for these stride construction purposes - the one that will be used + // for replacement is correct, the other one is ignored, and we don't have to + // branch on whether A/B are row-major + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); + stride_output[tid] = + cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); +} +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/MiscUtils.h b/aten/src/ATen/native/cuda/MiscUtils.h index e616a7d1fcfb..f733f3a38099 100644 --- a/aten/src/ATen/native/cuda/MiscUtils.h +++ b/aten/src/ATen/native/cuda/MiscUtils.h @@ -4,8 +4,8 @@ #include #include -namespace at { -namespace native { + +namespace at::native { static inline int cuda_int_cast(int64_t value, const char* varname) { auto result = static_cast(value); @@ -28,5 +28,4 @@ static inline Storage pin_memory(int64_t size) { /*resizable=*/false); } -} // namespace native -} // namespace at +} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Resize.h b/aten/src/ATen/native/cuda/Resize.h index d5de128cac1d..b2c3efe5a719 100644 --- a/aten/src/ATen/native/cuda/Resize.h +++ b/aten/src/ATen/native/cuda/Resize.h @@ -5,7 +5,7 @@ #include -namespace at { namespace native { +namespace at::native { TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); @@ -50,4 +50,4 @@ inline TensorImpl* resize_impl_cuda_( return self; } -}} +} diff --git a/aten/src/ATen/native/cuda/ScaledGroupMM.cu b/aten/src/ATen/native/cuda/ScaledGroupMM.cu index 7573dd943498..fe6fb2dba0b6 100644 --- a/aten/src/ATen/native/cuda/ScaledGroupMM.cu +++ b/aten/src/ATen/native/cuda/ScaledGroupMM.cu @@ -22,16 +22,15 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") #if defined(BUILD_ROWWISE_FP8_KERNEL) +#include #include #include #include #include #include -#include #include #include -#include #include #include @@ -51,101 +50,7 @@ C10_DIAGNOSTIC_POP() namespace { -using Strides = std::array; - -template < - typename DtypeA, - typename DtypeB, - typename DtypeOutput, - typename DtypeScale, - typename ProblemShape, - typename StrideA, - typename StrideB, - typename StrideOutput> -__global__ void prepare_gemm_data( - DtypeA* A, - DtypeB* B, - DtypeOutput* output, - DtypeScale* scale_A, - DtypeScale* scale_B, - DtypeA** A_ptrs, - DtypeB** B_ptrs, - DtypeOutput** output_ptrs, - DtypeScale** inputA_scale_ptrs, - DtypeScale** inputB_scale_ptrs, - ProblemShape* problem_sizes, - // Strides for cutlass, cute::Stride - StrideA* stride_A, - StrideB* stride_B, - StrideOutput* stride_output, - const int32_t* offs, - int32_t M, - int32_t N, - int32_t K, - // Original strides of the input tensors - Strides tensor_StrideA, - Strides tensor_StrideB, - Strides tensor_StrideOutput, - int64_t a_scale_stride, - int64_t b_scale_stride) { - int32_t tid = threadIdx.x; - int32_t delta = 0; - if (offs != nullptr) { - int32_t start = tid == 0 ? 0 : offs[tid - 1]; - delta = offs[tid] - start; - CUDA_KERNEL_ASSERT(delta % 16 == 0 && "expected dynamic dimension to be multiple of 16\n"); - } - int64_t lda, ldb, ldoutput; - if (M < 0) { - // A and output is 2d - M = delta; - lda = tensor_StrideA[0]; - ldb = tensor_StrideB[2]; // B is transposed - ldoutput = tensor_StrideOutput[0]; - A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * lda; - inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; - output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; - B_ptrs[tid] = B + tid * tensor_StrideB[0]; - inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; - } else if (N < 0) { - N = delta; - lda = tensor_StrideA[1]; - ldb = tensor_StrideB[1]; // B is transposed - ldoutput = tensor_StrideOutput[0]; - A_ptrs[tid] = A + tid * tensor_StrideA[0]; - inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; - output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; - B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * ldb; - inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; - } else if (K < 0) { - // A, B is 2d, output is 3d - K = delta; - lda = tensor_StrideA[0]; - ldb = tensor_StrideB[1]; // B is transposed - ldoutput = tensor_StrideOutput[1]; - A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1]; - B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1]; - inputA_scale_ptrs[tid] = scale_A + tid * M; - inputB_scale_ptrs[tid] = scale_B + tid * N; - output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; - } else { - // A, B, output are 3D - lda = tensor_StrideA[1]; - ldb = tensor_StrideB[2]; - ldoutput = tensor_StrideOutput[1]; - A_ptrs[tid] = A + tid * tensor_StrideA[0]; - B_ptrs[tid] = B + tid * tensor_StrideB[0]; - inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; - inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; - output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; - } - problem_sizes[tid] = ProblemShape(M, N, K); - - stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {M, lda, 1}); - stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {N, ldb, 1}); - stride_output[tid] = - cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); -} +using Strides = at::cuda::detail::Strides; using DtypeScale = float; using DtypeAccum = float; @@ -205,7 +110,6 @@ struct Schedule { using ClusterShape = cute::Shape; }; - int ceildiv(int a, int b) { return (a + b - 1) / b; } @@ -257,8 +161,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( typename Schedule:: EpilogueSchedule; // TODO remove *BroadcastPtrArrays and replace with just Broadcast - // when https://github.com/NVIDIA/cutlass/pull/2120/ is in the tagged cutlass version - // Implement rowwise scaling epilogue. + // when https://github.com/NVIDIA/cutlass/pull/2120/ is in the tagged cutlass + // version Implement rowwise scaling epilogue. using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcastPtrArray< 0, TileShape, @@ -345,6 +249,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( group_count = mat_a.size(0); } + TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups"); + const int64_t problem_shape_size = group_count * ((int64_t)sizeof(ProblemShape::UnderlyingProblemShape)); @@ -383,7 +289,6 @@ void f8f8bf16_grouped_gemm_impl_sm90( reinterpret_cast( stride_output + group_count); - TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups"); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto make_strides = [](at::IntArrayRef strides) -> Strides { @@ -400,7 +305,7 @@ void f8f8bf16_grouped_gemm_impl_sm90( int64_t a_scale_stride = scale_a.stride(0); int64_t b_scale_stride = scale_b.stride(0); - prepare_gemm_data<<<1, group_count, 0, stream>>>( + at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>( reinterpret_cast(mat_a.data_ptr()), reinterpret_cast(mat_b.data_ptr()), reinterpret_cast(out.data_ptr()), @@ -427,46 +332,50 @@ void f8f8bf16_grouped_gemm_impl_sm90( C10_CUDA_KERNEL_LAUNCH_CHECK(); -// auto buf_cpu = mat_a.new_empty( -// input_args_size, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); -// AT_CUDA_CHECK(cudaMemcpy( -// (char*)buf_cpu.data_ptr(), -// buf_ptr, -// input_args_size, -// cudaMemcpyDeviceToHost)); -// char* buf_ptr_cpu = (char*)buf_cpu.data_ptr(); -// DtypeA** inputA_ptrs_h = reinterpret_cast(buf_ptr_cpu); -// DtypeB** inputB_ptrs_h = -// reinterpret_cast(inputA_ptrs_h + aligned_group_count); -// DtypeOutput** output_ptrs_h = -// reinterpret_cast(inputB_ptrs_h + aligned_group_count); -// DtypeScale** inputA_scale_ptrs_h = -// reinterpret_cast(output_ptrs_h + aligned_group_count); -// DtypeScale** inputB_scale_ptrs_h = -// reinterpret_cast(inputA_scale_ptrs_h + aligned_group_count); -// StrideA* stride_A_h = -// reinterpret_cast(inputB_scale_ptrs_h + aligned_group_count); -// StrideB* stride_B_h = reinterpret_cast(stride_A_h + group_count); -// StrideOutput* stride_output_h = -// reinterpret_cast(stride_B_h + group_count); -// ProblemShape::UnderlyingProblemShape* problem_sizes_h = -// reinterpret_cast( -// stride_output_h + group_count); - -// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << " " -// << out.data_ptr() << " " << scale_a.data_ptr() << " " -// << scale_b.data_ptr() << "\n"; -// for (int i = 0; i < group_count; i++) { -// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n"; -// std::cout << "B " << (void*)inputB_ptrs_h[i] << "\n"; -// std::cout << "O " << (void*)output_ptrs_h[i] << "\n"; -// std::cout << "A_scale " << (void*)inputA_scale_ptrs_h[i] << "\n"; -// std::cout << "B_scale " << (void*)inputB_scale_ptrs_h[i] << "\n"; -// std::cout << "sizes " << problem_sizes_h[i] << "\n"; -// std::cout << "strideA" << stride_A_h[i] << "\n"; -// std::cout << "strideB" << stride_B_h[i] << "\n"; -// std::cout << "stride_output" << stride_output_h[i] << "\n"; -// } + // auto buf_cpu = mat_a.new_empty( + // input_args_size, + // at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + // AT_CUDA_CHECK(cudaMemcpy( + // (char*)buf_cpu.data_ptr(), + // buf_ptr, + // input_args_size, + // cudaMemcpyDeviceToHost)); + // char* buf_ptr_cpu = (char*)buf_cpu.data_ptr(); + // DtypeA** inputA_ptrs_h = reinterpret_cast(buf_ptr_cpu); + // DtypeB** inputB_ptrs_h = + // reinterpret_cast(inputA_ptrs_h + aligned_group_count); + // DtypeOutput** output_ptrs_h = + // reinterpret_cast(inputB_ptrs_h + aligned_group_count); + // DtypeScale** inputA_scale_ptrs_h = + // reinterpret_cast(output_ptrs_h + aligned_group_count); + // DtypeScale** inputB_scale_ptrs_h = + // reinterpret_cast(inputA_scale_ptrs_h + + // aligned_group_count); + // StrideA* stride_A_h = + // reinterpret_cast(inputB_scale_ptrs_h + + // aligned_group_count); + // StrideB* stride_B_h = reinterpret_cast(stride_A_h + + // group_count); StrideOutput* stride_output_h = + // reinterpret_cast(stride_B_h + group_count); + // ProblemShape::UnderlyingProblemShape* problem_sizes_h = + // reinterpret_cast( + // stride_output_h + group_count); + + // std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << " + // " + // << out.data_ptr() << " " << scale_a.data_ptr() << " " + // << scale_b.data_ptr() << "\n"; + // for (int i = 0; i < group_count; i++) { + // std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n"; + // std::cout << "B " << (void*)inputB_ptrs_h[i] << "\n"; + // std::cout << "O " << (void*)output_ptrs_h[i] << "\n"; + // std::cout << "A_scale " << (void*)inputA_scale_ptrs_h[i] << "\n"; + // std::cout << "B_scale " << (void*)inputB_scale_ptrs_h[i] << "\n"; + // std::cout << "sizes " << problem_sizes_h[i] << "\n"; + // std::cout << "strideA" << stride_A_h[i] << "\n"; + // std::cout << "strideB" << stride_B_h[i] << "\n"; + // std::cout << "stride_output" << stride_output_h[i] << "\n"; + // } // int device_id = 0; // cutlass::KernelHardwareInfo kernel_hw_info = // cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); @@ -484,7 +393,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( output_ptrs, stride_output}}; - int sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount; + int sm_count = + at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount; if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index b2fd2dc85895..e2eb2226acf4 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -27,7 +27,8 @@ namespace at::native { constexpr int CAT_ARRAY_BATCH_SIZE = 128; constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; -constexpr int ALIGNED_VEC_LOAD_BYTES = 16; +constexpr int ALIGNED_VEC_LOAD_BYTES_16 = 16; +constexpr int ALIGNED_VEC_LOAD_BYTES_8 = 8; namespace { @@ -72,14 +73,14 @@ inline std::tuple getCatGridRocm(unsigned int max_elements_per_tenso return std::make_tuple(grid, block); } -template +template inline std::tuple getCatGridContig(unsigned int max_elements_per_tensor, ptrdiff_t nTensors) { constexpr unsigned int threads_per_block = 128; constexpr unsigned int min_aligned_vec_per_thread = 1; constexpr unsigned int max_tb_per_sm = 32; - unsigned int elements_per_thread = ALIGNED_VEC_LOAD_BYTES / sizeof(T) * + unsigned int elements_per_thread = aligned_vec_load_bytes / sizeof(T) * min_aligned_vec_per_thread; unsigned int max_threads = ceil_div(max_elements_per_tensor, elements_per_thread); unsigned int thread_blocks = ceil_div(max_threads, threads_per_block); @@ -230,16 +231,19 @@ __global__ void CatArrayBatchedCopy_contig( to improve memory bandwidth throughput. */ -template -__global__ void CatArrayBatchedCopy_aligned16_contig( +template +__global__ void CatArrayBatchedCopy_alignedK_contig( T* output, CatArrInputTensorMetadata inputs, TensorSizeStride os, const int concatDim, IndexType dimStride) { - // This kernel tries to use 128 bit loads - constexpr int kILP = ALIGNED_VEC_LOAD_BYTES / sizeof(T); + // This kernel tries to use aligned_vec_load_bytes*8 bit loads + // Special case 2-byte types to use 8-byte vec loads to reduce register pressure + // The below lambda is to allow cc compiler to pass kILP>0 checks for large types (e.g. ComplexDouble, 16 bytes) + constexpr int kILP = aligned_vec_load_bytes / sizeof(T) > 0 ? aligned_vec_load_bytes / sizeof(T) : ALIGNED_VEC_LOAD_BYTES_16/sizeof(T); + IndexType inputOffset = (blockIdx.x * blockDim.x + threadIdx.x) * kILP; IndexType inputStride = gridDim.x * blockDim.x * kILP; @@ -349,7 +353,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i isAligned = false; #else // If at least one of the inputs is not aligned, we can't call the - // CatArrayBatchedCopy_aligned16_contig + // CatArrayBatchedCopy_alignedK_contig isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]); #endif @@ -385,7 +389,10 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i #else dim3 applyBlock, catGrid; if (isContig && sizeof(scalar_t) > 2) { - std::tie(catGrid, applyBlock) = getCatGridContig( + std::tie(catGrid, applyBlock) = getCatGridContig( + max_elements_per_tensor, batchCounter); + } else if (isContig && sizeof(scalar_t) == 2) { + std::tie(catGrid, applyBlock) = getCatGridContig( max_elements_per_tensor, batchCounter); } else { applyBlock = dim3(32 * 16); @@ -406,8 +413,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i } // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - if (isContig && isAligned && sizeof(scalar_t) >= 4 && sizeof(scalar_t) <= 8) {\ - CatArrayBatchedCopy_aligned16_contig<<<\ + if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\ + CatArrayBatchedCopy_alignedK_contig<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ + } else if (isContig && isAligned && sizeof(scalar_t) == 2) { \ + CatArrayBatchedCopy_alignedK_contig<<<\ catGrid, applyBlock, 0, stream.stream()>>>(\ data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ } else if (isContig) {\ diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index dcc9237d737e..7fc3947879f4 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -135,16 +135,7 @@ template using VecT = T __attribute__((ext_vector_type(Rank))); static bool isCDNA2orLater(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; - static const std::vector archs = {"gfx90a", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index); } #else diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 0d63a2f979c9..ee573e2e566f 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -508,6 +508,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized( } } + template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, @@ -657,6 +658,7 @@ bool aligned_grid > __global__ void +__launch_bounds__(block_dim_x * block_dim_y) GammaBetaBackwardCUDAKernelTemplate( int64_t M, int64_t N, diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h index 5e1f255ebe08..4ab411d9a025 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h @@ -36,8 +36,8 @@ // The current pytorch implementation sets gesvdj tolerance to epsilon of a C++ data type to target the best possible precision. constexpr int cusolver_gesvdj_max_sweeps = 400; -namespace at { -namespace native { + +namespace at::native { void geqrf_batched_cublas(const Tensor& input, const Tensor& tau); void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular); @@ -90,4 +90,4 @@ C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&); }} // namespace cuda::detail #endif -}} // namespace at::native +} // namespace at::native diff --git a/aten/src/ATen/native/cudnn/RNNUtils.h b/aten/src/ATen/native/cudnn/RNNUtils.h index 7e2869a80574..841164622172 100644 --- a/aten/src/ATen/native/cudnn/RNNUtils.h +++ b/aten/src/ATen/native/cudnn/RNNUtils.h @@ -6,9 +6,8 @@ #include // Declares utilities used by RNN.cpp and also needed by external consumers -namespace at { -namespace native { -namespace cudnn_rnn { + +namespace at::native::cudnn_rnn { TORCH_CUDA_CPP_API std::tuple> copy_weights_to_flat_buf_views( @@ -27,6 +26,4 @@ copy_weights_to_flat_buf_views( bool allow_type_change = false, bool include_bias = true); -} // namespace cudnn_rnn -} // namespace native -} // namespace at +} // namespace at::native::cudnn_rnn diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip index 14756167b142..552f0de84541 100644 --- a/aten/src/ATen/native/hip/ck_gemm_half.hip +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -598,9 +598,7 @@ void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) { template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { + if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half)); } else{ dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half)); diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h index cc5739825d7e..03ef7ce450c1 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h @@ -20,7 +20,7 @@ #endif #endif -namespace at { namespace native { +namespace at::native { // Mapping ScalarType to ideep tensor data_type TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type); @@ -62,6 +62,6 @@ TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_cons // Set MKLDNN verbose level TORCH_API int set_verbose(int level); -}} +} #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index cc3d4ec9555d..d2abeda0e6ff 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -418,4 +418,53 @@ TORCH_IMPL_FUNC(addmv_out_xpu) xpu::addmv_out(self, mat, vec, beta, alpha, const_cast(result)); } +Tensor _weight_int4pack_mm_xpu( + const Tensor& A, + const Tensor& B, + int64_t qGroupSize, + const Tensor& qScale, + const Tensor& qZeros) { + auto M = A.size(0); // M + auto N = B.size(0); // N1=LCM(N, K) + TORCH_CHECK( + A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, + __func__, + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); + + TORCH_CHECK(B.dtype() == kInt, __func__, " : expect B to be int32 tensor."); + TORCH_CHECK( + qZeros.dtype() == kChar, + __func__, + " : expect qZeros to be int8 tensor currently."); + TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor."); + + TORCH_CHECK( + qGroupSize > 1 && qGroupSize % 32 == 0, + __func__, + " : expect qGroupSize to be multiple of 32 and greater than 1, got ", + qGroupSize); + + TORCH_CHECK( + qScale.dim() == 2 && qScale.size(1) == N, + __func__, + ": expect qScale to be 2d tensor with sizes [:, ", + N, + "]"); + TORCH_CHECK( + qZeros.dim() == 2 && qZeros.size(1) == N, + __func__, + ": expect qZeros to be 2d tensor with sizes [:, ", + N, + "]"); + + auto C = at::empty({M, N}, A.options()); + + // qscale:[K/qGroupSize, N] + // qzp:[K/qGroupSize, N] + at::native::onednn::woq_matmul_int4(C, A, B, qScale, qZeros, qGroupSize); + + return C; +} } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h index df14020466f5..eb09d37c4b75 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h @@ -131,7 +131,7 @@ struct PostOpParam { class Attr { public: - Attr() : q_scale_(1.f), q_zero_point_(0) {} + Attr() : q_scale_(1.f) {} Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} /***** eltwise *****/ diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp new file mode 100644 index 000000000000..66d4ffaa9b8a --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp @@ -0,0 +1,179 @@ +#include + +#include +#include + +#include +#include + +namespace at::native::onednn { + +void woq_matmul_int4( + Tensor& result, // torchao: [M, K], dtype: fp16,bf16 + const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16 + const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8 + const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16 + const Tensor& zp, // torchao: [K/group_size, N], dtype: int8 + int64_t group_size) { + size_t dims = result.dim(); + TORCH_CHECK( + dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims); + TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); + + at::Device cur_device = at::Device(at::kXPU, at::xpu::current_device()); + TORCH_CHECK( + cur_device == mat1_.device(), + "_weight_int4pack_mm_with_scales_and_zeros input should be on current device."); + auto& engine = GpuEngineManager::Instance().get_engine(); + auto& stream = GpuStreamManager::Instance().get_stream(); + + Tensor m1 = mat1_; + Tensor m2 = mat2_; + Tensor scale_ = scale; + Tensor zp_ = zp; + Tensor dst = result; + + int m = m1.size(-2); // M + int n = dst.size(-1); // N + int k = m1.size(-1); // K + + // Construct usr md from input + // xxx_usr_md would describe the real layout of inputs + auto m1_usr_dt = get_onednn_dtype(m1); // e.g., half <==> f16 + auto m2_usr_dt = get_onednn_dtype(m2); // int32 tensor, pack 8 int4 + auto scale_usr_dt = get_onednn_dtype(scale_); // bf16 + auto zp_usr_dt = get_onednn_dtype(zp_); // s8 expected currently + auto dst_usr_dt = get_onednn_dtype(dst); // bf16 + + dnnl::memory::dims m1_usr_dims, m2_usr_dims, scale_usr_dims, zp_usr_dims, + dst_usr_dims; + dnnl::memory::dims m1_usr_strides, m2_usr_strides, scale_usr_strides, + zp_usr_strides, dst_usr_strides; + int compressed_k = (int)(k / 8); + int num_groups = (int)(k / group_size); + m1_usr_dims = {m, k}; + m1_usr_strides = {m1.stride(0), m1.stride(1)}; + m2_usr_dims = {compressed_k, n}; + m2_usr_strides = {1, compressed_k}; // k dim contiguous, 4bit pack into s32 + + scale_usr_dims = {num_groups, n}; + scale_usr_strides = {n, 1}; + zp_usr_dims = {num_groups, n}; + zp_usr_strides = {n, 1}; + dst_usr_dims = {m, n}; + dst_usr_strides = {dst.stride(0), dst.stride(1)}; + + dnnl::memory::desc m1_usr_md, m2_usr_md, scale_usr_md, zp_usr_md, dst_usr_md; + + m1_usr_md = dnnl::memory::desc(m1_usr_dims, m1_usr_dt, m1_usr_strides); + m2_usr_md = dnnl::memory::desc(m2_usr_dims, m2_usr_dt, m2_usr_strides); + scale_usr_md = + dnnl::memory::desc(scale_usr_dims, scale_usr_dt, scale_usr_strides); + zp_usr_md = dnnl::memory::desc(zp_usr_dims, zp_usr_dt, zp_usr_strides); + dst_usr_md = dnnl::memory::desc(dst_usr_dims, dst_usr_dt, dst_usr_strides); + + // create usr memory + auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr()); + auto scale_usr_m = make_onednn_memory(scale_usr_md, engine, scale.data_ptr()); + auto zp_usr_m = make_onednn_memory(zp_usr_md, engine, zp.data_ptr()); + + // Construct md for primitive creation + // The xxx_md describes what kinds of matmul the oneDNN does. + // The problem for this op is [m, k] x [k, n] => [m, n] matmul. + auto m1_dt = m1_usr_dt; // bf16 + // Tell oneDNN the weight dtype we want manipulate is u4, + // library needs infer how to unpack u4 data based on the m2_usr_md (s32). + auto m2_dt = dnnl::memory::data_type::u4; + auto scale_dt = scale_usr_dt; // bf16 + // Tell oneDNN the zp dtype we want manipulate is s8 + // library needs infer how to unpack s8 data based on the m2_usr_md. + auto zp_dt = zp_usr_dt; // should be s8, currently + auto dst_dt = dst_usr_dt; + + dnnl::memory::desc m1_md, m2_md, scale_md, zp_md, dst_md; + dnnl::memory::dims m1_dims, m2_dims, scale_dims, zp_dims, dst_dims; + dnnl::memory::dims m1_strides, m2_strides, scale_strides, zp_strides, + dst_strides; + + m1_dims = m1_usr_dims; // {m, k} + m1_strides = m1_usr_strides; // {k, 1} + m2_dims = {k, n}; + m2_strides = {n, 1}; + scale_dims = scale_usr_dims; // {k//group_size, n} + scale_strides = scale_usr_strides; + zp_dims = zp_usr_dims; + zp_strides = zp_usr_strides; + dst_dims = dst_usr_dims; + dst_strides = dst_usr_strides; + + m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides); + m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides); + scale_md = dnnl::memory::desc(scale_dims, scale_dt, scale_strides); + zp_md = dnnl::memory::desc(zp_dims, zp_dt, zp_strides); + dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides); + + std::unordered_map args; + + dnnl::matmul matmul_p; + dnnl::matmul::primitive_desc matmul_pd; + + auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr()); + auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr()); + + void* handle_b = m2_usr_m.get_data_handle(); + // reinterpret m2_usr_memory as u4 + dnnl::memory m2_u4_m( + {{k, n}, dnnl::memory::data_type::u4, dnnl::memory::format_tag::ba}, + engine, + handle_b); + + dnnl::primitive_attr pattr; + pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); +#if ONEDNN_SUPPORT_DETERMINISTIC + if (at::globalContext().deterministicAlgorithms() || + at::globalContext().deterministicMkldnn()) { + pattr.set_deterministic(true); + } +#endif + + // Set scales with multiple scales along K dimension and with groups along K. + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), + {group_size, 1}, + scale_dt); + // Set a single zero point with s8 data type. + pattr.set_zero_points( + DNNL_ARG_WEIGHTS, + (1 << 0) + (1 << 1), + {group_size, 1}, + dnnl::memory::data_type::s8); + + if (m1_dt == dnnl::memory::data_type::f16) + pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); + else if (m1_dt == dnnl::memory::data_type::bf16) + pattr.set_fpmath_mode(dnnl::fpmath_mode::bf16, true); + + matmul_pd = dnnl::matmul::primitive_desc( + engine, m1_md, m2_u4_m.get_desc(), dst_md, pattr); + matmul_p = dnnl::matmul(matmul_pd); + + dnnl::memory m1_m = m1_usr_m, m2_m = m2_u4_m, dst_m = dst_usr_m; + dnnl::memory scale_m = scale_usr_m; // zp_m = zp_u4_m; + Tensor m1_, m2_, zp_new, dst_; + + int scratchpad_size = matmul_pd.scratchpad_desc().get_size(); + Tensor scratchpad_tensor = + at::empty({scratchpad_size}, m1.options().dtype(at::kByte), c10::nullopt); + auto scratchpad_memory = make_onednn_memory( + matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr()); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory}); + + args.insert({DNNL_ARG_SRC, m1_m}); + args.insert({DNNL_ARG_WEIGHTS, m2_u4_m}); + args.insert({DNNL_ARG_DST, dst_m}); + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_m}); + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m}); + dnnl::sycl_interop::execute(matmul_p, stream, args); +} +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index a4f993eebcd6..9d8e9fe50df5 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -89,6 +89,14 @@ TORCH_API sycl::event deconvolution_backward_weights( int64_t groups, const std::vector& deps = {}); +TORCH_API void woq_matmul_int4( + at::Tensor& result, // dst, [M, N] + const at::Tensor& mat1_, // src, [M, K] + const at::Tensor& mat2_, // quantized weight, [K/8, N] + const at::Tensor& scale, // [K/group_size, N] + const at::Tensor& zp, // [k/group_size, N] + int64_t group_size); + dnnl::memory::dims conv_dst_size( int64_t ndim, IntArrayRef src_tz, diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index dff66520ccfc..acd2bf66101f 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -46,9 +46,12 @@ constexpr bool has_size_type_v = has_size_type::value; } // namespace detail +// Returns `gpuAddress` of respective `id` plus storage offset +void* get_tensor_gpu_address(const at::TensorBase&); + class MetalKernelFunction { public: - MetalKernelFunction(MTLComputePipelineState_t cps_); + MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_); ~MetalKernelFunction(); MetalKernelFunction(MetalKernelFunction&) = delete; // Shader properties @@ -56,7 +59,7 @@ class MetalKernelFunction { uint64_t getThreadExecutionWidth() const; uint64_t getStaticThreadGroupMemoryLength() const; void runCommandBlock(std::function f); - // Methods below should be called from runCommandBlock functionT + // Methods below should be called from runCommandBlock function void startEncoding(); void setArg(unsigned idx, const at::TensorBase& t); void setArg(unsigned idx, const void* ptr, uint64_t size); @@ -88,6 +91,7 @@ class MetalKernelFunction { private: MTLComputePipelineState_t cps; + MTLFunction_t func; MTLComputeCommandEncoder_t encoder = nullptr; }; diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 9655988e082a..57fa278b01d8 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -922,7 +922,8 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} } std::shared_ptr MetalShaderLibrary::getKernelFunction(const std::string& name) { - return std::make_shared(getPipelineStateForFunc(name)); + auto [cpl, func] = getLibraryPipelineState(getLibrary(), name); + return std::make_shared(cpl, func); } class BundledShaderLibary : public MetalShaderLibrary { @@ -1088,10 +1089,12 @@ static dispatch_data_t getSectionData(const std::string& name) { } // MetalKernelFunction implementation -MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_) : cps([cps_ retain]) {} +MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_) + : cps([cps_ retain]), func([f_ retain]) {} MetalKernelFunction::~MetalKernelFunction() { [cps release]; + [func release]; } void MetalKernelFunction::runCommandBlock(std::function run) { @@ -1152,6 +1155,10 @@ static dispatch_data_t getSectionData(const std::string& name) { return [cps staticThreadgroupMemoryLength]; } +void* get_tensor_gpu_address(const at::TensorBase& t) { + return reinterpret_cast(getMTLBufferStorage(t).gpuAddress + t.storage_offset() * t.element_size()); +} + } // namespace at::native::mps // Check that c10::metal::ScalarType is strict subset (with matching values) of c10::ScalarType diff --git a/aten/src/ATen/native/mps/kernels/Amp.metal b/aten/src/ATen/native/mps/kernels/Amp.metal new file mode 100644 index 000000000000..abe852798f44 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Amp.metal @@ -0,0 +1,130 @@ +#include +using namespace metal; + +constant constexpr unsigned kmaxThreadGroups = 32; +constant constexpr unsigned kmaxTensors = 32; +constant constexpr unsigned kChunkSize = 65536; + +template +struct AmpNonFiniteCheckAndUnscaleArgs { + metal::array data [[id(0)]]; +}; + +struct MetadataArguments { + ulong numels[kmaxTensors]; + ulong threadgroup_to_tensor[kmaxThreadGroups]; + ulong threadgroup_to_chunk[kmaxThreadGroups]; +}; + +template +kernel void ampNonFiniteCheckAndUnscale( + constant AmpNonFiniteCheckAndUnscaleArgs& pointerArgs [[buffer(0)]], + constant MetadataArguments& metadata [[buffer(1)]], + device float& foundInf [[buffer(2)]], + constant T& invScale [[buffer(3)]], + uint local_tid [[thread_position_in_threadgroup]], + uint tgSize [[threads_per_threadgroup]], + uint group_id [[threadgroup_position_in_grid]]) { + uint threadGroupSize = tgSize; + uint tensor_index = metadata.threadgroup_to_tensor[group_id]; + uint chunk = metadata.threadgroup_to_chunk[group_id]; + uint numel = metadata.numels[tensor_index]; + + uint offset = chunk * kChunkSize; + uint chunk_size = + ((offset + kChunkSize) > numel) ? (numel - offset) : kChunkSize; + + device T* data = pointerArgs.data[tensor_index]; + + for (uint i = local_tid; i < chunk_size; i += threadGroupSize) { + uint index = offset + i; + T val = data[index]; + if (!isfinite(val)) { + foundInf = 1.0f; + } + data[index] = (invScale == static_cast(1.0) ? val : val * invScale); + } +} + +template +kernel void ampNonFiniteCheckAndUnscaleSingle( + device T* data [[buffer(0)]], + device float& foundInf [[buffer(1)]], + constant T& invScale [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + T val = data[tid]; + if (!isfinite(val)) { + foundInf = 1.0f; + } + data[tid] = (invScale == T(1.0) ? val : val * invScale); +} + +template +kernel void ampUpdateScale( + device T& scale [[buffer(0)]], + device int& growth_tracker [[buffer(1)]], + device float& foundInf [[buffer(2)]], + constant T& scaleGrowthFactor [[buffer(3)]], + constant T& scaleBackoffFactor [[buffer(4)]], + constant int& growthInterval [[buffer(5)]]) { + if (foundInf != 0.0f) { + scale *= scaleBackoffFactor; + growth_tracker = 0; + } else { + int g = growth_tracker + 1; + if (g >= growthInterval) { + scale *= scaleGrowthFactor; + g = 0; + } + growth_tracker = g; + } +} + +#define INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(DTYPE) \ + template [[host_name("ampNonFiniteCheckAndUnscale_" #DTYPE)]] kernel void \ + ampNonFiniteCheckAndUnscale( \ + constant AmpNonFiniteCheckAndUnscaleArgs & \ + pointerArgs [[buffer(0)]], \ + constant MetadataArguments & metadata [[buffer(1)]], \ + device float& foundInf [[buffer(2)]], \ + constant DTYPE& invScale [[buffer(3)]], \ + uint local_tid [[thread_position_in_threadgroup]], \ + uint tgSize [[threads_per_threadgroup]], \ + uint group_id [[threadgroup_position_in_grid]]) + +#define INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(DTYPE) \ + template \ + [[host_name("ampNonFiniteCheckAndUnscaleSingle_" #DTYPE)]] kernel void \ + ampNonFiniteCheckAndUnscaleSingle( \ + device DTYPE * data [[buffer(0)]], \ + device float& foundInf [[buffer(1)]], \ + constant DTYPE& invScale [[buffer(2)]], \ + uint tid [[thread_position_in_grid]]) + +#define INSTANTIATE_AMP_UPDATE_SCALE(DTYPE) \ + template [[host_name("ampUpdateScale_" #DTYPE)]] kernel void \ + ampUpdateScale( \ + device DTYPE & scale [[buffer(0)]], \ + device int& growth_tracker [[buffer(1)]], \ + device float& foundInf [[buffer(2)]], \ + constant DTYPE& scaleGrowthFactor [[buffer(3)]], \ + constant DTYPE& scaleBackoffFactor [[buffer(4)]], \ + constant int& growthInterval [[buffer(5)]]) + +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float); +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat); +#endif + +INSTANTIATE_AMP_UPDATE_SCALE(float); +INSTANTIATE_AMP_UPDATE_SCALE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_UPDATE_SCALE(bfloat); +#endif + +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float); +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat); +#endif diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 12b78f32e96e..eb2a038b16be 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -75,6 +75,13 @@ struct chebyshev_polynomial_w_functor { } }; +struct hermite_polynomial_h_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(c10::metal::hermite_polynomial_h_forward(a, b)); + } +}; + struct nextafter_functor { #if __METAL_VERSION__ < 310 template @@ -164,6 +171,8 @@ REGISTER_BINARY_OP(chebyshev_polynomial_v, float, float); REGISTER_BINARY_OP(chebyshev_polynomial_v, half, half); REGISTER_BINARY_OP(chebyshev_polynomial_w, float, float); REGISTER_BINARY_OP(chebyshev_polynomial_w, half, half); +REGISTER_BINARY_OP(hermite_polynomial_h, float, float); +REGISTER_BINARY_OP(hermite_polynomial_h, half, half); #if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(copysign, bfloat, bfloat); @@ -176,6 +185,7 @@ REGISTER_BINARY_OP(chebyshev_polynomial_t, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_u, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat); +REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat); #endif // Complex binary functions diff --git a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal index 2006e768d826..fe5605226748 100644 --- a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal +++ b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal @@ -57,9 +57,9 @@ struct SgdMomentumArguments { }; struct MetadataArguments { - uint32_t numels[kmaxTensors]; - uint32_t threadgroup_to_tensor[kmaxThreadGroups]; - uint32_t threadgroup_to_chunk[kmaxThreadGroups]; + ulong numels[kmaxTensors]; + ulong threadgroup_to_tensor[kmaxThreadGroups]; + ulong threadgroup_to_chunk[kmaxThreadGroups]; }; enum ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; diff --git a/aten/src/ATen/native/mps/operations/Amp.mm b/aten/src/ATen/native/mps/operations/Amp.mm new file mode 100644 index 000000000000..e410d434ec7a --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Amp.mm @@ -0,0 +1,132 @@ +// Copyright © 2022 Apple Inc. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native { +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif +namespace mps { + +static void _amp_non_finite_check_and_unscale_mps_single_impl(const Tensor& scaled_grad, + at::Tensor& found_inf, + const at::Tensor& inv_scale) { + if (scaled_grad.numel() == 0) { + return; + } + TORCH_CHECK(scaled_grad.is_mps(), "Tensor is not on the MPS device."); + TORCH_CHECK(scaled_grad.numel() <= std::numeric_limits::max(), "scaled_grad is too large"); + float inv_scale_val = inv_scale.item(); + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto ampPipelineState = + lib.getPipelineStateForFunc("ampNonFiniteCheckAndUnscaleSingle_" + mps::scalarToMetalTypeString(scaled_grad)); + + const uint32_t threadsPerThreadgroup = 256; + uint32_t numel = static_cast(scaled_grad.numel()); + MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize gridSize = MTLSizeMake(numel, 1, 1); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:ampPipelineState]; + mtl_setArgs(computeEncoder, scaled_grad, found_inf, inv_scale_val); + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; + }); +} + +static void _amp_update_scale_mps_impl(Tensor& self, + Tensor& growth_tracker, + const Tensor& found_inf, + float scale_growth_factor, + float scale_backoff_factor, + int32_t growth_interval) { + auto stream = getCurrentMPSStream(); + auto ampUpdatePipelineState = lib.getPipelineStateForFunc("ampUpdateScale_" + mps::scalarToMetalTypeString(self)); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:ampUpdatePipelineState]; + + mtl_setArgs( + computeEncoder, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + mtl_dispatch1DJob(computeEncoder, ampUpdatePipelineState, 1); + }); +} + +std::pair, id> getAmpCPLState(const std::string& fname) { + return {lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname)}; +} +} // namespace mps + +void _amp_foreach_non_finite_check_and_unscale_mps_(at::TensorList self, + at::Tensor& found_inf, + const at::Tensor& inv_scale) { + if (self.size() == 0) { + return; + } + TORCH_CHECK(inv_scale.is_mps(), "inv_scale must be a MPS tensor."); + TORCH_CHECK(found_inf.is_mps(), "found_inf must be a MPS tensor."); + TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); + TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); + // Ensures client code (GradScaler) filtered scaled_grads by API restrictions. + check_foreach_api_restrictions(self); + + // Prepare a vector of tensor lists. + std::vector> tensor_lists; + if (can_use_fast_route(self)) { + TORCH_CHECK(self[0].is_mps(), "scaled_grads must be MPS tensors."); + tensor_lists.emplace_back(self.vec()); + } else { + tensor_lists.resize(1); + tensor_lists[0].reserve(self.size()); + auto expected_device = self[0].device(); + const auto expected_dtype = self[0].scalar_type(); + for (const at::Tensor& t : self) { + // Ensure that GradScaler has filtered by device, layout, and dtype. + TORCH_CHECK(t.is_mps(), "one of scaled_grads was not a MPS tensor."); + TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); + TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); + if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) { + // Fall back to the single-tensor implementation + mps::_amp_non_finite_check_and_unscale_mps_single_impl(const_cast(t), found_inf, inv_scale); + } else { + tensor_lists[0].push_back(t); + } + } + if (tensor_lists[0].empty()) { + return; + } + } + + std::string kernel_name = + "ampNonFiniteCheckAndUnscale_" + mps::scalarToMetalTypeString(tensor_lists[0][0].scalar_type()); + mps::multi_tensor_apply<1>(kernel_name, tensor_lists, found_inf, inv_scale); +} + +Tensor& _amp_update_scale_mps_(Tensor& self, + Tensor& growth_tracker, + const Tensor& found_inf, + double scale_growth_factor, + double scale_backoff_factor, + int64_t growth_interval) { + mps::_amp_update_scale_mps_impl( + self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + return self; +} +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 92f3aa011ae0..35a3ec81ca07 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -110,6 +110,12 @@ static void chebyshev_polynomial_w_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "chebyshev_polynomial_w"); } +static void hermite_polynomial_h_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), + "hermite_polynomial_h_mps not implemented for non-floating types"); + lib.exec_binary_kernel(iter, "hermite_polynomial_h"); +} + static void polar_mps_kernel(TensorIterator& iter) { lib.exec_binary_kernel(iter, "polar"); } @@ -128,6 +134,7 @@ static void complex_mps_kernel(TensorIterator& iter) { REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_mps_kernel) +REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_mps_kernel) REGISTER_DISPATCH(polar_stub, &polar_mps_kernel); REGISTER_DISPATCH(complex_stub, &complex_mps_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm b/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm index 538d04fed999..4057f2dcbac5 100644 --- a/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm +++ b/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm @@ -114,8 +114,6 @@ void _fused_sgd_kernel_mps_(TensorList params, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { - TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); - if (!momentum_buffer_list.empty()) { return _fused_sgd_with_momentum_kernel_mps_(params, grads, @@ -163,8 +161,6 @@ void _fused_sgd_kernel_mps_(TensorList params, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { - TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); - if (!momentum_buffer_list.empty()) { return _fused_sgd_with_momentum_kernel_mps_(params, grads, diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm deleted file mode 100644 index 5574df89afe5..000000000000 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ /dev/null @@ -1,61 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - -namespace at::native { - -TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { - TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); - TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { - TORCH_WARN_ONCE( - "torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); - auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt); - auto cpu_result = result.to("cpu"); - at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); - info.copy_(cpu_info); - result.copy_(cpu_result); - return; - } - - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - MPSStream* stream = getCurrentMPSStream(); - info.zero_(); - - if (A.numel() == 0) { - return; - } - - if (!result.is_contiguous()) { - result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); - } - - @autoreleasepool { - string key = "inv_out_mps" + getTensorsStringKey({A}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A); - MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - auto feeds = dictionaryFromPlaceholders(inputPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } -} - -} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 22aee2307f69..1a9e841cfbcf 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -2,6 +2,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include // For MTLLanguageVersion_3_1 @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -261,14 +263,14 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, } } -static void linalg_solve_out_mps_impl(const at::Tensor& A, - const at::Tensor& B, +static void linalg_solve_out_mps_impl(const Tensor& A, + const Tensor& B, bool left, bool check_errors, - const at::Tensor& result, - const at::Tensor& LU, - const at::Tensor& pivots, - const at::Tensor& info) { + const Tensor& result, + const Tensor& LU, + const Tensor& pivots, + const Tensor& info) { using namespace mps; TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), @@ -436,6 +438,32 @@ static void linalg_solve_out_mps_impl(const at::Tensor& A, } } +static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + using namespace mps; + TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); + TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); + using CachedGraph = MPSUnaryCachedGraph; + + MPSStream* stream = getCurrentMPSStream(); + info.zero_(); + + if (A.numel() == 0) { + return; + } + + if (!result.is_contiguous()) { + result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); + } + auto A_sizes = A.sizes(); + int ndim = A.dim(); + + Tensor LU = empty_like(A); + Tensor identity = zeros_like(A); + Tensor pivots = empty({A_sizes.begin(), A_sizes.end() - 1}, A.options().dtype(kInt)); + (ndim == 2 ? identity.diagonal() : identity.diagonal(0, -2, -1)).fill_(1); + linalg_solve_out_mps_impl(A, identity, true, check_errors, result, LU, pivots, info); +} + static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); @@ -1427,4 +1455,8 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, (const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) { mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); } + +TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info); +} } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h index cb8d65a129c5..2897d643648a 100644 --- a/aten/src/ATen/native/mps/operations/MultiTensorApply.h +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -11,10 +11,10 @@ static constexpr int64_t kChunkSize = 65536; static constexpr int64_t kmaxThreadGroups = 32; static constexpr int64_t kmaxTensors = 32; -struct MetadataArguments { // the size of this struct must be less than 4 bytes - uint numels[kmaxTensors]; - uint threadgroup_to_tensor[kmaxThreadGroups]; - uint threadgroup_to_chunk[kmaxThreadGroups]; +struct MetadataArguments { // the size of this struct must be less than 4 kilobytes + uint64_t numels[kmaxTensors]; + uint64_t threadgroup_to_tensor[kmaxThreadGroups]; + uint64_t threadgroup_to_chunk[kmaxThreadGroups]; }; struct FusedAdamEncodingFunctor { @@ -253,4 +253,110 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam }); } +std::pair, id> getAmpCPLState(const std::string& fname); +template +void multi_tensor_apply(const std::string& kernel_name, + std::vector>& tensor_lists, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists must match depth."); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [pipeline, function] = getAmpCPLState(kernel_name); + [computeEncoder setComputePipelineState:pipeline]; + + id argumentEncoder = [function newArgumentEncoderWithBufferIndex:0]; + auto tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int tensor_loc = 0; + int threadgroup_loc = 0; + MetadataArguments metadata_arguments; + std::memset(&metadata_arguments, 0, sizeof(metadata_arguments)); + + for (size_t t = 0; t < num_tensors; t++) { + if (tensor_lists[0][t].numel() == 0) + continue; + + // bind each tensor in this list to the correct slots across depths + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + + // save number of elements for this tensor + metadata_arguments.numels[tensor_loc] = tensor_lists[0][t].numel(); + int currentTensorIndex = tensor_loc; + tensor_loc++; + + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + ((numel % kChunkSize) ? 1 : 0); + + // process tensor in chunks based on max chunk size + for (uint chunk = 0; chunk < chunks; chunk++) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = currentTensorIndex; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + threadgroup_loc++; + + // dispatch when we've filled the threadgroup array or finished the chunks + const bool dispatch_now = (threadgroup_loc == kmaxThreadGroups) || (chunk == chunks - 1); + if (dispatch_now) { + // check for a partial dispatch (i.e. more chunks remain for the current tensor) + bool partial = (chunk != chunks - 1); + uint carried_numels = 0; + if (partial) { + carried_numels = metadata_arguments.numels[currentTensorIndex]; + } + + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, (uint32_t)64), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // prepare for the next batch: reset threadgroup count and create a new buffer + threadgroup_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + if (partial) { + // for a partial dispatch, rebind the partially processed tensor to slot 0 + // so that its metadata is in the correct location + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + 0); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + metadata_arguments.numels[0] = carried_numels; + // the currently processed tensor now lives at index 0 + currentTensorIndex = 0; + tensor_loc = 1; + } else { + tensor_loc = 0; + } + } + } + } + + if (threadgroup_loc != 0) { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, static_cast(64)), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + } + }); +} + } // namespace at::native::mps diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1336122e2fd6..48be73ac5eea 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4165,6 +4165,10 @@ MPS: _weight_int4pack_mm_mps CUDA: _weight_int4pack_mm_cuda +- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor + dispatch: + XPU: _weight_int4pack_mm_xpu + # Split int4 pack weight between cpu and other devices due to # https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756. - func: _convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor @@ -7076,6 +7080,11 @@ dispatch: CUDA: _scaled_grouped_mm_cuda +- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + variants: function + dispatch: + CUDA: _grouped_mm_cuda + # NOTE [ Sparse: autograd and API ] # # @@ -10389,6 +10398,7 @@ dispatch: CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ CPU: _amp_foreach_non_finite_check_and_unscale_cpu_ + MPS: _amp_foreach_non_finite_check_and_unscale_mps_ autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out - func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) @@ -10396,6 +10406,7 @@ dispatch: CUDA: _amp_update_scale_cuda_ CPU: _amp_update_scale_cpu_ + MPS: _amp_update_scale_mps_ autogen: _amp_update_scale, _amp_update_scale.out #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor @@ -15262,7 +15273,7 @@ - func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_hermite_polynomial_h_out + CPU, CUDA, MPS: special_hermite_polynomial_h_out python_module: special structured_inherits: TensorIteratorBase structured: True diff --git a/aten/src/ATen/native/quantized/AffineQuantizer.cpp b/aten/src/ATen/native/quantized/AffineQuantizer.cpp index 6bd9bfd687aa..dab9e1cf7fc9 100644 --- a/aten/src/ATen/native/quantized/AffineQuantizer.cpp +++ b/aten/src/ATen/native/quantized/AffineQuantizer.cpp @@ -151,6 +151,7 @@ Tensor& quantize_tensor_per_channel_affine( AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); if (qtensor.device().type() != c10::DeviceType::CUDA && + qtensor.device().type() != c10::DeviceType::XPU && qtensor.device().type() != c10::DeviceType::PrivateUse1) { checkZeroPoints(fn_name, zero_points); } // for cuda and privateuse1, this check will occur in the actual device function @@ -242,6 +243,7 @@ Tensor& dequantize_tensor_per_channel_affine( AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); if(qtensor.device().type() != c10::DeviceType::CUDA && + qtensor.device().type() != c10::DeviceType::XPU && qtensor.device().type() != c10::DeviceType::PrivateUse1){ checkZeroPoints(fn_name, zero_points); } // for cuda and privateuse1, this check will occur in the actual device function diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 46b58e9a38a8..06196043e08d 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1758,23 +1758,26 @@ namespace at::native { std::optional algorithm) { #if AT_MKLDNN_ENABLED() - if (act.dim() == 3 || act.dim() == 5) { - // Conv1D/3D post op check - TORCH_CHECK( - attr == "none", - "quantized pointwise conv", - act.dim()-2, - "d doesn't support unary_post_op fusion. Got unary_post_op: ", - attr, - ".") - } else { - // Conv2D post op check - TORCH_CHECK( - attr == "none" || attr == "relu" || attr == "hardtanh" || attr == "hardswish" || attr == "swish", - "none post_op or post_op relu/hardtanh/hardswish is supported for quantized pointwise conv2d. Got unary_post_op: ", - attr, - ".") + std::vector supported_postop = { + "none" + }; + if (act.dim() == 3) { + // Conv1D post op + supported_postop.push_back("relu"); + } else if (act.dim() == 4) { + // Conv2D post op + supported_postop.push_back("relu"); + supported_postop.push_back("hardtanh"); + supported_postop.push_back("hardswish"); + supported_postop.push_back("swish"); } + TORCH_CHECK( + std::find(supported_postop.begin(), supported_postop.end(), attr) != supported_postop.end(), + "Unsupported post op ", + attr, + " for quantized pointwise conv", + act.dim()-2, + "d.") return _quantized_convolution_onednn( act, act_scale, act_zero_point, weight, weight_scales, weight_zero_points, @@ -2079,6 +2082,8 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise"), at::native::QConvoneDNN::run_pointwise); m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor); m.impl(TORCH_SELECTIVE_NAME("onednn::qconv3d_pointwise"), at::native::QConvoneDNN::run_pointwise); + m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise"), at::native::QConvoneDNN::run_pointwise); + m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor); // Conv2D with binary postop m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"), at::native::QConvoneDNN::run_pointwise_binary); diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 27c484c62bb9..8a70fbffc00d 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -258,6 +258,8 @@ TORCH_LIBRARY(onednn, m) { m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); // Conv2D with binary postop m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor qaccum, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, float accum_scale, int accum_zero_point, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index e74d71fe1aff..75d4e8c75c9b 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -245,7 +245,7 @@ Tensor two_four_sgemm( ElementC(0), {cute::_1{}, cute::_0{}, problem_size.m()}}; } else { - return {ElementC(0)}; + return {{ElementC(0)}}; } }() }; diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu index 35d6559b62ce..47a19d26342e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu @@ -209,7 +209,7 @@ void spgemm_cutlass( std::is_same_v) { return {ElementComputeEpilogue{alpha.to()}}; } else { - return {alpha.to()}; + return {{alpha.to()}}; } }() }; @@ -219,7 +219,7 @@ void spgemm_cutlass( std::is_same_v) { return {ElementComputeEpilogue{beta.to()}}; } else { - return {beta.to()}; + return {{beta.to()}}; } }() }; @@ -230,7 +230,7 @@ void spgemm_cutlass( ElementC(0), {cute::_1{}, cute::_0{}, problem_size.m()}}; } else { - return {ElementC(0)}; + return {{ElementC(0)}}; } }() }; diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index ff4a7730fcc5..948cbe0f4028 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -42,19 +42,25 @@ desc: | This tag indicates if an operator doesn't guarantee bitwise equivalence across different runs of an operator with identical inputs. +- tag: needs_exact_strides + desc: | + This tag indicates that the operator should be passed Tensors following + the same strides as observed in eager when compiled in inductor. + Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} + can apply; if multiple are assigned then we assume the most restrictive one. - tag: needs_fixed_stride_order desc: | This tag indicates that the operator should be passed Tensors following the same stride permutation as observed in eager when compiled in inductor. - Only one of {needs_fixed_stride_order, flexible_layout} can apply; if - multiple are assigned then we assume the most restrictive one. + Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} + can apply; if multiple are assigned then we assume the most restrictive one. - tag: flexible_layout desc: | This tag indicates that the custom operator can accept inputs with varying strides/storage_offset and that when compiled, Inductor is allowed to change the strides/storage_offset of inputs to the custom operator. - Only one of {needs_fixed_stride_order, flexible_layout} can apply; if - multiple are assigned then we assume the most restrictive one. + Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} + can apply; if multiple are assigned then we assume the most restrictive one. # NOTE [Core ATen Ops] - tag: core diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 27397bf78898..66bdaa0baa89 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -28,6 +28,7 @@ #include #else #include +#include #include #include #include @@ -448,6 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) +REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta); int64_t _fused_sdp_choice_meta( const Tensor& query_, @@ -459,6 +461,20 @@ int64_t _fused_sdp_choice_meta( std::optional scale, bool enable_gqa) { auto query_key_set = query_.key_set(); + bool has_hpu = query_key_set.has(c10::DispatchKey::HPU); + if (has_hpu) { + auto choice_int = at::_ops::_fused_sdp_choice::redispatch( + c10::DispatchKeySet(DispatchKey::HPU), + query_, + key, + value, + attn_mask_, + dropout_p, + is_causal, + scale, + enable_gqa); + return choice_int; + } #if defined(USE_ROCM) bool has_rocm = query_key_set.has(c10::DispatchKey::HIP); if (has_rocm) { diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index bb4c3d9cbc18..05acc275b468 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -553,9 +553,10 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled."); } return false; - } else if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { + } else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) { if (debug) { TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward."); + return false; } } @@ -645,7 +646,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { constexpr auto dense_constraints = c10::array_of( check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense + check_batch_size_and_num_heads_dense ); if (has_only_dense_inputs(params)) { diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 22afbac1d079..4591fa253824 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -333,13 +333,14 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) { return true; } +template inline bool check_grouped_query_attention(sdp_params const& params, bool debug) { const auto q_num_heads = params.query.sym_size(-3); const auto k_num_heads = params.key.sym_size(-3); const auto v_num_heads = params.value.sym_size(-3); const bool same_kv_heads = k_num_heads == v_num_heads; - if (!(same_kv_heads)){ + if (requires_same_num_heads && !(same_kv_heads)){ if (debug) { TORCH_WARN( "Both fused kernels require key and value to have the same num_heads and batch_size but got: ", @@ -355,10 +356,10 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug) } // Check if grouped query attention is supported and validate the number of // heads - if (q_num_heads % k_num_heads != 0) { + if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) { if (debug) { TORCH_WARN( - "FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.", + "The number of heads in key/value must divide number of heads in query.", "Got input Key sizes(): ", params.key.sym_size(-3), ", Value sizes(): ", @@ -372,7 +373,7 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug) return true; } -template +template inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { // This is expected to be called after check_tensor_shapes ensuring that the // size() calls won't error since the inputs are all 4 dimensional @@ -407,9 +408,10 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool } if(params.enable_gqa && supports_gqa){ - return check_grouped_query_attention(params, debug); + return check_grouped_query_attention(params, debug); } + // same num heads condition for non-gqa case if (!same_num_heads){ if (debug) { TORCH_WARN( diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp index 4d98cd753159..8d50aa66b4d9 100644 --- a/aten/src/ATen/native/xnnpack/Linear.cpp +++ b/aten/src/ATen/native/xnnpack/Linear.cpp @@ -129,6 +129,7 @@ Tensor run( const IntArrayRef input_size = padded_input.sizes(); std::vector output_size(input_size.cbegin(), input_size.cend()); + // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) output_size.back() = context.output_channels; Tensor output = mobile::empty_with_tail_padding( diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index 127fa78de12d..1763d90cc94e 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -51,8 +51,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl { auto impl = c10::make_intrusive( Storage(storage()), key_set(), data_type_, quantizer_); copy_tensor_metadata( - /*src_impl=*/this, - /*dest_impl=*/impl.get(), + /*src_q_impl=*/this, + /*dest_q_impl=*/impl.get(), /*version_counter=*/version_counter, /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); impl->refresh_numel(); @@ -72,8 +72,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl { auto impl = c10::make_intrusive( Storage(storage()), key_set(), data_type_, quantizer_); copy_tensor_metadata( - /*src_impl=*/this, - /*dest_impl=*/impl.get(), + /*src_q_impl=*/this, + /*dest_q_impl=*/impl.get(), /*version_counter=*/std::move(version_counter), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); impl->refresh_numel(); @@ -91,8 +91,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl { AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); auto q_impl = static_cast(impl.get()); copy_tensor_metadata( - /*src_impl=*/q_impl, - /*dest_impl=*/this, + /*src_q_impl=*/q_impl, + /*dest_q_impl=*/this, /*version_counter=*/version_counter(), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); refresh_numel(); diff --git a/aten/src/ATen/test/mps_test_metal_library.cpp b/aten/src/ATen/test/mps_test_metal_library.cpp index baee8964364d..3f91516e5a66 100644 --- a/aten/src/ATen/test/mps_test_metal_library.cpp +++ b/aten/src/ATen/test/mps_test_metal_library.cpp @@ -54,6 +54,7 @@ TEST(MPSTestMetalLibrary, ArangeWithArgsShader) { }); ASSERT_TRUE((x==y).all().item().toBool()); } + TEST(MPSTestMetalLibrary, Arange2DShader) { const auto size = 16; auto x = torch::empty({size, size}, at::device(at::kMPS)); @@ -71,3 +72,41 @@ TEST(MPSTestMetalLibrary, Arange2DShader) { }); ASSERT_EQ(x.sum().item().to(), 65280); } + +TEST(MPSTestMetalLibrary, ArgumentBuffers) { + constexpr auto nbuffers = 64; + const auto size = 32; + std::vector ibuffers; + std::vector ibuffers_gpu_ptrs; + for([[maybe_unused]] auto idx: c10::irange(nbuffers)) { + ibuffers.push_back(torch::rand({size}, at::device(at::kMPS))); + ibuffers_gpu_ptrs.push_back(get_tensor_gpu_address(ibuffers.back())); + } + auto output = torch::empty({size}, at::device(at::kMPS)); + DynamicMetalShaderLibrary lib(R"MTL( + constant constexpr auto nbuffers = 64; + struct Inputs { + metal::array args; + }; + + kernel void sum_all(device float* output, constant Inputs& inputs, uint idx [[thread_position_in_grid]]) { + output[idx] = 0; + for(auto i = 0; i < nbuffers; ++i) { + output[idx] += inputs.args[i][idx]; + } + } + )MTL"); + auto func = lib.getKernelFunction("sum_all"); + func->runCommandBlock([&] { + func->startEncoding(); + func->setArg(0, output); + func->setArg(1, ibuffers_gpu_ptrs); + func->dispatch(size); + }); + // Compute sum of all 64 input tensors + auto result = torch::zeros({size}, at::device(at::kMPS)); + for(auto buf: ibuffers) { + result += buf; + } + ASSERT_EQ(result.sum().item().to(), output.sum().item().to()); +} diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 4e0780800906..db37925add67 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -192,6 +192,11 @@ namespace { [](vec v) { return v.neg(); }, createDefaultUnaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_int_minimum)); + test_unary( + NAME_INFO(negate), std::negate>(), + [](vec v) { return -v; }, + createDefaultUnaryTestCase(TestSeed()), + RESOLVE_OVERLOAD(filter_int_minimum)); } TYPED_TEST(SignManipulationHalfPrecision, AbsNegate) { typedef enum { @@ -329,7 +334,7 @@ namespace { test_binary( NAME_INFO(fmod), RESOLVE_OVERLOAD(std::fmod), - [](vec v0, vec v1) { return v0.fmod(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).fmod(v1); }, createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_fmod)); } @@ -371,11 +376,22 @@ namespace { } TYPED_TEST(Hyperbolic, Tanh) { using vec = TypeParam; +// NOTE: Because SVE uses ACL logic, the precision changes, hence the adjusted tolerance. +#if defined(CPU_CAPABILITY_SVE) + using UVT = UvalueType; + UVT tolerance = getDefaultTolerance(); + test_unary( + NAME_INFO(tanH), + RESOLVE_OVERLOAD(std::tanh), + [](vec v) { return v.tanh(); }, + createDefaultUnaryTestCase(TestSeed(), tolerance)); +#else test_unary( NAME_INFO(tanH), RESOLVE_OVERLOAD(std::tanh), [](vec v) { return v.tanh(); }, createDefaultUnaryTestCase(TestSeed())); +#endif } TYPED_TEST(Hyperbolic, Sinh) { using vec = TypeParam; @@ -588,8 +604,8 @@ namespace { test_binary( NAME_INFO(atan2), RESOLVE_OVERLOAD(std::atan2), - [](vec v0, vec v1) { - return v0.atan2(v1); + [](const auto& v0, const auto& v1) { + return vec(v0).atan2(v1); }, createDefaultBinaryTestCase(TestSeed())); } @@ -598,7 +614,7 @@ namespace { test_binary( NAME_INFO(pow), RESOLVE_OVERLOAD(std::pow), - [](vec v0, vec v1) { return v0.pow(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).pow(v1); }, createDefaultBinaryTestCase(TestSeed(), false, true)); } TYPED_TEST(RealTests, Hypot) { @@ -606,7 +622,7 @@ namespace { test_binary( NAME_INFO(hypot), RESOLVE_OVERLOAD(std::hypot), - [](vec v0, vec v1) { return v0.hypot(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).hypot(v1); }, createDefaultBinaryTestCase(TestSeed(), false, true)); } TYPED_TEST(RealTests, NextAfter) { @@ -614,7 +630,7 @@ namespace { test_binary( NAME_INFO(nextafter), RESOLVE_OVERLOAD(std::nextafter), - [](vec v0, vec v1) { return v0.nextafter(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).nextafter(v1); }, createDefaultBinaryTestCase(TestSeed(), false, true)); } TYPED_TEST(Interleave, Interleave) { @@ -664,7 +680,7 @@ namespace { test_binary( NAME_INFO(plus), std::plus(), - [](const vec& v0, const vec& v1) -> vec { + [](const auto& v0, const auto& v1) -> vec { return v0 + v1; }, createDefaultBinaryTestCase(TestSeed()), @@ -676,7 +692,7 @@ namespace { test_binary( NAME_INFO(minus), std::minus(), - [](const vec& v0, const vec& v1) -> vec { + [](const auto& v0, const auto& v1) -> vec { return v0 - v1; }, createDefaultBinaryTestCase(TestSeed()), @@ -687,7 +703,7 @@ namespace { test_binary( NAME_INFO(mult), RESOLVE_OVERLOAD(local_multiply), - [](const vec& v0, const vec& v1) { return v0 * v1; }, + [](const auto& v0, const auto& v1) { return v0 * v1; }, createDefaultBinaryTestCase(TestSeed(), false, true), RESOLVE_OVERLOAD(filter_mult_overflow)); } @@ -697,7 +713,7 @@ namespace { test_binary( NAME_INFO(division), RESOLVE_OVERLOAD(local_division), - [](const vec& v0, const vec& v1) { return v0 / v1; }, + [](const auto& v0, const auto& v1) { return v0 / v1; }, createDefaultBinaryTestCase(seed), RESOLVE_OVERLOAD(filter_div_ub)); } @@ -706,7 +722,7 @@ namespace { test_binary( NAME_INFO(bit_and), RESOLVE_OVERLOAD(local_and), - [](const vec& v0, const vec& v1) { return v0 & v1; }, + [](const auto& v0, const auto& v1) { return v0 & v1; }, createDefaultBinaryTestCase(TestSeed(), true)); } TYPED_TEST(Bitwise, BitOr) { @@ -714,7 +730,7 @@ namespace { test_binary( NAME_INFO(bit_or), RESOLVE_OVERLOAD(local_or), - [](const vec& v0, const vec& v1) { return v0 | v1; }, + [](const auto& v0, const auto& v1) { return v0 | v1; }, createDefaultBinaryTestCase(TestSeed(), true)); } TYPED_TEST(Bitwise, BitXor) { @@ -722,7 +738,7 @@ namespace { test_binary( NAME_INFO(bit_xor), RESOLVE_OVERLOAD(local_xor), - [](const vec& v0, const vec& v1) { return v0 ^ v1; }, + [](const auto& v0, const auto& v1) { return v0 ^ v1; }, createDefaultBinaryTestCase(TestSeed(), true)); } TYPED_TEST(Comparison, Equal) { @@ -785,7 +801,7 @@ namespace { test_binary( NAME_INFO(minimum), minimum, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return minimum(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); @@ -796,7 +812,7 @@ namespace { test_binary( NAME_INFO(maximum), maximum, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return maximum(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); @@ -807,7 +823,7 @@ namespace { test_binary( NAME_INFO(clamp min), clamp_min, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return clamp_min(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); @@ -818,7 +834,7 @@ namespace { test_binary( NAME_INFO(clamp max), clamp_max, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return clamp_max(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 6ff988709582..cb877a9f77eb 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -991,6 +991,10 @@ void test_binary( CACHE_ALIGN VT vals0[el_count]; CACHE_ALIGN VT vals1[el_count]; CACHE_ALIGN VT expected[el_count]; + [[maybe_unused]] CACHE_ALIGN VT expectedWithLeftScalar[el_count]; + [[maybe_unused]] CACHE_ALIGN VT expectedWithRightScalar[el_count]; + [[maybe_unused]] VT scalar0; + [[maybe_unused]] VT scalar1; bool bitwise = testCase.isBitwise(); UVT default_start = std::is_floating_point_v ? std::numeric_limits::lowest() : std::numeric_limits::min(); UVT default_end = std::numeric_limits::max(); @@ -1000,6 +1004,7 @@ void test_binary( int trialCount = getTrialCount(test_trials, domains_size); TestSeed seed = testCase.getTestSeed(); uint64_t changeSeedBy = 0; + constexpr bool kCanUseScalar = std::is_invocable_v && std::is_invocable_v; for (const CheckWithinDomains& dmn : testCase.getDomains()) { size_t dmn_argc = dmn.ArgsDomain.size(); UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start; @@ -1012,9 +1017,23 @@ void test_binary( for (const auto k : c10::irange(el_count)) { vals0[k] = generator0.get(); vals1[k] = generator1.get(); + if (k == 0) { + scalar0 = vals0[0]; + scalar1 = vals1[0]; + } call_filter(filter, vals0[k], vals1[k]); + if constexpr (kCanUseScalar) { + call_filter(filter, vals0[k], scalar1); + call_filter(filter, scalar0, vals1[k]); + } + } + for (const auto k : c10::irange(el_count)) { // map operator expected[k] = expectedFunction(vals0[k], vals1[k]); + if constexpr (kCanUseScalar) { + expectedWithLeftScalar[k] = expectedFunction(scalar0, vals1[k]); + expectedWithRightScalar[k] = expectedFunction(vals0[k], scalar1); + } } // test auto input0 = vec_type::loadu(vals0); @@ -1024,8 +1043,27 @@ void test_binary( AssertVectorized vecAssert( testNameInfo, seed, vec_expected, actual, input0, input1); if (vecAssert.check( - bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) { return; + } + if constexpr (kCanUseScalar) { + auto actualWithLeftScalar = actualFunction(scalar0, input1); + auto actualWithRightScalar = actualFunction(input0, scalar1); + auto vec_expectedWithLeftScalar = vec_type::loadu(expectedWithLeftScalar); + auto vec_expectedWithRightScalar = vec_type::loadu(expectedWithRightScalar); + AssertVectorized vecAssertWithLeftScalar( + testNameInfo, seed, vec_expectedWithLeftScalar, actualWithLeftScalar, scalar0, input1); + if (vecAssertWithLeftScalar.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) { + return; + } + AssertVectorized vecAssertWithRightScalar( + testNameInfo, seed, vec_expectedWithRightScalar, actualWithRightScalar, input0, scalar1); + if (vecAssertWithRightScalar.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) { + return; + } + } } // trial changeSeedBy += 1; } diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv index 00fc3c9e0949..96e54bf6f0df 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv @@ -290,7 +290,7 @@ soft_actor_critic,pass,0 -speech_transformer,fail_to_run,5 +speech_transformer,pass,10 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index a354501a75ae..45fe1fb9f7bf 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1395,9 +1395,11 @@ def load(cls, model, example_inputs): with torch.no_grad(): # copy.deepcopy is required to prevent any surprising side-effect, # see https://github.com/pytorch/pytorch/issues/113029 + # This will cause memory stats to be overshadowed by this eager run. + # To fix that, memory stats will be reset later. example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs) - if pytree._is_namedtuple_instance(example_outputs): + if pytree.is_namedtuple_instance(example_outputs): typ = type(example_outputs) pytree._register_namedtuple( typ, @@ -1411,6 +1413,14 @@ def load(cls, model, example_inputs): _produce_dynamic_shapes_for_export, combined_args ) + # delete example_outputs and reset memory stats here + del example_outputs + if current_device == "cuda": + torch.cuda.reset_peak_memory_stats() + empty_gpu_cache(current_device) + elif current_device == "hpu": + torch.hpu.reset_peak_memory_stats() + ep = torch.export.export( model, example_args, @@ -3542,17 +3552,9 @@ def run(runner, args, original_dir=None): }: # some of the models do not support use_deterministic_algorithms torch.use_deterministic_algorithms(True) + if args.devices == ["xpu"]: + torch.use_deterministic_algorithms(True, warn_only=True) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - # TODO(eqy): revisit when cuBLASLt workspace size is bumped - # if args.only is not None and args.only in { - # "DebertaForQuestionAnswering", - # "RobertaForQuestionAnswering", - # "nvidia_deeprecommender", - # "volo_d1_224", - # }: - # # These seem unhappy with numerics of larger cuBLASLt workspace - # # sizes following #145130 (due to enabling split-k?) - # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.benchmark = False diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 47fe5eafcd0c..46c979979fdf 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,32 +1,32 @@ -add_loop_eager,compile_time_instruction_count,2866000000,0.015 +add_loop_eager,compile_time_instruction_count,2944000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5460000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,5633000000,0.025 -add_loop_inductor,compile_time_instruction_count,27660000000,0.015 +add_loop_inductor,compile_time_instruction_count,28950000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40640000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42490000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,23970000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,25350000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,953800000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,963100000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17190000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18110000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15410000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16130000000,0.015 @@ -34,43 +34,44 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,97140000 -update_hint_regression,compile_time_instruction_count,1523000000,0.02 +update_hint_regression,compile_time_instruction_count,1608000000,0.02 -float_args,compile_time_instruction_count,413700000,0.015 +float_args,compile_time_instruction_count,417400000,0.015 -sum_floordiv_regression,compile_time_instruction_count,970100000,0.015 +sum_floordiv_regression,compile_time_instruction_count,985300000,0.015 -symint_sum,compile_time_instruction_count,3080000000,0.015 +symint_sum,compile_time_instruction_count,3214000000,0.015 -symint_sum_loop,compile_time_instruction_count,3988000000,0.015 +symint_sum_loop,compile_time_instruction_count,4204000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1989000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2057000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5759000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5917000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,7873000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8561000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1746000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1876000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3579000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3779000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9830000000,0.015 + +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10260000000,0.015 diff --git a/buckbuild.bzl b/buckbuild.bzl index 29addd3bf724..f7fac4bf49dd 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -194,6 +194,9 @@ def get_pt_compiler_flags(): return select({ "DEFAULT": _PT_COMPILER_FLAGS, "ovr_config//compiler:cl": windows_convert_gcc_clang_flags(_PT_COMPILER_FLAGS), + }) + select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], }) _PT_COMPILER_FLAGS = [ @@ -228,6 +231,9 @@ ATEN_COMPILER_FLAGS = [ # Not supported by clang on Windows "DEFAULT": ["-fPIC"], "ovr_config//compiler:clang-windows": [], +}) + select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], }) def get_aten_compiler_flags(): @@ -982,6 +988,10 @@ def define_buck_targets( fb_xplat_cxx_library( name = "torch_mobile_headers", header_namespace = "", + compiler_flags = select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], + }), exported_headers = subdir_glob( [ ("", "torch/csrc/jit/mobile/*.h"), @@ -1185,7 +1195,10 @@ def define_buck_targets( srcs = [ "torch/csrc/jit/mobile/observer.cpp", ] + ([] if IS_OSS else ["torch/fb/observers/MobileObserverUtil.cpp"]), - compiler_flags = ["-fexceptions"], + compiler_flags = ["-fexceptions"] + select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], + }), header_namespace = "", exported_headers = subdir_glob( [ @@ -1712,6 +1725,7 @@ def define_buck_targets( compiler_flags = get_pt_compiler_flags() + ["-Wno-error"], exported_preprocessor_flags = get_pt_preprocessor_flags() + [ "-DUSE_KINETO", + "-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND", # Need this otherwise USE_KINETO is undefed # for mobile "-DEDGE_PROFILER_USE_KINETO", @@ -1737,6 +1751,7 @@ def define_buck_targets( exported_preprocessor_flags = get_pt_preprocessor_flags() + [ "-DUSE_KINETO", "-DEDGE_PROFILER_USE_KINETO", + "-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND", ], # @lint-ignore BUCKLINT link_whole link_whole = True, @@ -1823,6 +1838,7 @@ def define_buck_targets( # Need this otherwise USE_KINETO is undefed # for mobile "-DEDGE_PROFILER_USE_KINETO", + "-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND", ] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []), extra_flags = { "fbandroid_compiler_flags": ["-frtti"], @@ -2035,7 +2051,7 @@ def define_buck_targets( "ovr_config//os:xtensa-xos": [ "-fdata-sections", "-ffunction-sections", - ], + ] }), exported_preprocessor_flags = get_pt_preprocessor_flags() + [ "-DMIN_EDGE_RUNTIME", diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index 1b5269c9da13..63fcf064e01b 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -72,6 +72,22 @@ bool SymBool::guard_size_oblivious(const char* file, int64_t line) const { return a->guard_size_oblivious(file, line); } +bool SymBool::guard_or_false(const char* file, int64_t line) const { + if (auto ma = maybe_as_bool()) { + return *ma; + } + SymNode a = toSymNodeImpl(); + return a->guard_or_false(file, line); +} + +bool SymBool::guard_or_true(const char* file, int64_t line) const { + if (auto ma = maybe_as_bool()) { + return *ma; + } + SymNode a = toSymNodeImpl(); + return a->guard_or_true(file, line); +} + bool SymBool::expect_true(const char* file, int64_t line) const { if (auto ma = maybe_as_bool()) { return *ma; diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index c7b1fe5ff316..875377b2eb37 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -62,6 +62,8 @@ class C10_API SymBool { bool guard_bool(const char* file, int64_t line) const; bool expect_true(const char* file, int64_t line) const; bool guard_size_oblivious(const char* file, int64_t line) const; + bool guard_or_false(const char* file, int64_t line) const; + bool guard_or_true(const char* file, int64_t line) const; bool has_hint() const; @@ -113,7 +115,40 @@ inline bool guard_size_oblivious( return b.guard_size_oblivious(file, line); } +inline bool guard_or_false( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_or_false( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_or_false(file, line); +} + +inline bool guard_or_true( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_or_true( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_or_true(file, line); +} + #define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \ c10::guard_size_oblivious((cond), __FILE__, __LINE__) +#define TORCH_GUARD_OR_FALSE(cond) \ + c10::guard_or_false((cond), __FILE__, __LINE__) + +#define TORCH_GUARD_OR_TRUE(cond) c10::guard_or_true((cond), __FILE__, __LINE__) + } // namespace c10 diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 36652e1800ac..6589a1e0b780 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -186,6 +186,16 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { // with a better implementation! return guard_bool(file, line); } + virtual bool guard_or_false(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } + virtual bool guard_or_true(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } virtual bool expect_true(const char* file, int64_t line) { // No improvement for unbacked SymBools by default, replace this // with a better implementation! diff --git a/c10/cuda/CUDADeviceAssertion.h b/c10/cuda/CUDADeviceAssertion.h index 063c7836932a..6b98e78aa469 100644 --- a/c10/cuda/CUDADeviceAssertion.h +++ b/c10/cuda/CUDADeviceAssertion.h @@ -6,6 +6,7 @@ namespace c10::cuda { #ifdef TORCH_USE_CUDA_DSA +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") // Copy string from `src` to `dst` static __device__ void dstrcpy(char* dst, const char* src) { int i = 0; @@ -64,6 +65,7 @@ static __device__ void dsa_add_new_assertion_failure( self.thread_id[1] = thread_id.y; self.thread_id[2] = thread_id.z; } +C10_CLANG_DIAGNOSTIC_POP() // Emulates a kernel assertion. The assertion won't stop the kernel's progress, // so you should assume everything the kernel produces is garbage if there's an diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 65cbdfe878dc..d2eb495e8833 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -3,6 +3,8 @@ #define NVML_NO_UNVERSIONED_FUNC_DEFS #include +#include + #define C10_CUDA_DRIVER_CHECK(EXPR) \ do { \ CUresult __err = EXPR; \ diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index 5445d53039b1..b6f7f6bc83ee 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -6,27 +6,88 @@ namespace c10 { namespace metal { +constant constexpr ushort simdgroup_size = 32; + template -opmath_t threadgroup_sum(threadgroup T* data, unsigned size) { - // TODO: This should be moved to the callee - ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); - opmath_t rc = data[0]; - // TODO: Use `simd_shuffle_down` - for (unsigned idx = 1; idx < size; ++idx) { - rc += data[idx]; +inline ::metal::enable_if_t, T> simd_sum(T val) { + return ::metal::simd_sum(val); +} + +template +inline ::metal::enable_if_t, T> simd_prod(T val) { + return ::metal::simd_product(val); +} + +// Metal does not support SIMD reductions over 64-bit types, but it could be +// implement using simd_shuffle_down, that yields result in log2(simdgroup_size) +// iterations Use fill variant, as shuffle down returns garbage if inactive +// thread is referenced (on M1/M2, works fine on M4) and broadcast result to all +// threads in the end. Implementation heavily borrows from +// https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L16 +template +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_sum(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val += as_type( + ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); } - return rc; + return as_type(::metal::simd_broadcast(as_type(val), 0)); } template -opmath_t threadgroup_prod(threadgroup T* data, unsigned size) { - // TODO: This should be moved to the callee +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_prod(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val *= as_type( + ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); + } + return as_type(::metal::simd_broadcast(as_type(val), 0)); +} + +// Below algorithms are written with hardcoded assumption that simdgroup is 32 +// and threadgroup_max is 1024, i.e. reduction can be done in two stages max +template +opmath_t threadgroup_sum( + threadgroup opmath_t* data, + T val, + unsigned idx, + unsigned size) { + auto rc = simd_sum(static_cast>(val)); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; + } + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_sum(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); - opmath_t rc = data[0]; - for (unsigned idx = 1; idx < size; ++idx) { - rc *= data[idx]; + return data[0]; +} + +template +opmath_t threadgroup_prod( + threadgroup opmath_t* data, + T val, + unsigned idx, + unsigned size) { + auto rc = simd_prod(static_cast>(val)); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; } - return rc; + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_prod(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return data[0]; } template diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 26e1da619e35..1b60563b205d 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -1716,5 +1716,42 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) { return r; } // chebyshev_polynomial_w_forward(T x, int64_t n) +template +// TODO: Add 512 if/when double will be supported in Metal +inline constexpr int getHermitianLimit() { + return 128; +} + +template +inline float hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + float p = 1.0; + float q = x + x; + float r = 0.0; + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_h_forward(T x, int64_t n) + } // namespace metal } // namespace c10 diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 20959f64180e..1140fc703b59 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -477,7 +477,7 @@ class C10_API TypeMeta final { /** * convert TypeMeta handles to ScalarType enum values */ - inline ScalarType toScalarType() { + inline ScalarType toScalarType() const { if (C10_LIKELY(isScalarType())) { return static_cast(index_); } diff --git a/c10/xpu/XPUStream.h b/c10/xpu/XPUStream.h index 903986253d23..fea64d7c109e 100644 --- a/c10/xpu/XPUStream.h +++ b/c10/xpu/XPUStream.h @@ -59,6 +59,11 @@ class C10_XPU_API XPUStream { return queue(); } + /// Implicit conversion to sycl::queue*. + operator sycl::queue*() const { + return &queue(); + } + /// Implicit conversion to Stream (a.k.a., forget that the stream is a /// XPU stream). operator Stream() const { diff --git a/c10/xpu/test/impl/XPUStreamTest.cpp b/c10/xpu/test/impl/XPUStreamTest.cpp index 581e7e69c6fa..661022dbe18e 100644 --- a/c10/xpu/test/impl/XPUStreamTest.cpp +++ b/c10/xpu/test/impl/XPUStreamTest.cpp @@ -223,6 +223,9 @@ TEST(XPUStreamTest, ExternalTest) { ASSERT_TRUE(curStream == myStream); ASSERT_TRUE(&(curStream.queue()) == stream); + sycl::queue* q_ptr = curStream; + ASSERT_TRUE(q_ptr == stream); + delete stream; } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b850644fe977..71cc4b31a995 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -568,6 +568,7 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu diff --git a/caffe2/perfkernels/embedding_lookup_idx_sve.cc b/caffe2/perfkernels/embedding_lookup_idx_sve.cc index 873823536b55..3e211a5ba1f5 100644 --- a/caffe2/perfkernels/embedding_lookup_idx_sve.cc +++ b/caffe2/perfkernels/embedding_lookup_idx_sve.cc @@ -28,517 +28,406 @@ static bool EmbeddingLookupIdx_int32_t_float_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - vsum16 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); - vsum17 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); - vsum18 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); - vsum19 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); - vsum20 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); - vsum21 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); - vsum22 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); - vsum23 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); - vsum24 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); - vsum25 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); - vsum26 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); - vsum27 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); - vsum28 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); - vsum29 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); - vsum30 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); - vsum31 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + const float* const ip8 = &input[idx8 * block_size]; + const float* const ip9 = &input[idx9 * block_size]; + const float* const ip10 = &input[idx10 * block_size]; + const float* const ip11 = &input[idx11 * block_size]; + const float* const ip12 = &input[idx12 * block_size]; + const float* const ip13 = &input[idx13 * block_size]; + const float* const ip14 = &input[idx14 * block_size]; + const float* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const float* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -611,517 +500,406 @@ static bool EmbeddingLookupIdx_int64_t_float_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - vsum16 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); - vsum17 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); - vsum18 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); - vsum19 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); - vsum20 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); - vsum21 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); - vsum22 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); - vsum23 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); - vsum24 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); - vsum25 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); - vsum26 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); - vsum27 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); - vsum28 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); - vsum29 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); - vsum30 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); - vsum31 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + const float* const ip8 = &input[idx8 * block_size]; + const float* const ip9 = &input[idx9 * block_size]; + const float* const ip10 = &input[idx10 * block_size]; + const float* const ip11 = &input[idx11 * block_size]; + const float* const ip12 = &input[idx12 * block_size]; + const float* const ip13 = &input[idx13 * block_size]; + const float* const ip14 = &input[idx14 * block_size]; + const float* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const float* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -1194,895 +972,530 @@ static bool EmbeddingLookupIdx_int32_t_half_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])))), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])))), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])))), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])))), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])))), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])))), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])))), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])))), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])))), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])))), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])))), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])))), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])))), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])))), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])))), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])))), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + const at::Half* const ip8 = &input[idx8 * block_size]; + const at::Half* const ip9 = &input[idx9 * block_size]; + const at::Half* const ip10 = &input[idx10 * block_size]; + const at::Half* const ip11 = &input[idx11 * block_size]; + const at::Half* const ip12 = &input[idx12 * block_size]; + const at::Half* const ip13 = &input[idx13 * block_size]; + const at::Half* const ip14 = &input[idx14 * block_size]; + const at::Half* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip15[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_f16_x( - pg, - svreinterpret_f16_u32(svld1uh_u32( - pg, reinterpret_cast(&ip[k])))), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -2155,895 +1568,530 @@ static bool EmbeddingLookupIdx_int64_t_half_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])))), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])))), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])))), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])))), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])))), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])))), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])))), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])))), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])))), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])))), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])))), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])))), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])))), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])))), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])))), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])))), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + const at::Half* const ip8 = &input[idx8 * block_size]; + const at::Half* const ip9 = &input[idx9 * block_size]; + const at::Half* const ip10 = &input[idx10 * block_size]; + const at::Half* const ip11 = &input[idx11 * block_size]; + const at::Half* const ip12 = &input[idx12 * block_size]; + const at::Half* const ip13 = &input[idx13 * block_size]; + const at::Half* const ip14 = &input[idx14 * block_size]; + const at::Half* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip15[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_f16_x( - pg, - svreinterpret_f16_u32(svld1uh_u32( - pg, reinterpret_cast(&ip[k])))), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -3116,958 +2164,530 @@ static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])), - 16)), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])), - 16)), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])), - 16)), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])), - 16)), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])), - 16)), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])), - 16)), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])), - 16)), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])), - 16)), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])), - 16)), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])), - 16)), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])), - 16)), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])), - 16)), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])), - 16)), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])), - 16)), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])), - 16)), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])), - 16)), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + const at::BFloat16* const ip8 = &input[idx8 * block_size]; + const at::BFloat16* const ip9 = &input[idx9 * block_size]; + const at::BFloat16* const ip10 = &input[idx10 * block_size]; + const at::BFloat16* const ip11 = &input[idx11 * block_size]; + const at::BFloat16* const ip12 = &input[idx12 * block_size]; + const at::BFloat16* const ip13 = &input[idx13 * block_size]; + const at::BFloat16* const ip14 = &input[idx14 * block_size]; + const at::BFloat16* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - pg, - svld1uh_u32( - pg, reinterpret_cast(&ip[k])), - 16)), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -4140,958 +2760,530 @@ static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])), - 16)), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])), - 16)), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])), - 16)), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])), - 16)), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])), - 16)), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])), - 16)), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])), - 16)), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])), - 16)), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])), - 16)), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])), - 16)), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])), - 16)), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])), - 16)), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])), - 16)), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])), - 16)), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])), - 16)), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])), - 16)), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + const at::BFloat16* const ip8 = &input[idx8 * block_size]; + const at::BFloat16* const ip9 = &input[idx9 * block_size]; + const at::BFloat16* const ip10 = &input[idx10 * block_size]; + const at::BFloat16* const ip11 = &input[idx11 * block_size]; + const at::BFloat16* const ip12 = &input[idx12 * block_size]; + const at::BFloat16* const ip13 = &input[idx13 * block_size]; + const at::BFloat16* const ip14 = &input[idx14 * block_size]; + const at::BFloat16* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - pg, - svld1uh_u32( - pg, reinterpret_cast(&ip[k])), - 16)), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -5164,743 +3356,555 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), - svadd_f32_x(svAll, vsum16, vbio)); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), - svadd_f32_x(svAll, vsum17, vbio)); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), - svadd_f32_x(svAll, vsum18, vbio)); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), - svadd_f32_x(svAll, vsum19, vbio)); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), - svadd_f32_x(svAll, vsum20, vbio)); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), - svadd_f32_x(svAll, vsum21, vbio)); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), - svadd_f32_x(svAll, vsum22, vbio)); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), - svadd_f32_x(svAll, vsum23, vbio)); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), - svadd_f32_x(svAll, vsum24, vbio)); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), - svadd_f32_x(svAll, vsum25, vbio)); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), - svadd_f32_x(svAll, vsum26, vbio)); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), - svadd_f32_x(svAll, vsum27, vbio)); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), - svadd_f32_x(svAll, vsum28, vbio)); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), - svadd_f32_x(svAll, vsum29, vbio)); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), - svadd_f32_x(svAll, vsum30, vbio)); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), - svadd_f32_x(svAll, vsum31, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + bio += wgt8 * scale_bias[2 * idx8 + 1]; + wgt8 = wgt8 * scale_bias[2 * idx8]; + bio += wgt9 * scale_bias[2 * idx9 + 1]; + wgt9 = wgt9 * scale_bias[2 * idx9]; + bio += wgt10 * scale_bias[2 * idx10 + 1]; + wgt10 = wgt10 * scale_bias[2 * idx10]; + bio += wgt11 * scale_bias[2 * idx11 + 1]; + wgt11 = wgt11 * scale_bias[2 * idx11]; + bio += wgt12 * scale_bias[2 * idx12 + 1]; + wgt12 = wgt12 * scale_bias[2 * idx12]; + bio += wgt13 * scale_bias[2 * idx13 + 1]; + wgt13 = wgt13 * scale_bias[2 * idx13]; + bio += wgt14 * scale_bias[2 * idx14 + 1]; + wgt14 = wgt14 * scale_bias[2 * idx14]; + bio += wgt15 * scale_bias[2 * idx15 + 1]; + wgt15 = wgt15 * scale_bias[2 * idx15]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + const uint8_t* const ip8 = &input[idx8 * block_size]; + const uint8_t* const ip9 = &input[idx9 * block_size]; + const uint8_t* const ip10 = &input[idx10 * block_size]; + const uint8_t* const ip11 = &input[idx11 * block_size]; + const uint8_t* const ip12 = &input[idx12 * block_size]; + const uint8_t* const ip13 = &input[idx13 * block_size]; + const uint8_t* const ip14 = &input[idx14 * block_size]; + const uint8_t* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k])); + auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k])); + auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k])); + auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k])); + auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k])); + auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k])); + auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k])); + auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k])); + auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k])); + auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k])); + auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k])); + auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k])); + auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k])); + auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k])); + auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - // unimplemented - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), - svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -5973,743 +3977,555 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), - svadd_f32_x(svAll, vsum16, vbio)); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), - svadd_f32_x(svAll, vsum17, vbio)); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), - svadd_f32_x(svAll, vsum18, vbio)); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), - svadd_f32_x(svAll, vsum19, vbio)); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), - svadd_f32_x(svAll, vsum20, vbio)); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), - svadd_f32_x(svAll, vsum21, vbio)); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), - svadd_f32_x(svAll, vsum22, vbio)); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), - svadd_f32_x(svAll, vsum23, vbio)); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), - svadd_f32_x(svAll, vsum24, vbio)); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), - svadd_f32_x(svAll, vsum25, vbio)); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), - svadd_f32_x(svAll, vsum26, vbio)); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), - svadd_f32_x(svAll, vsum27, vbio)); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), - svadd_f32_x(svAll, vsum28, vbio)); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), - svadd_f32_x(svAll, vsum29, vbio)); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), - svadd_f32_x(svAll, vsum30, vbio)); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), - svadd_f32_x(svAll, vsum31, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + bio += wgt8 * scale_bias[2 * idx8 + 1]; + wgt8 = wgt8 * scale_bias[2 * idx8]; + bio += wgt9 * scale_bias[2 * idx9 + 1]; + wgt9 = wgt9 * scale_bias[2 * idx9]; + bio += wgt10 * scale_bias[2 * idx10 + 1]; + wgt10 = wgt10 * scale_bias[2 * idx10]; + bio += wgt11 * scale_bias[2 * idx11 + 1]; + wgt11 = wgt11 * scale_bias[2 * idx11]; + bio += wgt12 * scale_bias[2 * idx12 + 1]; + wgt12 = wgt12 * scale_bias[2 * idx12]; + bio += wgt13 * scale_bias[2 * idx13 + 1]; + wgt13 = wgt13 * scale_bias[2 * idx13]; + bio += wgt14 * scale_bias[2 * idx14 + 1]; + wgt14 = wgt14 * scale_bias[2 * idx14]; + bio += wgt15 * scale_bias[2 * idx15 + 1]; + wgt15 = wgt15 * scale_bias[2 * idx15]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + const uint8_t* const ip8 = &input[idx8 * block_size]; + const uint8_t* const ip9 = &input[idx9 * block_size]; + const uint8_t* const ip10 = &input[idx10 * block_size]; + const uint8_t* const ip11 = &input[idx11 * block_size]; + const uint8_t* const ip12 = &input[idx12 * block_size]; + const uint8_t* const ip13 = &input[idx13 * block_size]; + const uint8_t* const ip14 = &input[idx14 * block_size]; + const uint8_t* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k])); + auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k])); + auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k])); + auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k])); + auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k])); + auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k])); + auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k])); + auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k])); + auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k])); + auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k])); + auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k])); + auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k])); + auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k])); + auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k])); + auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - // unimplemented - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), - svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } diff --git a/caffe2/perfkernels/sve_emblookup_codegen.py b/caffe2/perfkernels/sve_emblookup_codegen.py index 643b614c9081..4c5ad01bdc10 100644 --- a/caffe2/perfkernels/sve_emblookup_codegen.py +++ b/caffe2/perfkernels/sve_emblookup_codegen.py @@ -4,289 +4,105 @@ # Unroll loops when block_size is a multiple of vector length. -def unroll(num_unrolls, IndexType, InType, OutType, use_weights): - def compute(regid, InType, use_weights): +def unroll(num_unrolls, IndexType, InType, OutType): + def compute_output(num_unrolls, InType, is_main): code = [] + pred = "svAll" if is_main else "pg" if InType == "float": - code.append( - f" vsum{regid} =\n" - " svmad_f32_x(" - f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen])," - f" vsum{regid});" - ) + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, svld1(svAll, &ip{i}[k]), wgt{i});") elif InType == "at::Half": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svcvt_f32_f16_x(\n" - " svAll,\n" - " svreinterpret_f16_u32(svld1uh_u32(\n" - " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])))),\n" - f" vsum{regid});" - ) + for i in range(num_unrolls): + code.append(f" auto input{i} = svcvt_f32_x({pred}, svreinterpret_f16(\n" + f" svld1uh_u32({pred}, reinterpret_cast(&ip{i}[k]))));") + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") elif InType == "at::BFloat16": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svreinterpret_f32_u32(svlsl_n_u32_x(\n" - " svAll,\n" - " svld1uh_u32(\n" - " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])),\n" - " 16)),\n" - f" vsum{regid});" - ) + for i in range(num_unrolls): + code.append(f" auto input{i} = svreinterpret_f32(svlsl_x({pred},\n" + f" svld1uh_u32({pred}, reinterpret_cast(&ip{i}[k])), 16));") + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") elif InType == "uint8_t": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svcvt_f32_u32_x(svAll," - f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" - f" svadd_f32_x(svAll, vsum{regid}, vbio));" - ) + code.append(f" output = svadd_x({pred}, output, bio);") + for i in range(num_unrolls): + code.append(f" auto input{i} = svcvt_f32_x({pred}, svld1ub_u32({pred}, &ip{i}[k]));") + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") else: raise ValueError(f'Unknown datatype "{InType}"') return code code = [] - code.append(f" // unrolling {num_unrolls} times") - code.append(" for (int64_t i = 0; i < output_size; ++i) {") - - code.append(" " + OutType + "* const op = &out[i * block_size];") - code.append( - " if (pos != offsets[i] - offsets[0]) {\n" - + " return false;\n" - + " }" - ) - - # Initialise vector sum registers + if num_unrolls == 1: + code.append(f" // tail loop") + code.append(" if (j < end_offset) {") + else: + code.append(f" // unrolling {num_unrolls} times") + code.append(f" while (j + {num_unrolls - 1} < end_offset) {{") for i in range(num_unrolls): - code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);") - - # inner loop - code.append("""\ - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1];""") - code.append( - " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {" - ) - - code.append(" const auto idx = indices[pos];") - code.append( - " if (idx < 0 || idx >= data_size) {\n" - + " return false;\n" - + " }" - ) + code.append(f" const auto idx{i} = indices[pos + {i}];") - if InType == "uint8_t": - code.append(" " + OutType + " wgt = 1.f;") - code.append(" " + OutType + " bio{};") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" - ) - code.append(" }") - code.append(" if (scale_bias) {") - code.append(" bio = wgt * scale_bias[2 * idx + 1];") - code.append(" wgt = wgt * scale_bias[2 * idx];") - code.append(" }") - code.append(" svfloat32_t vbio = svdup_n_f32(bio);") - else: - code.append(" " + OutType + " wgt = 1.f;") - code.append(" if (weights) {") + # check indices + for i in range(num_unrolls): code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" + f" if (idx{i} < 0 || idx{i} >= data_size) {{\n" + + " return false;\n" + + " }" ) - code.append(" }") - code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") - code.append(f" const {InType}* const ip = &input[idx * block_size];") - code.append(" // weight * input + out") + if InType == "uint8_t": + for i in range(num_unrolls): + code.append(f" {OutType} wgt{i} = 1.f;") + code.append(f" {OutType} bio = 0.f;") + else: + for i in range(num_unrolls): + code.append(f" {OutType} wgt{i} = 1.f;") + code.append(" if (weights) {") for i in range(num_unrolls): - code.extend(compute(i, InType, use_weights)) - - code.append(" ++pos;") + code.append(f" wgt{i} = weights[IS_WEIGHT_POSITIONAL ? (j + {i} - start_offset) : pos + {i}];") code.append(" }") + if InType == "uint8_t": + code.append(" if (scale_bias) {") + for i in range(num_unrolls): + code.append(f" bio += wgt{i} * scale_bias[2 * idx{i} + 1];") + code.append(f" wgt{i} = wgt{i} * scale_bias[2 * idx{i}];") + code.append(" }") - code.append(" // Normalisation") - code.append(" const int64_t length = end_offset - start_offset;") - code.append(" if (normalize_by_lengths && length != 0) {") - code.append(" const float len_inv = 1.0f / length;") - code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") - - for i in range(num_unrolls): - code.append( - f" svst1_f32(svAll, &op[{i} * vLen]," - + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));" - ) - - code.append(" } else {") - # inv of length for i in range(num_unrolls): - code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});") - + code.append(f" const {InType}* const ip{i} = &input[idx{i} * block_size];") + + # compute and store + code.append(" svbool_t pg;") + code.append(" int64_t k = 0;") + # main loop + code.append(" while (k + vLen - 1 < block_size) {") + code.append(" auto output = svld1(svAll, &op[k]);") + code.extend(compute_output(num_unrolls, InType, True)) + code.append(" svst1(svAll, &op[k], output);") + code.append(" k += vLen;") code.append(" }") - code.append(" }") - return code - - -# Handle the case where block_size is not a multiple of vector length. -def generic(IndexType, InType, OutType, use_weights): - def compute(InType, use_weights): - code = [] - if InType == "float": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg, vwgt, svld1_f32(pg, &ip[k])," - " svld1_f32(pg, &op[k])));" - ) - elif InType == "at::Half": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svcvt_f32_f16_x(\n" - " pg,\n" - " svreinterpret_f16_u32(svld1uh_u32(\n" - " pg," - " reinterpret_cast(&ip[k])))),\n" - " svld1_f32(pg, &op[k])));" - ) - elif InType == "at::BFloat16": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svreinterpret_f32_u32(svlsl_n_u32_x(\n" - " pg,\n" - " svld1uh_u32(\n" - " pg," - " reinterpret_cast(&ip[k])),\n" - " 16)),\n" - " svld1_f32(pg, &op[k])));" - ) - elif InType == "uint8_t": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svcvt_f32_u32_x(pg," - " svld1ub_u32(pg, &ip[k])),\n" - " svadd_f32_x(pg," - " svld1_f32(pg, &op[k]), vbio)));" - ) - else: - raise ValueError(f'Unknown datatype "{InType}"') - - return code - - code = [] - - code.append(" for (int64_t i = 0; i < output_size; ++i) {") - - code.append(" " + OutType + "* const op = &out[i * block_size];") - - # initialize to 0 - code.append(" memset(op, 0, sizeof(float) * block_size);") - - # inner loop - code.append( - " if (pos != offsets[i] - offsets[0]) {\n" - + " return false;\n" - + " }" - ) - code.append( - " int64_t start_offset = offsets[i];\n" - + " int64_t end_offset = offsets[i + 1];" - ) - code.append( - " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {" - ) - - code.append(" const auto idx = indices[pos];") - code.append( - " if (idx < 0 || idx >= data_size) {\n" - + " return false;\n" - + " }" - ) - - if InType == "uint8_t": - code.append(" // unimplemented") - code.append(" " + OutType + " wgt = 1.f;") - code.append(" " + OutType + " bio{};") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" - ) - code.append(" }") - code.append(" if (scale_bias) {") - code.append(" bio = wgt * scale_bias[2 * idx + 1];") - code.append(" wgt = wgt * scale_bias[2 * idx];") - code.append(" }") - code.append(" svfloat32_t vbio = svdup_n_f32(bio);") - else: - code.append(" " + OutType + " wgt = 1.f;") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" - ) - code.append(" }") - - code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") - code.append(f" const {InType}* ip = &input[idx * block_size];") - - # compute and store main loop - code.append(" svbool_t pg;") - code.append(" for (int64_t k = 0;") - code.append( - " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));" - ) - code.append(" k += vLen) {") - code.extend(compute(InType, use_weights)) - code.append(" }\n") - code.append(" ++pos;") + # tail loop + code.append(" if (k < block_size) {") + code.append(" pg = svwhilelt_b32_s64(k, block_size);") + code.append(" auto output = svld1(pg, &op[k]);") + code.extend(compute_output(num_unrolls, InType, False)) + code.append(" svst1(pg, &op[k], output);") + code.append(" k += vLen;") code.append(" }") + if num_unrolls == 1: + code.append(" pos ++;") + else: + code.append(f" j += {num_unrolls};") + code.append(f" pos += {num_unrolls};") - code.append(" const int64_t length = end_offset - start_offset;\n") - code.append(" if (normalize_by_lengths && length != 0) {") - code.append(" const float len_inv = 1.0f / length;") - code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") - code.append(" svbool_t pg;") - code.append( - " for (int64_t j = 0;\n" - " svptest_first(svAll, pg = svwhilelt_b32_s64(" - "j, block_size));" - ) - code.append(" j += vLen) {") - code.append( - " svst1_f32(\n" - " pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));" - ) - code.append(" }") - code.append(" }") code.append(" }") - return code + return code def main(): parser = argparse.ArgumentParser() @@ -352,22 +168,47 @@ def main(): code.append(" const auto vLen = static_cast(svcntw());") code.append(" int64_t pos = 0;") - code.append(" if (block_size == 32 * vLen) {") - code += unroll(32, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 16 * vLen) {") - code += unroll(16, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 8 * vLen) {") - code += unroll(8, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 4 * vLen) {") - code += unroll(4, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 2 * vLen) {") - code += unroll(2, IndexType, InType, OutType, True) - code.append(" } else {") - code.append(" // generic code:") - code += generic(IndexType, InType, OutType, True) + code.append(" for (int64_t i = 0; i < output_size; ++i) {") + code.append(" " + OutType + "* const op = &out[i * block_size];") + + # initialize to 0 + code.append(" memset(op, 0, sizeof(float) * block_size);") + + # inner loop + code.append( + " if (pos != offsets[i] - offsets[0]) {\n" + + " return false;\n" + + " }" + ) + code.append( + " int64_t start_offset = offsets[i];\n" + + " int64_t end_offset = offsets[i + 1];" + ) + code.append(" int64_t j = start_offset;") + + code += unroll(16, IndexType, InType, OutType) + code += unroll(8, IndexType, InType, OutType) + code += unroll(4, IndexType, InType, OutType) + code += unroll(2, IndexType, InType, OutType) + code += unroll(1, IndexType, InType, OutType) + + code.append(" const int64_t length = end_offset - start_offset;\n") + code.append(" if (normalize_by_lengths && length != 0) {") + code.append(" const float len_inv = 1.0f / length;") + code.append(" svbool_t pg;") + code.append(" int64_t j = 0;") + code.append(" while (j + vLen - 1 < block_size) {") + code.append(" svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));") + code.append(" j += vLen;") + code.append(" }") + code.append(" if (j < block_size) {") + code.append(" pg = svwhilelt_b32_s64(j, block_size);") + code.append(" svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));") + code.append(" }") + code.append(" }") + code.append(" }") code.append(" return pos == index_size;") - code.append("}") for is_weight_positional in ["false", "true"]: diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 5ca808f20c8a..dc6cf8db5c37 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -120,7 +120,11 @@ if(INTERN_BUILD_ATEN_OPS) "89;90a;100a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" - "89;90a") + "90a") + _BUILD_FOR_ADDITIONAL_ARCHS( + "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" + "90a") + endif() set(GEN_ROCM_FLAG) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index bd8f7792214e..b6c51e639eee 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -737,6 +737,12 @@ if(USE_FBGEMM) set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET fbgemm PROPERTY POSITION_INDEPENDENT_CODE ON) + # TODO: Remove next two lines after fbgemm pin is updated + + # For more details see https://github.com/pytorch/pytorch/issues/150846 + target_compile_options_if_supported(fbgemm_avx512 -Wno-maybe-uninitialized) + target_compile_options_if_supported(fbgemm_avx512 -Wno-uninitialized) + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0) # See https://github.com/pytorch/pytorch/issues/74352 target_compile_options_if_supported(asmjit -Wno-deprecated-copy) @@ -816,9 +822,18 @@ if(NOT TARGET fp16 AND NOT USE_SYSTEM_FP16) set(FP16_BUILD_TESTS OFF CACHE BOOL "") set(FP16_BUILD_BENCHMARKS OFF CACHE BOOL "") - add_subdirectory( - "${FP16_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/FP16") + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + message(WARNING "FP16 is only cmake-2.8 compatible") + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + add_subdirectory( + "${FP16_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/FP16") + unset(CMAKE_POLICY_VERSION_MINIMUM) + else() + add_subdirectory( + "${FP16_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/FP16") + endif() elseif(NOT TARGET fp16 AND USE_SYSTEM_FP16) add_library(fp16 STATIC "/usr/include/fp16.h") set_target_properties(fp16 PROPERTIES LINKER_LANGUAGE C) @@ -1206,15 +1221,7 @@ if(USE_GLOO) set(NCCL_EXTERNAL ON) endif() set(GLOO_USE_CUDA_TOOLKIT ON CACHE BOOL "" FORCE) - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") - # Remove me when https://github.com/facebookincubator/gloo/pull/424 is landed - message(WARNING "Downgrading cmake-policy-version for gloo build") - set(CMAKE_POLICY_VERSION_MINIMUM 3.5) - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) - unset(CMAKE_POLICY_VERSION_MINIMUM) - else() - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) - endif() + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) # Here is a little bit hacky. We have to put PROJECT_BINARY_DIR in front # of PROJECT_SOURCE_DIR with/without conda system. The reason is that # gloo generates a new config.h in the binary diretory. @@ -1714,7 +1721,7 @@ if(USE_KINETO) set_property(TARGET kineto PROPERTY POSITION_INDEPENDENT_CODE ON) endif() list(APPEND Caffe2_DEPENDENCY_LIBS kineto) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO") + string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO -DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND") if(LIBKINETO_NOCUPTI) string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOCUPTI") endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 5741cf7d0952..28d15a5ea1b7 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -154,7 +154,15 @@ if(HIP_FOUND) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(hipsolver REQUIRED) - find_package_and_print_version(hiprtc REQUIRED) + # workaround cmake 4 build issue + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + message(WARNING "Work around hiprtc cmake failure for cmake >= 4") + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + find_package_and_print_version(hiprtc REQUIRED) + unset(CMAKE_POLICY_VERSION_MINIMUM) + else() + find_package_and_print_version(hiprtc REQUIRED) + endif() find_package_and_print_version(hipblaslt REQUIRED) if(UNIX) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 7092a836417f..8e8d14e17e54 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -284,6 +284,13 @@ The machine with rank 0 will be used to set up all connections. This is the default method, meaning that ``init_method`` does not have to be specified (or can be ``env://``). +Improving initialization time +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* ``TORCH_GLOO_LAZY_INIT`` - establishes connections on demand rather than + using a full mesh which can greatly improve initialization time for non all2all + operations. + Post-Initialization ------------------- diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst index 694212296e35..75cedd809fdc 100644 --- a/docs/source/distributed.tensor.parallel.rst +++ b/docs/source/distributed.tensor.parallel.rst @@ -46,6 +46,10 @@ the ``parallelize_plan`` when calling ``parallelize_module``: :members: :undoc-members: +.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInputOutput + :members: + :undoc-members: + .. note:: when using the ``Shard(dim)`` as the input/output layouts for the above ``ParallelStyle`` s, we assume the input/output activation tensors are evenly sharded on the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance, diff --git a/docs/source/export.rst b/docs/source/export.rst index 3f947e54e568..8db7fb118334 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -797,6 +797,12 @@ API Reference .. automethod:: dynamic_shapes +.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs + + .. automethod:: add + .. automethod:: dynamic_shapes + .. automethod:: verify + .. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes .. autoclass:: Constraint .. autoclass:: ExportedProgram diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index dce6d126dce3..d5f140a3db0b 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -4,27 +4,46 @@ Getting Started on Intel GPU Hardware Prerequisite --------------------- +For Intel Data Center GPU + .. list-table:: - :widths: 50 50 + :widths: 50 50 50 50 :header-rows: 1 - * - Supported OS - - Validated Hardware - * - Linux - - Intel® Client GPUs / Intel® Data Center GPU Max Series - * - Windows - - Intel® Client GPUs - * - WSL2 (experimental feature) - - Intel® Client GPUs - -Intel GPUs support (Prototype) is ready in PyTorch* 2.6 for Intel® Client GPUs and Intel® Data Center GPU Max Series on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. + * - Device + - Red Hat* Enterprise Linux* 9.2 + - SUSE Linux Enterprise Server* 15 SP5 + - Ubuntu* Server 22.04 (>= 5.15 LTS kernel) + * - Intel® Data Center GPU Max Series (CodeName: Ponte Vecchio) + - yes + - yes + - yes + +For Intel Client GPU + ++-------------------------------------+----------------------------------------------------------------------------------------------+ +| Supported OS | Validated Hardware | ++=====================================+==============================================================================================+ +|| Windows 10/11 & Ubuntu 24.10 || Intel® Arc A-Series Graphics (CodeName: Alchemist) | +|| || Intel® Arc B-Series Graphics (CodeName: Battlemage) | +|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake) | +|| || Intel® Core™ Ultra 200V Series with Intel® Arc™ Graphics (CodeName: Lunar Lake) | +|| || Intel® Core™ Ultra Series 2 Processors with Intel® Arc™ Graphics (CodeName: Arrow Lake) | ++-------------------------------------+----------------------------------------------------------------------------------------------+ +|| Ubuntu 24.04 & WSL2 (Ubuntu 24.04) || Intel® Arc A-Series Graphics (CodeName: Alchemist) | +|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake) | +|| || Intel® Core™ Ultra 200V Series with Intel® Arc™ Graphics (CodeName: Lunar Lake) | +|| || Intel® Core™ Ultra Series 2 Processors with Intel® Arc™ Graphics (CodeName: Arrow Lake) | ++-------------------------------------+----------------------------------------------------------------------------------------------+ + +Intel GPUs support (Prototype) is ready from PyTorch* 2.5 for Intel® Client GPUs and Intel® Data Center GPU Max Series on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. Software Prerequisite --------------------- -To use PyTorch on Intel GPUs, you need to install the Intel GPUs driver first. For installation guide, visit `Intel GPUs Driver Installation `_. +To use PyTorch on Intel GPUs, you need to install the Intel GPUs driver first. For installation guide, visit `Intel GPUs Driver Installation `_. -Please skip the Intel® Deep Learning Essentials installation section if you install from binaries. For building from source, please refer to `PyTorch Installation Prerequisites for Intel GPUs `_ for both Intel GPU Driver and Intel® Deep Learning Essentials Installation. +Please skip the Intel® Deep Learning Essentials installation section if you install from binaries. For building from source, please refer to `PyTorch Installation Prerequisites for Intel GPUs `_ for both Intel GPU Driver and Intel® Deep Learning Essentials Installation. Installation @@ -33,7 +52,7 @@ Installation Binaries ^^^^^^^^ -Now that we have `Intel GPU Driver `_ installed, use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` on Linux. +Now that we have `Intel GPU Driver `_ installed, use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` on Linux. For release wheels @@ -52,7 +71,7 @@ For nightly wheels From Source ^^^^^^^^^^^ -Now that we have `Intel GPU Driver and Intel® Deep Learning Essentials `_ installed. Follow guides to build ``pytorch``, ``torchvision``, ``torchaudio`` from source. +Now that we have `Intel GPU Driver and Intel® Deep Learning Essentials `_ installed. Follow guides to build ``pytorch``, ``torchvision``, ``torchaudio`` from source. Build from source for ``torch`` refer to `PyTorch Installation Build from source `_. @@ -88,7 +107,7 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. -#. Both eager mode and ``torch.compile`` is supported. +#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 4e4ce90c6055..83ad054514ef 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -146,6 +146,20 @@ torch.ao.quantization.pt2e.export_utils .. currentmodule:: torch.ao.quantization +torch.ao.quantization.pt2e.lowering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torch.ao.quantization.pt2e.lowering + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + lower_pt2e_quantized_to_x86 + +.. currentmodule:: torch.ao.quantization + PT2 Export (pt2e) Numeric Debugger ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 1b808136ef11..226bb143d322 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1341,6 +1341,7 @@ Please take a look at `Limitations of Symbolic Tracing list[Measurement]: + rc = [] + + # Bench 2D with reduction over dim=0 + def f(t): + return reduction_func(t, dim=0) + + f.__name__ = reduction_func.__name__ + f_c = torch.compile(f, dynamic=False) + + for size in (512, 1024, 2048, 4096): + x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) + rc_c, rc_e = f(x), f_c(x) + rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e) + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) + rc.append(bench_unary_op(f, x, f"eager-{size}x{size}")) + rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}")) + return rc + + def main() -> None: dtypes = [torch.float16, torch.float32] + if torch.backends.mps.is_macos_or_newer(14, 0): + dtypes.append(torch.bfloat16) + # Profile unary ops rc = [] for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes): rc.extend(bench_unary(op, dtype=dtype)) Compare(rc).print() + # Profile reduction ops + rc = [] + for op in [torch.sum, torch.max]: + rc.extend(bench_reduction(op)) + Compare(rc).print() + # Profile binary ops rc = [] ops = [torch.fmax, torch.add] for op, dtype in itertools.product(ops, dtypes): rc.extend(bench_binary(op, dt_a=dtype)) - for op in ops: - rc.extend(bench_binary(op, dt_b=torch.float16)) + if dtype == torch.float32: + rc.extend(bench_binary(op, dt_b=torch.float16)) Compare(rc).print() diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index 8a9c36db683c..1bf6ecc1ecfe 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -230,6 +230,16 @@ void test_aoti_constants_update( actual_output_tensors = runner->run(input_tensors); ASSERT_FALSE( torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); + + for (auto& pair : missing_map) { + delete pair.second; + } + for (auto& pair : rand_map) { + delete pair.second; + } + for (auto& pair : real_map) { + delete pair.second; + } } void test_aoti_extract_constants_map(const std::string& device) { @@ -395,6 +405,13 @@ void test_aoti_double_buffering( runner->swap_constant_buffer(); actual_output_tensors = runner->run(input_tensors); ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); + + for (auto& pair : rand_map) { + delete pair.second; + } + for (auto& pair : real_map) { + delete pair.second; + } } #if defined(USE_CUDA) || defined(USE_ROCM) @@ -435,11 +452,14 @@ void test_aoti_double_buffering_with_tensor_constants() { runner->swap_constant_buffer(); actual_output_tensors = runner->run(input_tensors); ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); + + for (auto& pair : real_map) { + delete pair.second; + } } void test_aoti_free_buffer(bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; - size_t allocated, reserved, active; std::string data_path = (std::filesystem::path( @@ -490,7 +510,11 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { } c10::cuda::CUDACachingAllocator::DeviceStats stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + size_t initTorchActive = stats.active_bytes[0].current; + size_t initTorchReserved = stats.reserved_bytes[0].current; // This should contain one set of weight (128MB) loaded from .so + size_t torchActive1, torchActive2; + size_t torchReserved1, torchReserved2; size_t initMemory = 0; size_t totalMemory = 0; cudaStatus = cudaMemGetInfo(&initMemory, &totalMemory); @@ -511,18 +535,30 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { // (64MB). if (use_runtime_constant_folding) { runner->run_const_fold(/* use_inactive = */ true); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; size_t constFoldMemory = 0; cudaStatus = cudaMemGetInfo(&constFoldMemory, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory - DATASIZE - FOLDEDDATASIZE, constFoldMemory); + ASSERT_EQ( + initMemory - DATASIZE - (torchReserved1 - initTorchReserved), + constFoldMemory); + ASSERT_EQ(torchActive1 - initTorchActive, FOLDEDDATASIZE); } // We swap and free the inactive buffer. (Use #2 and free #1) - // Note that buffer #1 do not include folded-const + // Note that buffer #1 does not include folded-const + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; runner->swap_constant_buffer(); runner->free_inactive_constant_buffer(); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive2 = stats.active_bytes[0].current; + torchReserved2 = stats.reserved_bytes[0].current; size_t postFreeMemory = 0; cudaStatus = cudaMemGetInfo(&postFreeMemory, &totalMemory); if (cudaStatus != cudaSuccess) { @@ -530,60 +566,84 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { } // We should only have one set of buffer (#2), available memory should equal // initial memory minus the folded constants. - ASSERT_EQ(initMemory - FOLDEDDATASIZE, postFreeMemory); + ASSERT_EQ(initMemory - (torchReserved2 - initTorchReserved), postFreeMemory); + // Buffer #1 does not include folded-consts + ASSERT_EQ(torchActive2 - torchActive1, 0); // We update random weights to buffer #1 and run const fold. // We will have 2 full set of data plus 2 set of const-folded data. runner->update_inactive_constant_buffer(rand_map); runner->run_const_fold(/* use_inactive = */ true); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; size_t updateMemory1 = 0; cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory - DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1); + ASSERT_EQ( + initMemory - DATASIZE - (torchReserved1 - initTorchReserved), + updateMemory1); + ASSERT_EQ(torchActive1 - initTorchActive, 2 * FOLDEDDATASIZE); // We directly free the buffer #1. This would free the DATASIZE weight. // If folded constant exists, it will not directly free the cudaMalloc, but // decrease the active buffer in CachingAllocator instead. - size_t active1, active2; - size_t allocated1, allocated2; stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); - active1 = stats.active_bytes[0].current; - allocated1 = stats.allocated_bytes[0].current; + torchActive1 = stats.active_bytes[0].current; runner->free_inactive_constant_buffer(); cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); - active2 = stats.active_bytes[0].current; - allocated2 = stats.allocated_bytes[0].current; - ASSERT_EQ(initMemory - 2 * FOLDEDDATASIZE, updateMemory1); - ASSERT_EQ(FOLDEDDATASIZE, active1 - active2); + torchActive2 = stats.active_bytes[0].current; + torchReserved2 = stats.reserved_bytes[0].current; + ASSERT_EQ(initMemory - (torchReserved2 - initTorchReserved), updateMemory1); + ASSERT_EQ(FOLDEDDATASIZE, torchActive1 - torchActive2); // Free buffer #1 again, since #1 is freed, nothing should change. + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; runner->free_inactive_constant_buffer(); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive2 = stats.active_bytes[0].current; cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory - 2 * FOLDEDDATASIZE, updateMemory1); - ASSERT_EQ(FOLDEDDATASIZE, active1 - active2); + ASSERT_EQ(initMemory - (torchReserved2 - initTorchReserved), updateMemory1); + ASSERT_EQ(torchActive1 - torchActive2, 0); // Swap and free #2, no data should exist in memory now. - // However, the folded constants still occupies the CUDA memory in + // However, the folded constants might still occupies the CUDA memory in // CachedAllocator. + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; runner->swap_constant_buffer(); runner->free_inactive_constant_buffer(); stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); - active2 = stats.active_bytes[0].current; + torchActive2 = stats.active_bytes[0].current; + torchReserved2 = stats.reserved_bytes[0].current; cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory + DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1); - ASSERT_EQ(2 * FOLDEDDATASIZE, active1 - active2); + + ASSERT_EQ( + initMemory + DATASIZE - (torchReserved2 - initTorchReserved), + updateMemory1); + ASSERT_EQ(FOLDEDDATASIZE, torchActive1 - torchActive2); + ASSERT_EQ(0, torchActive2 - initTorchActive); + + for (auto& pair : rand_map) { + delete pair.second; + } + for (auto& pair : real_map) { + delete pair.second; + } } class ThreadPool { diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index a2fa2b467c52..533c50a43fe8 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter { }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { + // Note (kwen2501) 03/07/2025 + // TODO: re-enable + GTEST_SKIP() << "Skipping test as the trace write seems unstable."; int heartBeatIntervalInSec = 2; std::string timeInterval = std::to_string(heartBeatIntervalInSec); ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index cd2eaf761dff..75bf60b0654e 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -27,7 +27,7 @@ add_library(backend_with_compiler SHARED ) if(USE_KINETO) set_target_properties(backend_with_compiler PROPERTIES COMPILE_FLAGS - "-DUSE_KINETO") + "-DUSE_KINETO -DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND") endif() target_link_libraries(backend_with_compiler torch) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 6351a74459bd..db460818dad5 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -131,7 +131,6 @@ def skipTestForOldSm(self): if not sm_is_or_higher_than(device, 8, 0): self.skipTest("bf16 requires sm >= 8.0") - @skipIfRocm def test_dynamo_trace_use_training_state(self): torch._dynamo.reset() # Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager. @@ -169,7 +168,6 @@ def f(x): self.assertEqual(cnt.op_count, 1) self.assertEqual(len(cnt.graphs), 1) - @skipIfRocm def test_trace_fsdp_copy_(self): @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"}) def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None: diff --git a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py index 7b7beb30af9d..bb4f28f43a41 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py @@ -13,12 +13,11 @@ ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest, MLP -from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.common_utils import run_tests class TestFullyShardGradientScaler(FSDPTest): @skip_if_lt_x_gpu(4) - @skipIfRocm def test_gradient_scaler(self): self.run_subtests( {"has_inf": [True, False], "test_2d": [True, False]}, diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py index 340fe913c1eb..de6df77479c9 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py @@ -117,6 +117,9 @@ def _test_fully_shard_training_memory( # number is kept much smaller than the actual memory usage, which is on # the order of 100-200+ MB) buffer_mb = 16 + # The default workspace for hipblaslt is larger than for cublas/cublaslt + # which requires a slight increase to this buffer value. + buffer_mb = 16 if torch.version.cuda else 18 if reshard_after_forward: # 3x max unsharded block parameters (current all-gather + copy-out # and next all-gather), non-block parameters, and other diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 839bbcd6920d..3b92dfcb0a9f 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -28,11 +28,10 @@ from torch.testing._internal.common_distributed import ( DistributedTestBase, skip_if_lt_x_gpu, - skip_if_rocm_multiprocess, sm_is_or_higher_than, ) from torch.testing._internal.common_fsdp import get_devtype -from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils.checkpoint import checkpoint @@ -194,7 +193,6 @@ def test_compile_cpu_no_sync(self): self._test_compile(no_sync=True, device="cpu") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch( reorder_for_locality=False, reorder_for_peak_memory=False @@ -203,7 +201,6 @@ def test_compile_gpu(self): self._test_compile(no_sync=False, checkpoint=False, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch( reorder_for_locality=False, reorder_for_peak_memory=False @@ -212,11 +209,13 @@ def test_compile_gpu_ac(self): self._test_compile(no_sync=False, checkpoint=True, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_bf16(self): # Check device capability wrt bf16 - if not sm_is_or_higher_than(torch.device(device_type), 8, 0): + if ( + not sm_is_or_higher_than(torch.device(device_type), 8, 0) + and torch.version.hip is None + ): self.skipTest("bf16 requires sm >= 8.0") def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -230,7 +229,6 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: self._test_compile(no_sync=False, setup_func=setup, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_fp16(self): def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -247,7 +245,6 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): self._test_compile(no_sync=False, no_compile_forward=True, device=device_type) @@ -387,7 +384,6 @@ def tearDown(self): "Temporarily disabled due to SymInt error: `unhashable type: non-nested SymInt`" ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skipIfRocm def test_ddp_tp(self): ref_model = Net() compiled_replicate_model = deepcopy(ref_model) diff --git a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py index fac49dd2786f..f8d90d3677e1 100644 --- a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py +++ b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py @@ -482,15 +482,6 @@ def _get_base_model(mlp_dim: int = 2): tp_parallelize_plan.pop("0.out_proj") with cm: - tp_parallelize_plan = { - "0.in_proj": ColwiseParallel(), - "0.out_proj": RowwiseParallel(), - "1.in_proj": ColwiseParallel(), - "1.out_proj": RowwiseParallel(), - "2.in_proj": ColwiseParallel(), - "2.out_proj": RowwiseParallel(), - } - # init device mesh dp_size = 2 global_mesh_1d = init_device_mesh( diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 9f099bbd825a..dfa485090c2d 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -190,6 +190,37 @@ def test_metadata_hf(self) -> None: metadata = reader.read_metadata() self.assertEqual(metadata.storage_data, expected_metadata["weight_map"]) + def test_read_metadata_when_metadata_file_does_not_exist(self) -> None: + mock_module = MagicMock() + sys.modules["safetensors.torch"] = mock_module + sys.modules["huggingface_hub"] = mock_module + with tempfile.TemporaryDirectory() as path: + reader = _HuggingFaceStorageReader(path=path) + reader.fs = FileSystem() + # there is one safetensor file, but no metadata file, + # so we create metadata from the safetensor file + file_name = "test.safetensors" + open(os.path.join(path, file_name), "w").close() + + keys = ["tensor_0", "tensor_1"] + mock_module.safe_open.return_value.__enter__.return_value.keys.return_value = ( + keys + ) + + metadata = reader.read_metadata() + + self.assertEqual( + metadata.state_dict_metadata, + { + keys[0]: BytesStorageMetadata(), + keys[1]: BytesStorageMetadata(), + }, + ) + self.assertEqual( + metadata.storage_data, + {keys[0]: file_name, keys[1]: file_name}, + ) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index d3b3441039d4..9dc730379ecf 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -242,6 +242,19 @@ def test_scatter_object(self): expected_objects = rank assert scattered_objects == expected_objects + @with_comms + @skip_if_lt_x_gpu(2) + def test_barrier(self): + mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2)) + torch.random.manual_seed(dist.get_rank()) + + dist_wrapper = _DistWrapper( + mesh_2d.get_group(1), use_dist=True, coordinator_rank=0 + ) + + # No exception should be raised. + dist_wrapper.barrier() + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index fc371979ca3c..1e51938a033f 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -24,7 +24,6 @@ instantiate_parametrized_tests, parametrize, run_tests, - skipIfRocm, TEST_WITH_DEV_DBG_ASAN, ) @@ -275,7 +274,6 @@ def test_grad_acc( ) @skip_if_lt_x_gpu(2) - @skipIfRocm @parametrize("use_orig_params", [False, True]) def test_grad_acc_cpu_offload( self, diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 18128366c8db..ae94d8c3ec68 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -9,6 +9,7 @@ from torch.distributed.tensor.parallel.style import ( ColwiseParallel, PrepareModuleInput, + PrepareModuleInputOutput, PrepareModuleOutput, RowwiseParallel, ) @@ -201,6 +202,29 @@ def test_prepare_module_output(self): inp = dtensor.redistribute(device_mesh, [Shard(0)]).to_local() self.assertEqual(inp, output) + @with_comms + def test_prepare_module_input_output(self): + module = DummyModule() + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + parallelize_module( + module, + device_mesh, + PrepareModuleInputOutput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + output_layouts=Replicate(), + desired_output_layouts=Shard(1), + ), + ) + inp = torch.rand(5, 7, device=self.device_type) + output = module(inp) + inp = ( + DTensor.from_local(inp, device_mesh, [Shard(0)], run_check=False) + .redistribute(device_mesh, [Shard(1)]) + .to_local() + ) + self.assertEqual(inp, output) + @with_comms def test_parallelize_module_with_star(self): inp_size = [12, 10] diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 8de5a4db0a98..162acbd000e9 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -157,6 +157,7 @@ def forward(self, x): str(ep.graph_module.code).strip(), """\ def forward(self, b_buffer, x): + _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None view_as = torch.ops.aten.view_as.default(to, to); to = None dtensor___init__0 = self.dtensor___init__0 @@ -172,7 +173,8 @@ def forward(self, b_buffer, x): str(ep.run_decompositions({}).graph_module.code).strip(), """\ def forward(self, b_parametrizations_buffer_original0, x): - _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None + _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None + _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index 5c7d7fd43ae2..cd26a31abf7f 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -18,7 +18,8 @@ ) from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 -from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type +from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_unless_torch_gpu, @@ -33,8 +34,10 @@ def scale_for_fp8( t = t.unsqueeze(0).unsqueeze(-2) else: t = t.unflatten(0, (scale_shape[0], -1)).unflatten(-1, (scale_shape[1], -1)) - scale = t.abs().amax(dim=[1, -1]).float() / torch.finfo(torch.float8_e4m3fn).max - t_fp8 = (t / scale[:, None, :, None]).to(torch.float8_e4m3fn) + + scale = t.abs().amax(dim=[1, -1]).float() / E4M3_MAX_POS + t_fp8 = (t / scale[:, None, :, None]).to(e4m3_type) + return t_fp8.flatten(end_dim=1).flatten(start_dim=-2), scale.view(scale_shape) @@ -205,7 +208,7 @@ def test_scaled_mm(self): full_dist_res = dist_res.full_tensor() # Fp8 matmuls are quite inaccurate, we need high tolerances - self.assertEqual(full_dist_res, full_ref_res, atol=1, rtol=7e-2) + self.assertEqual(full_dist_res, full_ref_res, atol=1.5, rtol=7e-2) self.assertEqual(comm_mode.get_total_counts(), 0) @@ -448,7 +451,6 @@ def test_scaled_dot_product_attention(self): self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1)) self.assertEqual(dist_value.grad.full_tensor(), value.grad) - @skipIfRocm @skip_unless_torch_gpu @with_comms() def test_dtensor_mm(self): @@ -472,7 +474,9 @@ def test_dtensor_mm(self): lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()]) rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)]) dtensor_result = lhs_dtensor @ rhs_dtensor - self.assertEqual(dtensor_result.full_tensor(), mm_result) + self.assertEqual( + dtensor_result.full_tensor(), mm_result, atol=1.5e-5, rtol=1e-6 + ) @with_comms @skip_unless_torch_gpu diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 6d970c379065..ddaee7ab2405 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -649,8 +649,8 @@ def test_slice(self): global_out.backward(gradient=torch.ones_like(global_out)) with comm_mode: - sharded_out_grad = torch.distributed._tensor.ones( - sharded_out.shape, device_mesh=mesh, placements=[Shard(1)] + sharded_out_grad = torch.distributed.tensor.ones( + sharded_out.shape, device_mesh=mesh, placements=shard_spec ) sharded_out.backward(gradient=sharded_out_grad) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index a9798f9d434a..179f4a7913ce 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -3,13 +3,16 @@ import itertools import torch -from torch.distributed._tensor import distribute_tensor, DTensor -from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import distribute_tensor, DTensor from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._utils import ( + _explicit_order_placements, + compute_local_shape_and_global_offset, +) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -19,6 +22,87 @@ c10d_functional = torch.ops.c10d_functional +class LocalTest(TestCase): + def test_explicit_order_placements(self): + # mesh_shape: ShapeType, placements: Sequence[Placement] + test_cases = [ + { + "mesh_shape": [2, 4], + "placements": [Replicate(), Replicate()], + "ordered": [(0, Replicate()), (1, Replicate())], + }, + { + "mesh_shape": [3, 2], + "placements": [Shard(0), Replicate()], + "ordered": [(0, Shard(0)), (1, Replicate())], + }, + { + "mesh_shape": [2, 4], + "placements": [_StridedShard(0, split_factor=4), Shard(0)], + "ordered": [(1, Shard(0)), (0, Shard(0))], + }, + { + "mesh_shape": [2, 3, 4], + "placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)], + "ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))], + }, + { + "mesh_shape": [2, 3, 4], + "placements": [ + _StridedShard(0, split_factor=12), + _StridedShard(0, split_factor=4), + Shard(0), + ], + "ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))], + }, + ] + for test_case in test_cases: + actual = _explicit_order_placements( + test_case["mesh_shape"], test_case["placements"] + ) + expected = test_case["ordered"] + + self.assertEqual( + actual, + expected, + f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}", + ) + + error_cases = [ + { + "mesh_shape": [2, 3, 4], + "placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)], + "exception_type": RuntimeError, + "exception_text": "Can only convert _StridedShard to ordered Shard if split_factor", + }, + { + "mesh_shape": [2, 3, 4], + "placements": [ + _StridedShard(0, split_factor=3), + Shard(0), + Shard(0), + ], + "exception_type": NotImplementedError, + "exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended", + }, + { + "mesh_shape": [2, 3], + "placements": [ + Shard(0), + ], + "exception_type": RuntimeError, + "exception_text": "Expected one placement per mesh dim", + }, + ] + for test_case in error_cases: + with self.assertRaisesRegex( + test_case["exception_type"], test_case["exception_text"] + ): + _explicit_order_placements( + test_case["mesh_shape"], test_case["placements"] + ) + + class UtilTest(DTensorTestBase): @property def world_size(self): diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 9228efdedf34..57ad689179da 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -46,6 +46,7 @@ requires_gloo, simple_sparse_reduce_tests, skip_if_lt_x_gpu, + skip_if_win32, verify_ddp_error_logged, ) from torch.testing._internal.common_utils import ( @@ -219,6 +220,8 @@ def test_default_store_timeout_gloo(self): class ProcessGroupGlooTest(MultiProcessTestCase): + lazy_init = False + def _create_process_group_gloo(self, store, rank, world_size, opts): pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts) dist.barrier(group=pg) @@ -231,7 +234,7 @@ def setUp(self): def opts(self, threads=2): opts = c10d.ProcessGroupGloo._Options() opts._timeout = 50.0 - opts._devices = [create_device(interface=LOOPBACK)] + opts._devices = [create_device(interface=LOOPBACK, lazy_init=self.lazy_init)] opts._threads = threads return opts @@ -241,8 +244,8 @@ def test_multi_device_constructor(self): opts = c10d.ProcessGroupGloo._Options() opts._timeout = 5.0 opts._devices = [ - create_device(interface=LOOPBACK), - create_device(interface=LOOPBACK), + create_device(interface=LOOPBACK, lazy_init=self.lazy_init), + create_device(interface=LOOPBACK, lazy_init=self.lazy_init), ] pg = self._create_process_group_gloo(store, self.rank, self.world_size, opts) @@ -2334,6 +2337,19 @@ def test_forward_backward_optimizer(self): optimizer.step() +@skip_if_win32() +class ProcessGroupGlooLazyInitTest(ProcessGroupGlooTest): + lazy_init = True + + def setUp(self): + os.environ["TORCH_GLOO_LAZY_INIT"] = "1" + super().setUp() + + def tearDown(self) -> None: + del os.environ["TORCH_GLOO_LAZY_INIT"] + return super().tearDown() + + class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): @property def device(self): diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index dadb3b0804b3..c8032a89d523 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -4385,7 +4385,7 @@ def started_or_scheduled(self, timing_enabled): class NCCLTraceTest(NCCLTraceTestBase): def _verify_trace(self, t, include_collectives, timing_enabled, is_json): ver = t["version"] - self.assertEqual(ver, "2.4") + self.assertEqual(ver, "2.5") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] diff --git a/test/distributed/test_c10d_ops_nccl.py b/test/distributed/test_c10d_ops_nccl.py index 73bad39956c6..4b8aac29e503 100644 --- a/test/distributed/test_c10d_ops_nccl.py +++ b/test/distributed/test_c10d_ops_nccl.py @@ -733,6 +733,32 @@ def reduce_scatter_base(output_t, input_t): # fails the check because the dtype is different reduce_scatter_base(output_t, tensor) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_reduce_scatter_v(self): + device = torch.device("cuda", self.rank_to_GPU[self.rank][0]) + # A list of tensors with different sizes + input_list = [torch.ones(i, device=device) for i in range(self.world_size)] + # The i-th output should have size i + output = torch.zeros(self.rank, device=device) + work = c10d.reduce_scatter(output, input_list, group=self.pg, async_op=True) + expected = torch.ones(self.rank, device=device) * self.world_size + work.wait() + self.assertEqual(expected, output) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_all_gather_v(self): + device = torch.device("cuda", self.rank_to_GPU[self.rank][0]) + # A list of tensors with different sizes + output_list = [torch.zeros(i, device=device) for i in range(self.world_size)] + # The i-th input has size i, filled with value i + input = torch.ones(self.rank, device=device) * self.rank + work = c10d.all_gather(output_list, input, group=self.pg, async_op=True) + expected = [torch.ones(i, device=device) * i for i in range(self.world_size)] + work.wait() + self.assertEqual(expected, output_list) + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_reduce_scatter_ops(self): diff --git a/test/distributed/test_composability.py b/test/distributed/test_composability.py index 91b22a60e74b..812d5d8abc16 100644 --- a/test/distributed/test_composability.py +++ b/test/distributed/test_composability.py @@ -385,7 +385,7 @@ def apply_dp(partial_model): if not ( dist.is_available() and dist.is_nccl_available() - and torch.cuda.device_count() > 1 + and torch.cuda.device_count() > 3 ): print( "c10d NCCL not available or not enough GPUs, skipping tests", diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index b5e961276f87..d25da76a8931 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -771,7 +771,7 @@ def test_subgroup(self) -> None: self.assertTrue(buf.eq(peer_rank + world.size() // 2).all()) -@skipIfRocm +# @skipIfRocm @instantiate_parametrized_tests @requires_cuda_p2p_access() class SymmMemCollectiveTest(MultiProcessTestCase): @@ -912,7 +912,7 @@ def test_two_shot_all_reduce(self) -> None: shift = align_bytes // t.element_size() numel = size_bytes // t.element_size() res = t[shift : shift + numel] - res.normal_().fill_(1) + res.normal_() inp = res.clone() if not inplace: out = torch.empty_like(inp) @@ -940,6 +940,78 @@ def _verify_all_reduce_result(self, inp, res): gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01 ) + @skipIfRocm + @skip_if_lt_x_gpu(4) + def test_reduce_scatter(self) -> None: + self._init_process() + group_name = dist.group.WORLD.group_name + + for dtype, size_bytes, align_bytes, split_last_dim in itertools.product( + [torch.float, torch.bfloat16], + [128, 8192, 36 * 1024 * 16], + [4, 8, 16], + [True, False], + ): + t = symm_mem.empty(36 * 1024 * 16, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(t, group=group_name) + + self.assertTrue(t.data_ptr() % 16 == 0) + self.assertTrue(align_bytes % t.element_size() == 0) + self.assertTrue(size_bytes % t.element_size() == 0) + + shift = align_bytes // t.element_size() + numel = size_bytes // t.element_size() + res = t[shift : shift + numel].normal_() + if split_last_dim: + res = res.view(-1, 128 // t.element_size()) + inp = res.clone() + out_size = list(inp.shape) + out_size[-1] = inp.shape[-1] // self.world_size + out = torch.empty(out_size, dtype=dtype, device=self.device) + torch.ops.symm_mem.reduce_scatter_out(res, group_name, split_last_dim, out) + + # Head and tail should not be written + self.assertTrue(t[:shift].eq(0).all().item()) + self.assertTrue(t[shift + numel :].eq(0).all().item()) + self._verify_reduce_scatter_result(inp, out) + + dist.destroy_process_group() + + @skipIfRocm + @skip_if_lt_x_gpu(4) + def test_reduce_scatter_corner_cases(self) -> None: + dtype = torch.bfloat16 + self._init_process() + group_name = dist.group.WORLD.group_name + t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(t, group=group_name) + res = t[:0] + out_size = res.shape[0] // self.world_size + out = torch.empty(out_size, dtype=dtype, device=self.device) + torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out) + res = t[:48] + out_size = res.shape[0] // self.world_size + out = torch.empty(out_size, dtype=dtype, device=self.device) + with self.assertRaisesRegex(RuntimeError, "divisible"): + torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out) + res = t[: 2 * 48].view(2, 48) + out = torch.empty(2, 48 // self.world_size, dtype=dtype, device=self.device) + with self.assertRaisesRegex(RuntimeError, "divisible"): + torch.ops.symm_mem.reduce_scatter_out(res, group_name, True, out) + + def _verify_reduce_scatter_result(self, inp, res): + gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, *res.shape) + gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, *inp.shape) + sum_inps = gathered_inps.sum(0) + slice_width = sum_inps.shape[-1] // self.world_size + for i in range(self.world_size): + torch.testing.assert_close( + gathered_res[i], + sum_inps[..., i * slice_width : (i + 1) * slice_width], + rtol=1e-01, + atol=1e-01, + ) + @skip_if_lt_x_gpu(4) @parametrize("align_bytes", [4, 8, 16]) def test_multimem_all_gather(self, align_bytes: int) -> None: diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 3f9c23efc1d1..b42c56b21ced 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import unittest +from typing import Any import torch import torch._dynamo.test_case @@ -73,6 +74,194 @@ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"): """, # NOQA: B950 ) + def _find_hop_schema( + self, gm: torch.fx.GraphModule, target: Any + ) -> list[torch._C.FunctionSchema]: + import torch.utils._pytree as pytree + + schemas = [] + for node in gm.graph.find_nodes(op="call_function", target=target): + + def _get_example_value(node: torch.fx.Node) -> Any: + if node.op == "get_attr": + return getattr(gm, node.target) + else: + return node.meta["example_value"] + + fake_args, fake_kwargs = pytree.tree_map_only( + torch.fx.Node, + _get_example_value, + (node.args, node.kwargs), + ) + schema = node.target.gen_schema(*fake_args, **fake_kwargs) + schemas.append(schema) + return schemas + + def test_schema_gen_single_return(self): + def inner(x, y): + return (x @ y).sin().cos() + + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=False) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + out = f(x.clone(), y) + self.assertEqual(out, inner(x.clone(), y)) + schemas = self._find_hop_schema(backend.graphs[0], invoke_quant_test) + self.assertEqual(len(schemas), 1) + self.assertExpectedInline( + str(schemas[0]), + """invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, str scheme="nf4") -> ((Tensor))""", # noqa: B950 + ) + + def test_schema_gen_pytree_in_out(self): + def inner(x_y): + x, y = x_y + return [ + (x @ y).sin().cos(), + (x + y, x - y), + {"out": (x @ y,)}, + ] + + # make x not require grad because we want to inplace mutate it + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=True) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend) + def f(x, y): + return invoke_quant_test(inner, [x, y], scheme="nf4") + + out = f(x.clone(), y) + self.assertEqual(out, inner([x.clone(), y])) + schemas = self._find_hop_schema(backend.graphs[0], invoke_quant_test) + self.assertEqual(len(schemas), 1) + self.assertExpectedInline( + str(schemas[0]), + """invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950 + ) + + def test_schema_gen_single_return_with_mutation(self): + def inner(x, y): + x.add_(1) + y.mul_(-1) + return (x @ y).sin().cos() + + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=False) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + with self.assertRaisesRegex( + RuntimeError, + "Encountered input mutation during higher order op tracing for HOP", + ): + f(x.clone(), y) + + def test_schema_gen_pytree_in_out_with_mutation(self): + def inner(x_y): + x, y = x_y + x.add_(1) + return [ + (x @ y).sin().cos(), + (x + y, x - y), + {"out": (x @ y,)}, + ] + + # make x not require grad because we want to inplace mutate it + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=True) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, [x, y], scheme="nf4") + + with self.assertRaisesRegex( + RuntimeError, + "Encountered input mutation during higher order op tracing for HOP", + ): + f(x.clone(), y) + + def test_none_input(self): + def inner(x, y): + if x is not None: + return y.sin() + return y.cos() + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + x = None + y = torch.randn(3, 4) + out = f(x, y) + self.assertEqual(out, inner(x, y)) + self.assertExpectedInline( + normalize_graph(backend.graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_y_: "f32[3, 4]"): + l_y_ = L_y_ + + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_y_, scheme = 'nf4'); subgraph_0 = l_y_ = None + getitem: "f32[3, 4]" = invoke_quant_test[0]; invoke_quant_test = None + return (getitem,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_y_: "f32[3, 4]"): + cos: "f32[3, 4]" = l_y_.cos(); l_y_ = None + return (cos,) +""", + ) + + def test_int_input(self): + def inner(x, y): + return x + y + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + x = 1 + y = torch.randn(3, 4) + out = f(x, y) + self.assertEqual(out, inner(x, y)) + self.assertExpectedInline( + normalize_graph(backend.graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_y_: "f32[3, 4]"): + l_y_ = L_y_ + + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_y_, scheme = 'nf4'); subgraph_0 = l_y_ = None + getitem: "f32[3, 4]" = invoke_quant_test[0]; invoke_quant_test = None + return (getitem,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_y_: "f32[3, 4]"): + add: "f32[3, 4]" = 1 + l_y_; l_y_ = None + return (add,) +""", + ) + @torch._dynamo.config.patch(assume_static_by_default=True) def test_aot_eager(self): def inner(x, y): diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 44edc5305e14..74ff84dbb9e3 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -9,14 +9,19 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import InternalTorchDynamoError -from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same +from torch._dynamo.testing import ( + EagerAndRecordGraphs, + normalize_gm, + same, + skipIfNotPy311, +) from torch._dynamo.utils import counters from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + make_dynamo_test, parametrize, - TEST_WITH_ROCM, ) @@ -659,7 +664,7 @@ def fn(a_float32, b_float32): self.assertTrue(same(ref, res)) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION or TEST_WITH_ROCM, + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Can't run fused SDPA on this platform", ) def test_autocast_sdpa(self): @@ -1745,10 +1750,13 @@ def fn(x): class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase): def setUp(self): self._prev = torch._dynamo.config.enable_trace_contextlib + self._u_prev = torch._dynamo.config.enable_trace_unittest torch._dynamo.config.enable_trace_contextlib = True + torch._dynamo.config.enable_trace_unittest = True def tearDown(self): torch._dynamo.config.enable_trace_contextlib = self._prev + torch._dynamo.config.enable_trace_unittest = self._u_prev def test_ctx_basic0(self): @contextlib.contextmanager @@ -2692,10 +2700,11 @@ def fn(t): self.assertEqual(y, t.sin()) -class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): +class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py - # https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70 + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py + @make_dynamo_test def test_contextmanager_plain(self): state = [] @@ -2705,24 +2714,14 @@ def woohoo(): yield 42 state.append(999) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - y = t.sum() - with woohoo() as x: - assert state == [1] - assert x == 42 - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - y += x - return y - - t = torch.randn(2, 3) - y = fn(t) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) self.assertEqual(state, [1, 42, 999]) - self.assertEqual(y, t.sum() + 42) - @unittest.expectedFailure + @skipIfNotPy311 + @make_dynamo_test def test_contextmanager_finally(self): state = [] @@ -2734,170 +2733,66 @@ def woohoo(): finally: state.append(999) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - _y = t.sum() - with self.assertRaises(ZeroDivisionError): - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError - - fn(torch.randn(2, 3)) + with self.assertRaises(ZeroDivisionError): + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError self.assertEqual(state, [1, 42, 999]) @unittest.expectedFailure + @make_dynamo_test def test_contextmanager_traceback(self): @contextmanager def f(): yield - frames = [] - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - nonlocal frames - _y = t.sum() - try: - with f(): - 1 / 0 - except ZeroDivisionError as e: - frames = traceback.extract_tb(e.__traceback__) + try: + with f(): + 1 / 0 + except ZeroDivisionError as e: + frames = traceback.extract_tb(e.__traceback__) - fn(torch.randn(2, 3)) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "1 / 0") - - @unittest.expectedFailure - def test_contextmanager_traceback2(self): - @contextmanager - def f(): - yield + self.assertEqual(frames[0].line, "1/0") # Repeat with RuntimeError (which goes through a different code path) - class RuntimeErrorSubclass(RuntimeError): - pass - - frames = [] - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - nonlocal frames - _y = t.sum() - try: - with f(): - raise RuntimeErrorSubclass(42) - except RuntimeErrorSubclass as e: - frames = traceback.extract_tb(e.__traceback__) + try: + with f(): + raise NotImplementedError(42) + except NotImplementedError as e: + frames = traceback.extract_tb(e.__traceback__) - fn(torch.randn(2, 3)) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "raise RuntimeErrorSubclass(42)") - - @unittest.expectedFailure - def test_contextmanager_traceback3(self): - @contextmanager - def f(): - yield - - frames = [] - - class StopIterationSubclass(StopIteration): - pass - - for stop_exc in ( - StopIteration("spam"), - StopIterationSubclass("spam"), - ): - with self.subTest(type=type(stop_exc)): - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - nonlocal frames - _y = t.sum() - try: - with f(): - raise stop_exc - except type(stop_exc) as e: - self.assertIs(e, stop_exc) - frames = traceback.extract_tb(e.__traceback__) - else: - self.fail(f"{stop_exc} was suppressed") - - fn(torch.randn(2, 3)) - self.assertEqual(len(frames), 1) - self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "raise stop_exc") + self.assertEqual(frames[0].line, "raise NotImplementedError(42)") - @unittest.expectedFailure + @make_dynamo_test def test_contextmanager_no_reraise(self): @contextmanager def whee(): yield - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - ctx = whee() - ctx.__enter__() - # Calling __exit__ should not result in an exception - self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) - return t.sum() - - fn(torch.randn(2, 3)) + ctx = whee() + ctx.__enter__() + # Calling __exit__ should not result in an exception + self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) - @unittest.expectedFailure + @make_dynamo_test def test_contextmanager_trap_yield_after_throw(self): @contextmanager def whoo(): try: yield - except Exception: - yield - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(TypeError, TypeError("foo"), None) - return t.sum() - - fn(torch.randn(2, 3)) - - @unittest.expectedFailure - def test_contextmanager_trap_no_yield(self): - @contextmanager - def whoo(): - if False: + except Exception: # noqa: E722 yield - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - ctx = whoo() - with self.assertRaises(RuntimeError): - ctx.__enter__() - return t.sum() - - fn(torch.randn(2, 3)) - - @unittest.expectedFailure - def test_contextmanager_trap_second_yield(self): - @contextmanager - def whoo(): - yield - yield - - @torch.compile(backend="eager", fullgraph=True) - def f(t): - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(None, None, None) - - f(torch.randn(2)) + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(TypeError, TypeError("foo"), None) @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_contextmanager_except(self): @@ -2912,18 +2807,58 @@ def woohoo(): state.append(e.args[0]) self.assertEqual(state, [1, 42, 999]) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError(999) - - fn(torch.randn(2, 3)) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) self.assertEqual(state, [1, 42, 999]) @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_except_stopiter(self): + @contextmanager + def woohoo(): + yield + + class StopIterationSubclass(StopIteration): + pass + + for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")): + with self.subTest(type=type(stop_exc)): + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail(f"{stop_exc} was suppressed") + + @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_except_pep479(self): + code = """\ +from __future__ import generator_stop +from contextlib import contextmanager +@contextmanager +def woohoo(): + yield +""" + locals = {} + exec(code, locals, locals) + woohoo = locals["woohoo"] + + stop_exc = StopIteration("spam") + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail("StopIteration was suppressed") + + @unittest.expectedFailure + @make_dynamo_test def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self): @contextmanager def test_issue29692(): @@ -2932,71 +2867,77 @@ def test_issue29692(): except Exception as exc: raise RuntimeError("issue29692:Chained") from exc - @torch.compile(backend="eager", fullgraph=True) - def f(t): - try: - with test_issue29692(): - raise ZeroDivisionError - except Exception as ex: - self.assertIs(type(ex), RuntimeError) - self.assertEqual(ex.args[0], "issue29692:Chained") - self.assertIsInstance(ex.__cause__, ZeroDivisionError) + try: + with test_issue29692(): + raise ZeroDivisionError + except Exception as ex: + self.assertIs(type(ex), RuntimeError) + self.assertEqual(ex.args[0], "issue29692:Chained") + self.assertIsInstance(ex.__cause__, ZeroDivisionError) + + try: + with test_issue29692(): + raise StopIteration("issue29692:Unchained") + except Exception as ex: + self.assertIs(type(ex), StopIteration) + self.assertEqual(ex.args[0], "issue29692:Unchained") + self.assertIsNone(ex.__cause__) - try: - with test_issue29692(): - raise StopIteration("issue29692:Unchained") - except Exception as ex: - self.assertIs(type(ex), StopIteration) - self.assertEqual(ex.args[0], "issue29692:Unchained") - self.assertIsNone(ex.__cause__) + @unittest.expectedFailure + @make_dynamo_test + def _create_contextmanager_attribs(self): + def attribs(**kw): + def decorate(func): + for k, v in kw.items(): + setattr(func, k, v) + return func - f(torch.randn(2)) + return decorate - @unittest.expectedFailure - def test_contextmanager_wrap_runtimeerror(self): @contextmanager - def woohoo(): - try: - yield - except Exception as exc: - raise RuntimeError(f"caught {exc}") from exc + @attribs(foo="bar") + def baz(spam): + """Whee!""" - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - with self.assertRaises(RuntimeError): - with woohoo(): - 1 / 0 - - fn(torch.randn(2, 3)) + return baz - # If the context manager wrapped StopIteration in a RuntimeError, - # we also unwrap it, because we can't tell whether the wrapping was - # done by the generator machinery or by the generator itself. - with self.assertRaises(StopIteration): - with woohoo(): - raise StopIteration + @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_attribs(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__name__, "baz") + self.assertEqual(baz.foo, "bar") + @make_dynamo_test def test_keywords(self): # Ensure no keyword arguments are inhibited @contextmanager def woohoo(self, func, args, kwds): yield (self, func, args, kwds) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - with woohoo(self=11, func=22, args=33, kwds=44) as target: - self.assertEqual(target, (11, 22, 33, 44)) + with woohoo(self=11, func=22, args=33, kwds=44) as target: + self.assertEqual(target, (11, 22, 33, 44)) - fn(torch.randn(2, 3)) + @unittest.expectedFailure + @make_dynamo_test + def test_param_errors(self): + @contextmanager + def woohoo(a, *, b): + yield + with self.assertRaises(TypeError): + woohoo() + with self.assertRaises(TypeError): + woohoo(3, 5) + with self.assertRaises(TypeError): + woohoo(b=3) + + @make_dynamo_test def test_recursive(self): depth = 0 - ncols = 0 @contextmanager def woohoo(): - nonlocal ncols - ncols += 1 nonlocal depth before = depth depth += 1 @@ -3009,14 +2950,67 @@ def recursive(): if depth < 10: recursive() - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - recursive() + recursive() + self.assertEqual(depth, 0) - fn(torch.randn(2, 3)) + @skipIfNotPy311 + @make_dynamo_test + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield - self.assertEqual(ncols, 10) - self.assertEqual(depth, 0) + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + + @make_dynamo_test + def test_contextmanager_trap_second_yield(self): + @contextmanager + def whoo(): + yield + yield + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(None, None, None) + + @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_wrap_runtimeerror(self): + @contextmanager + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f"caught {exc}") from exc + + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration + + @make_dynamo_test + def test_contextmanager_non_normalised(self): + @contextmanager + def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError # noqa: B904 + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(SyntaxError): + ctx.__exit__(RuntimeError, None, None) instantiate_parametrized_tests(CtxManagerTests) diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 61cafbcbda2c..dcecc827cb99 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -21,6 +21,7 @@ import torch.nn import torch.utils.checkpoint from torch._dynamo.testing import same +from torch._dynamo.utils import dict_items from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import TestCase @@ -936,6 +937,17 @@ def fn(x, d): self.assertEqual(ref, res) self.assertEqual(d1.calls, d2.calls) + def test_items_type(self): + def fn(): + d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) + return d.items() + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + ref = fn() + res = opt_fn() + self.assertEqual(ref, res) + self.assertEqual(type(res), dict_items) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 8db90f836310..8310d3df974d 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -11,6 +11,7 @@ import torch._dynamo.test_case import torch.utils._pytree as python_pytree from torch._dynamo.exc import Unsupported +from torch._dynamo.testing import skipIfNotPy312 from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( IS_FBCODE, @@ -35,6 +36,14 @@ """ +class GenericCtxMgr: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + class GraphBreakMessagesTest(LoggingTestCase): def test_dynamic_shape_operator(self): def fn(): @@ -298,7 +307,7 @@ def post_munge(s): Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function. Hint: Please file an issue to PyTorch. - Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup SKIP_DIRS + Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest from user code: @@ -307,38 +316,6 @@ def post_munge(s): post_munge=post_munge, ) - def test_disable(self): - @torch.compiler.disable - def inner(): - return 1 - - def fn(): - return inner() - - def post_munge(s): - return re.sub( - r"\.inner at 0x[0-9A-Fa-f]+>", - "", - s, - ) - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)(), - """\ -Skip calling `torch.compiler.disable()`d function - Explanation: Skip calling function `` since it was wrapped with `torch.compiler.disable` - Hint: Remove the `torch.compiler.disable` call - - Developer debug context: - - -from user code: - File "test_error_messages.py", line N, in fn - return inner()""", - post_munge=post_munge, - ) - def test_dynamo_graph_break_fn(self): def fn(): torch._dynamo.graph_break() @@ -600,19 +577,12 @@ def fn(mod, x): ) def test_generic_ctx_mgr_graph_break(self): - class CtxMgr: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - def fn(): - with CtxMgr(): - with CtxMgr(): + with GenericCtxMgr(): + with GenericCtxMgr(): pass - with CtxMgr(): - with CtxMgr(): + with GenericCtxMgr(): + with GenericCtxMgr(): pass torch._dynamo.graph_break() @@ -627,7 +597,7 @@ def fn(): Hint: Move the offending context manager(s) to outside the compiled region. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. - Developer debug context: Active generic context managers: [GenericContextWrappingVariable(CtxMgr), GenericContextWrappingVariable(CtxMgr)] + Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)] from user code: @@ -646,18 +616,42 @@ def fn(): """, ) - def test_unsupported_bytecode(self): + def test_load_build_class(self): def fn(): class Foo: pass return Foo + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +LOAD_BUILD_CLASS bytecode not supported + Explanation: Dynamo does not support tracing classes that are defined in the compiled region. + Hint: Move the class definition out of the compiled region. + Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. + + Developer debug context: + + +from user code: + File "test_error_messages.py", line N, in fn + class Foo:""", + ) + + @skipIfNotPy312 + def test_unsupported_bytecode(self): + async def fn(): + async for i in range(3): + print(i) + return 1 + def post_munge(s): s = re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s) s = re.sub( - r"Instruction\(.*opname='LOAD_BUILD_CLASS'.*\)\n", - "Instruction(LOAD_BUILD_CLASS)", + r"Instruction\(.*opname='GET_AITER'.*\)\n", + "Instruction(GET_AITER)", s, ) return s @@ -667,15 +661,15 @@ def post_munge(s): lambda: torch.compile(fn, backend="eager", fullgraph=True)(), """\ Missing bytecode handler - Explanation: Dynamo does not know how to handle the bytecode instruction `LOAD_BUILD_CLASS`. - Hint: Do not trace code that produces the `LOAD_BUILD_CLASS` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics). + Explanation: Dynamo does not know how to handle the bytecode instruction `GET_AITER`. + Hint: Do not trace code that produces the `GET_AITER` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics). Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: LOAD_BUILD_CLASS with args (, Instruction(LOAD_BUILD_CLASS) + Developer debug context: GET_AITER with args (, Instruction(GET_AITER) from user code: File "test_error_messages.py", line N, in fn - class Foo:""", + async for i in range(3):""", post_munge=post_munge, ) @@ -696,9 +690,9 @@ def post_munge(s): """\ Reconstruction failure Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). - Hint: If Dynamo attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. + Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. - Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't havereconstruction rules may be fundamentally unreconstructable. + Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable. Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) @@ -750,9 +744,9 @@ def post_munge(s): """\ Reconstruction failure Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). - Hint: If Dynamo attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. + Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. - Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't havereconstruction rules may be fundamentally unreconstructable. + Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable. Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) @@ -841,6 +835,44 @@ def fn(x): """, ) + @unittest.skipIf(IS_FBCODE, "assert gets patched in internal pytest") + @make_logging_test(graph_breaks=True) + def test_assert_failure_in_generic_ctx_mgr(self, records): + def fn(x): + with GenericCtxMgr(): + assert x is None + + with self.assertRaises(AssertionError): + torch.compile(fn, backend="eager")(torch.randn(3)) + + # only 1 graph break message + self.assertEqual(len(records), 1) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break: skip: from user code at: + File "test_error_messages.py", line N, in fn + assert x is None +""", + ) + self.assertExpectedInline( + munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0), + """\ +Data-dependent assertion failed (cannot compile partial graph) + Explanation: Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph. + Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. + Hint: Use `torch._assert()` to raise a hard AssertionError when the check fails. This error will propagate back the user code that called the compiled function (i.e. Dynamo wil not trace any exception handling). + Hint: Remove the assert statement. + Hint: Move the assert statement outside of any context managers in order to graph break with partial graph compilation (if fullgraph=False). + + Developer debug context: value: ConstantVariable(bool: False) + + +from user code: + File "test_error_messages.py", line N, in fn + assert x is None""", + ) + def test_no_internal_compiler_stacktrace(self): def fn(): gn() @@ -1090,6 +1122,81 @@ def f3(x): """, ) + def test_disable_message(self): + @torch.compile(backend="eager", fullgraph=True) + def outer(fn, x): + return fn(x) + + @torch.compiler.disable + def f(x): + return x + 1 + + def post_munge(s): + return re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: outer(f, torch.randn(3)), + """\ +Skip calling `torch.compiler.disable()`d function + Explanation: Skip calling function `.f at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: None) + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: .f at 0xmem_addr> + + +from user code: + File "test_error_messages.py", line N, in outer + return fn(x)""", + post_munge=post_munge, + ) + + @torch.compiler.disable(reason="test message") + def g(x): + return x + 2 + + self.assertExpectedInlineMunged( + Unsupported, + lambda: outer(g, torch.randn(3)), + """\ +Skip calling `torch.compiler.disable()`d function + Explanation: Skip calling function `.g at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: test message) + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: .g at 0xmem_addr> + + +from user code: + File "test_error_messages.py", line N, in outer + return fn(x)""", + post_munge=post_munge, + ) + + class Mod(torch.nn.Module): + def forward(self, x): + return x + 3 + + mod = Mod() + mod.compile() + mod = torch.compiler.disable(mod, reason="test message 2") + + self.assertExpectedInlineMunged( + Unsupported, + lambda: outer(mod, torch.randn(3)), + """\ +Unsupported function call (delayed) + Explanation: Dynamo determined that a graph break should occur when calling `L['fn']`. Reason: Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: test message 2) + + + Developer debug context: source: LocalSource(local_name='fn', is_input=True, dynamism=None, is_derefed_cell_contents=False) + + +from user code: + File "test_error_messages.py", line N, in outer + return fn(x)""", + post_munge=post_munge, + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index c2390e8db449..bfd1f5352645 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -177,7 +177,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_raise_match(self): a = AttributeError @@ -259,7 +258,6 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager") opt_fn(x) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_exception_with_ctx_manager(self): def fn(x): x = torch.cos(x) @@ -545,6 +543,34 @@ def fn(x, d, key): self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) + @make_dynamo_test + def test_raise_from_None_2(self): + def fn(): + try: + raise ValueError + except Exception: + raise TypeError from None + + try: + fn() + except TypeError as e: + assert e.__cause__ is None + assert e.__suppress_context__ is True + + @make_dynamo_test + def test_raise_from_other(self): + def fn(): + try: + raise ValueError + except Exception as e: + raise TypeError from e + + try: + fn() + except TypeError as e: + assert isinstance(e.__cause__, ValueError) + assert e.__suppress_context__ is True + @make_dynamo_test def test_reraise_first_exc(self): def fn(): @@ -825,7 +851,6 @@ def fn(t): t = torch.randn(2) fn(t) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_user_defined_exception_with_args(self): @torch.compile(backend="eager", fullgraph=True) def fn(t): @@ -861,6 +886,12 @@ def test_raise_set___context__(self): class CPythonExceptionTests(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py + def setUp(self): + self._u_prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._u_prev @make_dynamo_test def testChainingAttrs(self): @@ -948,7 +979,6 @@ def test_context_of_exception_in_else_and_finally(self): assert exc is oe assert exc.__context__ is ve - @unittest.expectedFailure @make_dynamo_test def test_raise_does_not_create_context_chain_cycle(self): A = AssertionError @@ -987,7 +1017,6 @@ def test_raise_does_not_create_context_chain_cycle(self): self.assertIs(c.__context__, b) self.assertIsNone(b.__context__) - @unittest.expectedFailure @make_dynamo_test def test_no_hang_on_context_chain_cycle1(self): # See issue 25782. Cycle in context chain. @@ -1043,7 +1072,6 @@ def test_no_hang_on_context_chain_cycle2(self): self.assertIs(b.__context__, a) self.assertIs(a.__context__, c) - @unittest.expectedFailure @make_dynamo_test def test_no_hang_on_context_chain_cycle3(self): # See issue 25782. Longer context chain with cycle. diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 09aee481c0cc..54a0d70727b9 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -3821,6 +3821,21 @@ def test_map_unpack_vars(a, b): x, y = map(lambda x: x + 1, [a, b]) return x + y + def test_unsqueeze_inplace(self): + def fn(x): + return torch.Tensor.unsqueeze_(x, dim=1) + 1 + + def self_fn(x): + return x.unsqueeze_(dim=1) + 1 + + v = torch.ones([3], device="cpu") + # identical tensor since modify inplace + v2 = torch.ones([3], device="cpu") + opt_fn = torch.compile(fn) + opt_self_fn = torch.compile(self_fn) + self.assertEqual(v, v2) + self.assertEqual(opt_fn(v), opt_self_fn(v2)) + def test_enumerate_custom(self): class MyClass: def __iter__(self): diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index d1f4289a5793..03b1cf3e5268 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -12,6 +12,7 @@ from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + make_dynamo_test, parametrize, ) @@ -1069,6 +1070,7 @@ def fn(t): self.assertEqual(L, [1, -123, -1, 456]) @parametrize("exc", [RuntimeError, AttributeError]) + @make_dynamo_test def test_close_capture_and_reraise_exc(self, exc): def whoo(t): try: @@ -1079,7 +1081,6 @@ def whoo(t): finally: pass - @torch.compile(backend="eager", fullgraph=True) def fn(t): gen = whoo(t) i = next(gen) @@ -1087,8 +1088,14 @@ def fn(t): return i t = torch.randn(2) - with self.assertRaises(exc): + + z = 0 + try: fn(t) + except exc: + z = 1 + finally: + assert z == 1 def test_close_with_subgen(self): L = [] diff --git a/test/dynamo/test_generator_stop.py b/test/dynamo/test_generator_stop.py index fe6c9961ddf9..7091d3d37137 100644 --- a/test/dynamo/test_generator_stop.py +++ b/test/dynamo/test_generator_stop.py @@ -8,19 +8,9 @@ from torch.testing._internal.common_utils import make_dynamo_test -class TestPEP479(torch._dynamo.test_case.TestCase): +class TestPEP479(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py - - def assertTrue(self, expr, msg=None): - assert bool(expr) is True, msg - - def assertIs(self, expr1, expr2, msg=None): - assert expr1 is expr2, msg - - def assertEqual(self, x, y): - assert x == y - @unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12") @make_dynamo_test def test_stopiteration_wrapping(self): @@ -30,16 +20,9 @@ def f(): def g(): yield f() - try: + with self.assertRaises(RuntimeError) as cm: next(g()) - except RuntimeError as cm: - self.assertEqual("generator raised StopIteration", cm.args[0]) - except Exception: - self.fail("Error!") - - # with self.assertRaises(RuntimeError) as cm: - # next(g()) - # self.assertEqual("generator raised StopIteration", str(cm.exception)) + self.assertEqual("generator raised StopIteration", str(cm.exception)) @unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12") @make_dynamo_test diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 805d8f6be2d0..2ff363a5f5c7 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -3,6 +3,7 @@ import torch import torch.fx from torch._dynamo.graph_deduplication import _flatten_args_kwargs +from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.test_case import TestCase from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm @@ -59,18 +60,15 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None o1: "f32[10, 20]" = torch.sin(l_y_) - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_x_, o1)); o1 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None @@ -114,7 +112,9 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -207,7 +207,9 @@ def forward(self, primals_1: "f32[10, 10]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(primals_0, 7); primals_0 = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None + add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1,) """, @@ -265,31 +267,27 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): y0: "f32[10, 20]" = torch.sin(l_y_) - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', \ -(x0, y0)); invoke_subgraph_3 = None - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_x_, l_y_)) + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)) getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None o1: "f32[]" = torch.sin(getitem); getitem = None - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_x_, y0)) + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0)) getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', \ -(x0, y0)); subgraph_1 = x0 = y0 = None - - getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None - mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None + + getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None return (add_13,) @@ -328,32 +326,36 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): ___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph = None - getitem_1: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None + getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None - sin_1: "f32[]" = torch.ops.aten.sin.default(getitem_1) + sin_1: "f32[]" = torch.ops.aten.sin.default(getitem) ___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_post_graph_1 = None - getitem_2: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None - ___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph - invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None - getitem_19: "f32[]" = invoke_subgraph_11[3] - getitem_18: "f32[10, 20]" = invoke_subgraph_11[2] - getitem_17: "f32[10, 10]" = invoke_subgraph_11[1] - getitem_3: "f32[10, 10]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None + getitem_1: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None + + mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None + ___forward_subgraph_0_post_graph_2 = self.___forward_subgraph_0_post_graph - invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None - getitem_4: "f32[]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None + invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None + getitem_2: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None + ___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph + invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None + getitem_19: "f32[]" = invoke_subgraph_12[3] + getitem_18: "f32[10, 20]" = invoke_subgraph_12[2] + getitem_17: "f32[10, 10]" = invoke_subgraph_12[1] + getitem_4: "f32[10, 10]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None - mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_2); sin_1 = None - mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_3); mul = None - add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_4); mul_1 = getitem_4 = None - return (add, primals_1, primals_2, getitem_1, getitem_2, getitem_19, getitem_18, getitem_17, getitem_3) + mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_4); mul = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_2); mul_1 = getitem_2 = None + return (add, primals_1, primals_2, getitem, getitem_1, getitem_19, getitem_18, getitem_17, getitem_4) class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -362,7 +364,9 @@ def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): class ___forward_subgraph_1_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 2) + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 3) + cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = None @@ -420,6 +424,7 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, primals_1); sum_1 = primals_1 = None return (add_1,) @@ -475,12 +480,7 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None - repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \ -'subgraph_0', (add_2, add_3)); repeated_subgraph0 = None - getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None - - clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2); add_2 = None + clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2) clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3) add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1) @@ -491,9 +491,11 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None + repeated_subgraph0 = self.repeated_subgraph0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = add_2 = None + getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None @@ -551,18 +553,19 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None + repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \ -'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None - repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None - getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None + + repeated_subgraph0_1 = self.repeated_subgraph0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None + getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None return (add_2,) @@ -570,7 +573,9 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None + mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -585,6 +590,76 @@ def test_flatten_with_slices(self): str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]""" ) + def test_cycle_detection_no_cycle(self): + def fn(x, y): + x0 = x + 1 + y0 = y + 2 + z = x0.sum() + y0.sum() + return z + + x = torch.rand(10, 10, requires_grad=False) + y = torch.rand(10, 20, requires_grad=False) + + _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) + mod = fw_graphs[0] + self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""") + + def test_cycle_detection_simple(self): + def fn(x, y): + x0 = x + 1 + y0 = y + 2 + z = x0.sum() + y0.sum() + return z + + x = torch.rand(10, 10, requires_grad=False) + y = torch.rand(10, 20, requires_grad=False) + + _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) + mod = fw_graphs[0] + add_node = next(n for n in mod.graph.nodes if n.name == "add") + add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") + args = add_node.args + add_node.args = (args[0], add_2) + self.assertExpectedInline( + _detect_cycles(mod.graph), + """cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""", + ) + + def test_cycle_detection_complex(self): + def inner_fn(x, y): + x0 = x.view(x.size()) + return x0.view(x.size()) + + def inner_fn2(x, y): + x = x * 2 + y = y * 2 + return x.sum() + y.sum() + + def fn(x, y): + o0 = inner_fn(x, y) + o1 = inner_fn(x, y) + o2 = inner_fn2(x, y) + o3 = inner_fn2(x, y) + return o0 + o1 + o2.sum() + o3.sum() + + x = torch.rand(10, 10, requires_grad=False) + y = torch.rand(10, 20, requires_grad=False) + x_clone = x.clone() + y_clone = y.clone() + + _, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) + mod = fw_graphs[0] + invoke_subgraph_node = next( + n for n in mod.graph.nodes if n.name == "invoke_subgraph" + ) + add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") + args = invoke_subgraph_node.args + invoke_subgraph_node.args = (add_2, args[1]) + self.assertExpectedInline( + _detect_cycles(mod.graph), + """cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index a75bc7ac1af7..3793db65d73f 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -872,6 +872,7 @@ def forward(self, x): mod = ToyModel() mod.register_forward_pre_hook(lambda mod, input: input[0] + 1) + # Case 1: torch.compile(mod) cnts = torch._dynamo.testing.CompileCounter() compiled_mod = torch.compile(mod, backend=cnts) @@ -881,6 +882,13 @@ def forward(self, x): self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) + # Case 2: mod.compile() + cnts = torch._dynamo.testing.CompileCounter() + mod.compile(backend=cnts) + res = mod(x) + self.assertEqual(ref, res) + self.assertEqual(cnts.frame_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a6b3a29eb4d3..53c3e8b624ca 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11,6 +11,7 @@ import gc import importlib import itertools +import json import logging import math import operator @@ -76,6 +77,7 @@ TEST_CUDA, TEST_MULTIGPU, ) +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import ( sample_inputs_take_along_dim, ) @@ -84,8 +86,10 @@ IS_FBCODE, scoped_load_inline, set_default_dtype, + skipIfHpu, skipIfNNModuleInlined, skipIfWindows, + TEST_HPU, wrapDeterministicFlagAPITest, ) from torch.testing._internal.jit_utils import JitTestCase @@ -586,6 +590,22 @@ def f(x): ref = f(x) self.assertEqual(res, ref) + def test_newly_constructed_tensor_attr_mutation(self): + def f(x): + y = x + 10 + y.grad = x + y.foo = 42 + return y + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.ones(5) + + res = opt_f(x) + ref = f(x) + self.assertEqual(res, ref) + self.assertEqual(res.grad, ref.grad) + self.assertEqual(res.foo, ref.foo) + def test_closure_recompiles(self): cnt = CompileCounter() @@ -3166,6 +3186,58 @@ def fn(m, x): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 4) + def test_global_state_guard_serialization(self): + GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard + guards = GlobalStateGuard() + serialized_guards = guards.dump() + json_guards = json.loads(serialized_guards) + + samples = [] + # Test on non autocast state and autocast cache states. + self.assertIn("autocast_state", json_guards) + for key, value in json_guards.items(): + if type(value) == int: + variant = value + 1 + elif type(value) == bool: + variant = not value + elif isinstance(value, dict) and key == "autocast_state": + variant = value.copy() + variant["cached_enabled"] = not variant["cached_enabled"] + continue + else: + self.fail(f"Unknown global state type {key}: {value}") + new_dict = json_guards.copy() + new_dict[key] = variant + samples.append(new_dict) + + for sample in samples: + guards.load(json.dumps(sample)) + self.assertFalse(guards.check()) + + guards.load(json.dumps(json_guards)) + self.assertTrue(guards.check()) + + # Test on autocast states. + def _test_autocast(dtype): + with torch.autocast("cpu", dtype): + guards = GlobalStateGuard() + serialized_guards = guards.dump() + json_guards = json.loads(serialized_guards) + + for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]): + if enabled: + self.assertEqual( + type(json_guards["autocast_state"]["dtype"][i]), int + ) + json_guards["autocast_state"]["dtype"][i] += 1 + guards.load(json.dumps(json_guards)) + self.assertFalse(guards.check()) + + _test_autocast(torch.float16) + _test_autocast(torch.float32) + _test_autocast(torch.float64) + _test_autocast(torch.bfloat16) + def test_type_copy(self): def fn(seq): a, b = seq @@ -4266,27 +4338,6 @@ def test_version_ci(self): # temporary test to check that the ci torch version is set correctly self.assertTrue(hasattr(torch, "_subclasses")) - @unittest.skipIf(not TEST_CUDA, "requires cuda") - def test_rand(self): - cnts = torch._dynamo.testing.CompileCounter() - device = "cuda" - - def fn(): - return torch.randn(10, device=device) - - torch.manual_seed(10) - ref_run1 = fn() - - torch.manual_seed(10) - ref_run2 = fn() - self.assertTrue(same(ref_run1, ref_run2)) - - torch.manual_seed(10) - opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) - res = opt_fn() - - self.assertTrue(same(res, ref_run1)) - def test_slice_input(self): cnts = torch._dynamo.testing.CompileCounter() @@ -5969,57 +6020,6 @@ def fn(param, y): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 3) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, - "Can't run fused SDPA on this platform", - ) - def test_parsing_sdpa(self): - class MyModule(torch.nn.Module): - def forward(self, query, key, value): - out = F.scaled_dot_product_attention(query, key, value, None, 0, True) - out = F.scaled_dot_product_attention( - query, key, value, None, 0, True, scale=8 - ) - out = F.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=None, - dropout_p=0, - is_causal=True, - ) - out = F.scaled_dot_product_attention( - query, - key=key, - value=value, - attn_mask=None, - dropout_p=0, - is_causal=True, - ) - out = F.scaled_dot_product_attention( - query, key, value, None, dropout_p=0, is_causal=True - ) - out = F.scaled_dot_product_attention(query, key, value, None, scale=8) - return out - - device = "cuda" - dtype = torch.float16 - seq_len_q = 1 - seq_len_k = 1 - head_dim = 8 - query = torch.ones( - 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True - ) - key = torch.ones( - 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True - ) - value = torch.ones( - 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True - ) - module = MyModule() - opt_mod = torch.compile(module, backend="inductor") - opt_mod(query, key, value) - def test_generate_tensor_from_list_of_numpy_primitive_type(self): # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) def fn(): @@ -6476,19 +6476,6 @@ def fn(x, obj): res = opt_fn(x, obj) self.assertTrue(same(ref, res)) - def test_torch_cuda_is_available(self): - def fn(x): - if torch.cuda.is_available(): - return x + 1 - else: - return x - 1 - - x = torch.rand(4) - ref = fn(x) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - def test_variable_tracker_recursively_contains(self): # VariableTracker.recursively_contains should be updated correctly when mutation happens def fn(x): @@ -6506,61 +6493,6 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) - @unittest.skipIf(not TEST_CUDA, "requires cuda") - @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") - def test_torch_cudnn_is_acceptable(self): - def fn(x): - if torch.backends.cudnn.is_acceptable(tensor=x): - return x + 1 - return x - - x = torch.rand(4).cuda() - ref = fn(x) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - - @unittest.skipIf(not TEST_CUDA, "requires cuda") - @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") - def test_torch_cudnn_is_acceptable_bad_inputs(self): - def fn1(x): - if torch.backends.cudnn.is_acceptable("invalid"): - return x + 1 - return x - - def fn2(x): - if torch.backends.cudnn.is_acceptable(x, 3.14): - return x + 1 - return x - - with self.assertRaisesRegex( - AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" - ): - x1 = torch.rand(4).cuda() - opt_fn1 = torch.compile(fn1, backend="eager", fullgraph=True) - res1 = opt_fn1(x1) - - with self.assertRaisesRegex( - AssertionError, "Expect 1 input to cudnn.is_acceptable" - ): - x2 = torch.rand(4).cuda() - opt_fn2 = torch.compile(fn2, backend="eager", fullgraph=True) - res = opt_fn2(x2) - - @unittest.skipIf(not TEST_CUDA, "requires cuda") - def test_get_device(self): - def fn(x, y): - x = x + 1 - y = y + 1 - return x.get_device(), y.get_device() - - x = torch.rand(4, device="cuda") - y = torch.rand(4, device="cpu") - ref = fn(x, y) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - res = opt_fn(x, y) - self.assertTrue(same(ref, res)) - def test_disable_flag(self): cnt = torch._dynamo.testing.CompileCounter() @@ -6856,17 +6788,6 @@ def guard_export_print(guards): # This guard was created self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") - def test_symint_as_device_kwarg(self): - def f(rank): - # -2 to make device id 0 for easier testing on CI - return torch.ones(10, device=rank.size(0) - 2) - - x = torch.randn(2) - out = f(torch.randn(2)) - opt_out = torch.compile(backend="eager", dynamic=True, fullgraph=True)(f)(x) - self.assertEqual(out, opt_out) - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") def test_symint_as_device_kwarg_multi_gpu(self): def fn(rank): @@ -8315,21 +8236,6 @@ def func(x): self.assertTrue(isinstance(compile_out, torch.Size)) self.assertEqual(eager_out, compile_out) - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") - def test_cuda_set_device(self): - def fn(): - a = torch.ones(2, device="cuda") - torch.cuda.set_device(1) - return a + 1 - - with torch.cuda.device(0): - counter = CompileCounter() - opt_fn = torch.compile(fn, backend=counter) - res = opt_fn() - self.assertEqual(res.device.type, "cuda") - self.assertEqual(res.device.index, 0) - self.assertEqual(counter.frame_count, 2) - def test_nested_function_resuming_with_correct_globals(self): # https://github.com/pytorch/pytorch/issues/99665 try: @@ -9568,36 +9474,6 @@ def fn(): res = opt_func() self.assertEqual(ref, res) - def test_torch_device_python_type(self): - for device, device_type, index in [ - ("cpu", "cpu", None), - ("cuda:0", "cuda", 0), - ]: - if device == "cuda:0" and not TEST_CUDA: - continue - - def fn(target): - target_device = target.device - a = torch.zeros(2, 3, device=target_device) - # Constant assert at trace time - assert isinstance(target_device, torch.device) - assert target_device.type == device_type - assert target_device.index == index - b = torch.zeros(2, 3, device=target_device) - c = torch.zeros(2, 3, device=target_device) - return a + b + c - - from torch._dynamo.variables import ConstantVariable - - device = torch.device(device) - expected_variable = ConstantVariable(device) - self.assertEqual(expected_variable.python_type(), type(device)) - - opt_func = torch.compile(fn, backend="eager", fullgraph=True) - a = torch.tensor([2, 3], device=device) - res = opt_func(a) - self.assertIsInstance(res, torch.Tensor) - def test_torch_dtype_python_type(self): def fn(target): target_dtype = target.dtype @@ -10388,8 +10264,8 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {44, 93} > Right: {} ==> val_to_var: values don't match. - > Left: {0: 0, 1: 1, 2: s44, 3: s93} - > Right: {0: 0, 1: 1} + > Left: {2: s44, 3: s93} + > Right: {} ==> var_to_range: values don't match. > Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]} > Right: {} @@ -11519,23 +11395,6 @@ def fn(x, d): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) - @unittest.skipIf(not TEST_CUDA, "requires cuda") - @torch._dynamo.config.patch( - capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True - ) - @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) - def test_interpolate_propagate_real_tensors(self): - @torch.compile(backend="eager", fullgraph=True) - def f(mask, box): - # u0, u1 = mask.tolist() - mask = torch.randn(1, 1, 30, 30, device="cuda") - h, w = box.tolist() - return torch.nn.functional.interpolate( - mask, (h, w), mode="bilinear", align_corners=False - ) - - f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda")) - def test_iter_type(self): @torch.compile(fullgraph=True) def fn(y): @@ -12177,6 +12036,222 @@ def fn(x, y): self.assertTrue(y.grad is not None) +class MiscTestsDevice(torch._inductor.test_case.TestCase): + def test_rand(self, device): + cnts = torch._dynamo.testing.CompileCounter() + device = device + + def fn(): + return torch.randn(10, device=device) + + torch.manual_seed(10) + ref_run1 = fn() + + torch.manual_seed(10) + ref_run2 = fn() + self.assertTrue(same(ref_run1, ref_run2)) + + torch.manual_seed(10) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + res = opt_fn() + + self.assertTrue(same(res, ref_run1)) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Can't run fused SDPA on this platform", + ) + def test_parsing_sdpa(self, device): + class MyModule(torch.nn.Module): + def forward(self, query, key, value): + out = F.scaled_dot_product_attention(query, key, value, None, 0, True) + out = F.scaled_dot_product_attention( + query, key, value, None, 0, True, scale=8 + ) + out = F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, + dropout_p=0, + is_causal=True, + ) + out = F.scaled_dot_product_attention( + query, + key=key, + value=value, + attn_mask=None, + dropout_p=0, + is_causal=True, + ) + out = F.scaled_dot_product_attention( + query, key, value, None, dropout_p=0, is_causal=True + ) + out = F.scaled_dot_product_attention(query, key, value, None, scale=8) + return out + + device = device + dtype = torch.float16 + seq_len_q = 1 + seq_len_k = 1 + head_dim = 8 + query = torch.ones( + 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True + ) + key = torch.ones( + 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True + ) + value = torch.ones( + 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True + ) + module = MyModule() + opt_mod = torch.compile(module, backend="inductor") + opt_mod(query, key, value) + + def test_torch_device_is_available(self, device): + def fn(x): + if TEST_HPU or TEST_CUDA: + return x + 1 + else: + return x - 1 + + x = torch.rand(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + @unittest.skipIf(not TEST_CUDA, "requires cuda") + @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") + def test_torch_cudnn_is_acceptable(self, device): + def fn(x): + if torch.backends.cudnn.is_acceptable(tensor=x): + return x + 1 + return x + + x = torch.rand(4).to(device) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + @unittest.skipIf(not TEST_CUDA, "requires cuda") + @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") + def test_torch_cudnn_is_acceptable_bad_inputs(self, device): + def fn1(x): + if torch.backends.cudnn.is_acceptable("invalid"): + return x + 1 + return x + + def fn2(x): + if torch.backends.cudnn.is_acceptable(x, 3.14): + return x + 1 + return x + + with self.assertRaisesRegex( + AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" + ): + x1 = torch.rand(4).to(device) + opt_fn1 = torch.compile(fn1, backend="eager", fullgraph=True) + res1 = opt_fn1(x1) + + with self.assertRaisesRegex( + AssertionError, "Expect 1 input to cudnn.is_acceptable" + ): + x2 = torch.rand(4).to(device) + opt_fn2 = torch.compile(fn2, backend="eager", fullgraph=True) + res = opt_fn2(x2) + + def test_get_device(self, device): + def fn(x, y): + x = x + 1 + y = y + 1 + return x.get_device(), y.get_device() + + x = torch.rand(4, device=device) + y = torch.rand(4, device="cpu") + ref = fn(x, y) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + + def test_symint_as_device_kwarg(self, device): + def f(rank): + # -2 to make device id 0 for easier testing on CI + return torch.ones(10, device=rank.size(0) - 2) + + x = torch.randn(2) + out = f(torch.randn(2)) + opt_out = torch.compile(backend="eager", dynamic=True, fullgraph=True)(f)(x) + self.assertEqual(out, opt_out) + + @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") + def test_cuda_set_device(self, device): + def fn(): + a = torch.ones(2, device=device) + torch.cuda.set_device(1) + return a + 1 + + with torch.cuda.device(0): + counter = CompileCounter() + opt_fn = torch.compile(fn, backend=counter) + res = opt_fn() + self.assertEqual(res.device.type, "cuda") + self.assertEqual(res.device.index, 0) + self.assertEqual(counter.frame_count, 2) + + def test_torch_device_python_type(self): + for device, device_type, index in [ + ("cpu", "cpu", None), + ("cuda:0", "cuda", 0), + ("hpu:0", "hpu", 0), + ]: + if (device == "cuda:0" and not TEST_CUDA) or ( + device == "hpu:0" and not TEST_HPU + ): + continue + + def fn(target): + target_device = target.device + a = torch.zeros(2, 3, device=target_device) + # Constant assert at trace time + assert isinstance(target_device, torch.device) + assert target_device.type == device_type + assert target_device.index == index + b = torch.zeros(2, 3, device=target_device) + c = torch.zeros(2, 3, device=target_device) + return a + b + c + + from torch._dynamo.variables import ConstantVariable + + device = torch.device(device) + expected_variable = ConstantVariable(device) + self.assertEqual(expected_variable.python_type(), type(device)) + + opt_func = torch.compile(fn, backend="eager", fullgraph=True) + a = torch.tensor([2, 3], device=device) + res = opt_func(a) + self.assertIsInstance(res, torch.Tensor) + + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) + def test_interpolate_propagate_real_tensors(self, device): + @torch.compile(backend="eager", fullgraph=True) + def f(mask, box): + # u0, u1 = mask.tolist() + mask = torch.randn(1, 1, 30, 30, device=device) + h, w = box.tolist() + return torch.nn.functional.interpolate( + mask, (h, w), mode="bilinear", align_corners=False + ) + + f(torch.tensor([30, 30], device=device), torch.tensor([68, 32], device=device)) + + +devices = ("cuda", "hpu") +instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_raise.py b/test/dynamo/test_raise.py new file mode 100644 index 000000000000..9a95d23226c0 --- /dev/null +++ b/test/dynamo/test_raise.py @@ -0,0 +1,563 @@ +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import types +import unittest + +import torch +import torch._dynamo.config +import torch._dynamo.test_case +import torch._functorch.config +import torch.nn +import torch.utils.checkpoint +from torch.testing._internal.common_utils import make_dynamo_test + + +def get_tb(): + try: + raise OSError() + except: + return sys.exc_info()[2] + + +class Context: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return True + + +class MyException(Exception): + def __init__(self): + raise RuntimeError() + + +class ContextManager: + def __enter__(self): + pass + + def __exit__(self, t, v, tb): + raise NameError + + +class TestRaise(torch._dynamo.test_case.CPythonTestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + @make_dynamo_test + def test_invalid_reraise(self): + try: + raise + except RuntimeError as e: + self.assertIn("No active exception", str(e)) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_reraise(self): + try: + try: + raise IndexError + except IndexError as e: + exc1 = e + raise + except IndexError as exc2: + self.assertIs(exc1, exc2) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_except_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + try: + raise KeyError("caught") + except KeyError: + pass + raise + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_finally_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + try: + raise KeyError("caught") + finally: + raise + + self.assertRaises(KeyError, reraise) + + @make_dynamo_test + def test_nested_reraise(self): + def nested_reraise(): + raise + + def reraise(): + try: + raise TypeError("foo") + except: + nested_reraise() + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_raise_from_None(self): + try: + try: + raise TypeError("foo") + except: + raise ValueError() from None + except ValueError as e: + self.assertIsInstance(e.__context__, TypeError) + self.assertIsNone(e.__cause__) + + @make_dynamo_test + def test_with_reraise1(self): + def reraise(): + try: + raise TypeError("foo") + except: + with Context(): + pass + raise + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_with_reraise2(self): + def reraise(): + try: + raise TypeError("foo") + except: + with Context(): + raise KeyError("caught") + raise + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_yield_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + yield 1 + raise + + g = reraise() + next(g) + self.assertRaises(TypeError, lambda: next(g)) + self.assertRaises(StopIteration, lambda: next(g)) + + @make_dynamo_test + def test_erroneous_exception(self): + try: + raise MyException + except RuntimeError: + pass + else: + self.fail("No exception raised") + + @unittest.expectedFailure # object + @make_dynamo_test + def test_new_returns_invalid_instance(self): + # See issue #11627. + class MyException2(Exception): + def __new__(cls, *args): + return object() + + with self.assertRaises(TypeError): + raise MyException2 + + @unittest.expectedFailure # Assertion with non-string message + @make_dynamo_test + def test_assert_with_tuple_arg(self): + try: + assert False, (3,) + except AssertionError as e: + self.assertEqual(str(e), "(3,)") + + +class TestCause(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev + + @make_dynamo_test + def testCauseSyntax(self): + try: + try: + try: + raise TypeError + except Exception: + raise ValueError from None + except ValueError as exc: + self.assertIsNone(exc.__cause__) + self.assertTrue(exc.__suppress_context__) + exc.__suppress_context__ = False + raise exc + except ValueError as exc: + e = exc + + self.assertIsNone(e.__cause__) + self.assertFalse(e.__suppress_context__) + self.assertIsInstance(e.__context__, TypeError) + + @make_dynamo_test + def test_invalid_cause(self): + try: + raise IndexError from 5 + except TypeError as e: + self.assertIn("exception cause", str(e)) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_class_cause(self): + try: + raise IndexError from KeyError + except IndexError as e: + self.assertIsInstance(e.__cause__, KeyError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_instance_cause(self): + cause = KeyError() + try: + raise IndexError from cause + except IndexError as e: + self.assertIs(e.__cause__, cause) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_erroneous_cause(self): + try: + raise IndexError from MyException + except RuntimeError: + pass + else: + self.fail("No exception raised") + + +class TestTraceback(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_sets_traceback(self): + try: + raise IndexError() + except IndexError as e: + self.assertIsInstance(e.__traceback__, types.TracebackType) + else: + self.fail("No exception raised") + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_accepts_traceback(self): + tb = get_tb() + try: + raise IndexError().with_traceback(tb) + except IndexError as e: + self.assertNotEqual(e.__traceback__, tb) + self.assertEqual(e.__traceback__.tb_next, tb) + else: + self.fail("No exception raised") + + +class TestTracebackType(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev + + def raiser(self): + raise ValueError + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_attrs(self): + try: + self.raiser() + except Exception as exc: + tb = exc.__traceback__ + + self.assertIsInstance(tb.tb_next, types.TracebackType) + self.assertIs(tb.tb_frame, sys._getframe()) + self.assertIsInstance(tb.tb_lasti, int) + self.assertIsInstance(tb.tb_lineno, int) + + self.assertIs(tb.tb_next.tb_next, None) + + # Invalid assignments + with self.assertRaises(TypeError): + del tb.tb_next + + with self.assertRaises(TypeError): + tb.tb_next = "asdf" + + # Loops + with self.assertRaises(ValueError): + tb.tb_next = tb + + with self.assertRaises(ValueError): + tb.tb_next.tb_next = tb + + # Valid assignments + tb.tb_next = None + self.assertIs(tb.tb_next, None) + + new_tb = get_tb() + tb.tb_next = new_tb + self.assertIs(tb.tb_next, new_tb) + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_constructor(self): + other_tb = get_tb() + frame = sys._getframe() + + tb = types.TracebackType(other_tb, frame, 1, 2) + self.assertEqual(tb.tb_next, other_tb) + self.assertEqual(tb.tb_frame, frame) + self.assertEqual(tb.tb_lasti, 1) + self.assertEqual(tb.tb_lineno, 2) + + tb = types.TracebackType(None, frame, 1, 2) + self.assertEqual(tb.tb_next, None) + + with self.assertRaises(TypeError): + types.TracebackType("no", frame, 1, 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, "no", 1, 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, frame, "no", 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, frame, 1, "nuh-uh") + + +class TestContext(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev + + @unittest.expectedFailure # missing Exception.__eq__ + @make_dynamo_test + def test_instance_context_instance_raise(self): + context = IndexError() + try: + try: + raise context + except: + raise OSError() + except OSError as e: + self.assertEqual(e.__context__, context) + else: + self.fail("No exception raised") + + @unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__ + @make_dynamo_test + def test_class_context_instance_raise(self): + context = IndexError + try: + try: + raise context + except: + raise OSError() + except OSError as e: + self.assertNotEqual(e.__context__, context) + self.assertIsInstance(e.__context__, context) + else: + self.fail("No exception raised") + + @unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__ + @make_dynamo_test + def test_class_context_class_raise(self): + context = IndexError + try: + try: + raise context + except: + raise OSError + except OSError as e: + self.assertNotEqual(e.__context__, context) + self.assertIsInstance(e.__context__, context) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_c_exception_context(self): + try: + try: + raise ZeroDivisionError + except: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_c_exception_raise(self): + try: + try: + raise ZeroDivisionError + except: + raise NameError + except NameError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_noraise_finally(self): + try: + try: + pass + finally: + raise OSError + except OSError as e: + self.assertIsNone(e.__context__) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_raise_finally(self): + try: + try: + raise ZeroDivisionError + finally: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_context_manager(self): + try: + with ContextManager(): + raise ZeroDivisionError + except NameError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_cycle_broken(self): + # Self-cycles (when re-raising a caught exception) are broken + try: + try: + raise ZeroDivisionError + except ZeroDivisionError as e: + raise e + except ZeroDivisionError as e: + self.assertIsNone(e.__context__) + + @make_dynamo_test + def test_reraise_cycle_broken(self): + # Non-trivial context cycles (through re-raising a previous exception) + # are broken too. + try: + try: + raise NameError + except NameError as a: + try: + raise ZeroDivisionError + except ZeroDivisionError: + raise a + except NameError as e: + self.assertIsNone(e.__context__.__context__) + + @make_dynamo_test + def test_3118(self): + # deleting the generator caused the __context__ to be cleared + def gen(): + try: + yield 1 + finally: + pass + + def f(): + g = gen() + next(g) + try: + try: + raise ValueError + except: + del g + raise KeyError + except Exception as e: + self.assertIsInstance(e.__context__, ValueError) + + f() + + @unittest.expectedFailure # too CPython specific(?) + @make_dynamo_test + def test_3611(self): + # A re-raised exception in a __del__ caused the __context__ + # to be cleared + class C: + def __del__(self): + try: + raise ZeroDivisionError + except: + raise + + def f(): + x = C() + try: + try: + x.x + except AttributeError: + del x + raise TypeError + except Exception as e: + self.assertNotEqual(e.__context__, None) + self.assertIsInstance(e.__context__, AttributeError) + + with support.catch_unraisable_exception() as cm: + f() + + self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 4eecfdf13989..662f5420bfcb 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -300,6 +300,103 @@ def fn(model, states, x): got = opt_fn(model, states, x) self.assertEqual(expected, got) + def test_graph_break_in_wrapped_user_function(self): + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + 2 + + @torch.compile(backend="eager") + def gn(x): + x = torch.no_grad()(fn)(x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + + def test_graph_break_in_wrapped_user_method(self): + class Foo: + def __init__(self): + self.a = 1 + self.b = 2 + + def fn(self, x): + x = x + self.a + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + self.b + + obj = Foo() + + @torch.compile(backend="eager") + def gn(x): + obj.fn = torch.no_grad()(obj.fn) + x = obj.fn(x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + + def test_graph_break_in_wrapped_nested_function(self): + @torch.compile(backend="eager") + def gn(x): + a = 1 + b = 2 + + @torch.no_grad() + def fn(x): + x = x + a + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + b + + x = fn(x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + + def test_graph_break_in_wrapped_skipped_function(self): + from torch._dynamo import trace_rules + from torch._dynamo.testing import _skipped_function_for_test_reconstruct + from torch._dynamo.variables import SkipFunctionVariable + + self.assertIs( + trace_rules.lookup(_skipped_function_for_test_reconstruct), + SkipFunctionVariable, + ) + + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + 2 + + @torch.compile(backend="eager") + def gn(x): + x = torch.no_grad()(_skipped_function_for_test_reconstruct)(fn, x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 6d8c86923cf3..3bd52f981c6b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4075,41 +4075,6 @@ def forward(self, **inp): res = torch.compile(mod, backend="eager", fullgraph=True)(**inputs) self.assertEqual(ref, res) - def test_call_finally_python_3_8(self): - # Issue - https://github.com/pytorch/pytorch/issues/97811 - def make_fn(g): - def fn(): - while True: - try: - print(g) - break - except Exception as _: - break - - return torch.compile(fn, backend="eager") - - make_fn(None)() - - def test_call_finally_python_3_8_2(self): - def f(x): - while x: - try: - pass - except Exception as _: - continue - - torch.compile(f, backend="eager")(0) - - def test_call_finally_opcode_python_3_8(self): - def fn(): - try: - return torch.zeros(4) - finally: - return torch.ones(4) # noqa: SIM107, B012 - - result = torch.compile(fn, backend="aot_eager")() - self.assertEqual(result, torch.ones(4)) - def test_string_format(self): s = "temp{i}" @@ -6650,6 +6615,9 @@ def f(image_latent): torch.cuda.manual_seed_all(54321) expected = f(torch.randn((2, 12, 16, 32, 32))).sum() + # https://github.com/pytorch/pytorch/issues/147171 + torch._inductor.config.fallback_random = True + for backend in ["eager", "aot_eager"]: torch.manual_seed(54321) torch.cuda.manual_seed_all(54321) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ef2acadac89d..0e7d54c28448 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -40,6 +40,10 @@ def traceable_subclass(c): return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) +def nontraceable_subclass(c): + return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c}) + + def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) self.assertEqual(actual_recompiles, expected_recompiles) @@ -757,26 +761,22 @@ def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} return super().__torch_function__(func, types, args, kwargs) def sigmoid(self): return None - @torch.compile(backend="eager", fullgraph=True) def fn(x): x.sigmoid() - msg = ( - "Accessing overridden method/attribute sigmoid on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - with torch._dynamo.config.patch( - "traceable_tensor_subclasses", {LocalSubclass} - ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn(x) + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn_opt = compile_full_eager(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) def test_user_overidden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @@ -792,10 +792,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def fn(x): return x.ndim - msg = ( - "Accessing overridden method/attribute ndim on a tensor" - " subclass with a __torch_function__ override is not supported" - ) + msg = "Currently only support accessing overridden attributes that are functions or properties, but got " with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): @@ -804,13 +801,11 @@ def fn(x): def test_user_overidden_property_unsupported(self): class LocalSubclass(torch.Tensor): - def __init__(self) -> None: + def __init__(self, *args, **kwargs) -> None: self._ndim = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} return super().__torch_function__(func, types, args, kwargs) @property @@ -821,19 +816,17 @@ def ndim(self): def ndim(self, value): self._ndim = value - @torch.compile(backend="eager", fullgraph=True) def fn(x): - return x.ndim + return x + x.ndim - msg = ( - "Accessing overridden method/attribute ndim on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - with torch._dynamo.config.patch( - "traceable_tensor_subclasses", {LocalSubclass} - ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn(x) + x = LocalSubclass(torch.ones(2, 2)) + fn_opt = compile_full_eager(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) def test_overridden_method_guarding(self): class LocalSubclass(torch.Tensor): @@ -954,6 +947,275 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) + def test_make_subclass(self): + # Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo + # models its aliasing relationships correctly. + class MySubclass(torch.Tensor): + pass + + def fn(x): + # Downcast then upcast + y = torch.Tensor._make_subclass(MySubclass, x) + z = torch.Tensor._make_subclass(torch.Tensor, x) + # Now `x, y, z` should have the same underlying data. + x += 1 + y += 2 + z += 3 + res = x * y + z + return res + + with traceable_subclass(MySubclass): + x0 = torch.randn(2, 2) + x1 = x0.clone() + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + + def test_subclass_override_shape_and_to(self): + # This is a slight variabtion of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + class MySubclass(torch.Tensor): + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + return new + + @property + def shape(self): + if not hasattr(self, "tensor_shape"): + self.tensor_shape = self.size() + return self.tensor_shape + + def fn(x): + x_shape = x.shape + y = x.to("cpu") + return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape + + with traceable_subclass(MySubclass): + x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + self.assertEqual(x0.tensor_shape, x1.tensor_shape) + + def test_subclass_dont_invoke_torch_function_on_overriden_method(self): + # We shouldn't fire `__torch_function__` for overriden tensor methods. + class MySubclass(torch.Tensor): + def to(self, device): + return self * len(device) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func is torch.Tensor.to: + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x.to("cpu") + + with traceable_subclass(MySubclass): + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_subclass_dont_invoke_torch_function_on_overriden_attr(self): + from types import MethodWrapperType + + # We shouldn't fire `__torch_function__` for overriden tensor attrs. + class MySubclass(torch.Tensor): + def ndim(self): + return 42 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if type(func) is MethodWrapperType and func.__name__ == "ndim": + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x + x.ndim() + + with traceable_subclass(MySubclass): + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): + # This is a slight variation of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + # which basically + # 1. uses tensor subclass to attach quantization metadata onto tensors + # 2. preserve them across torch ops + # 3. use the metadata to dequantize the tensor + # 4. convert it to a regular tensor. + # + # The test is meant to make sure Dynamo won't graph break over it. + class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + return self + + def __init__(self, *args, quant_type=None, **kwargs): + self.quant_type = quant_type + + def as_tensor(self): + return torch.Tensor._make_subclass( + torch.Tensor, self, self.requires_grad + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + quant_type = None + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + quant_type = arg[0].quant_type + break + if isinstance(arg, GGUFParameter): + quant_type = arg.quant_type + break + if isinstance(result, torch.Tensor): + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif isinstance(result, (tuple, list)): + # Preserve the original type (tuple or list) + wrapped = [ + cls(x, quant_type=quant_type) + if isinstance(x, torch.Tensor) + else x + for x in result + ] + return type(result)(wrapped) + else: + return result + + def f(x): + tmp = x * 2 + tmp = tmp + tmp.quant_type + tmp = tmp.as_tensor() + return tmp * 3 + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + + x = GGUFParameter(torch.ones(2), quant_type=42) + with traceable_subclass(GGUFParameter): + res = f(x) + ref = opt_f(x) + self.assertEqual(res, ref) + + def test_newly_constructed_tensor_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from constructor call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = MySubclass(torch.ones(2)) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + with traceable_subclass(MySubclass): + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_as_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from as_subclass call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = torch.ones(2).as_subclass(MySubclass) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + with traceable_subclass(MySubclass): + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_tensor_subclass_attr_codegen_tos(self): + # This repros a very subtle interaction between + # `TensorWithTFOverrideVariable` attribute mutation codegen and + # `PyCodegen.top_of_stack`. It was uncovered from + # `test_tensor_subclass_deepcopy`. + class MySubclass(torch.Tensor): + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_subclass(cls, torch.ones(0)) + r.elem = elem + return r + + def f(t): + return MySubclass(t.elem.clone()) + + opt_f = compile_full_eager(f) + + t = MySubclass(torch.ones(2)) + with traceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + self.assertEqual(res, ref) + self.assertEqual(res.elem, ref.elem) + self.assertEqual(type(res), type(ref)) + + def test_nontraceable_tensor_subclass(self): + # This will error if Dynamo tries to wrap it as a tensor variable, + # because that involves calling certain methods to inspect the tensor + # property, which will blow up in the overriden `__torch_function__`. + class MySubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("one shall not pass") + + def f(t): + return t.foo + torch.ones(10) + + opt_f = torch.compile(f, backend="eager", fullgraph=False) + + t = MySubclass(torch.ones(2)) + t.foo = 42 + # Make sure the `nontraceable_tensor_subclasses` config prevents Dynamo + # from wrapping `t`. + with nontraceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + self.assertEqual(res, ref) + def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo/test_sys.py b/test/dynamo/test_sys.py index 2f7bd7178695..3b72ecb36d99 100644 --- a/test/dynamo/test_sys.py +++ b/test/dynamo/test_sys.py @@ -25,9 +25,10 @@ def fn(t): self.assertEqual(y, t.sin()) -class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase): +class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_sys.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py + @make_dynamo_test def test_exc_info_no_exception(self): self.assertEqual(sys.exc_info(), (None, None, None)) @@ -37,7 +38,6 @@ def test_exc_info_no_exception(self): def test_sys_exception_no_exception(self): self.assertEqual(sys.exception(), None) - @unittest.expectedFailure @make_dynamo_test def test_exc_info_with_exception_instance(self): def f(): @@ -54,7 +54,6 @@ def f(): self.assertIs(exc_info[1], e) self.assertIs(exc_info[2], e.__traceback__) - @unittest.expectedFailure @make_dynamo_test def test_exc_info_with_exception_type(self): def f(): @@ -71,7 +70,6 @@ def f(): self.assertIs(exc_info[1], e) self.assertIs(exc_info[2], e.__traceback__) - @unittest.expectedFailure @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_sys_exception_with_exception_instance(self): @@ -87,7 +85,6 @@ def f(): self.assertIsInstance(e, ValueError) self.assertIs(exc, e) - @unittest.expectedFailure @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_sys_exception_with_exception_type(self): diff --git a/test/dynamo/test_unittest.py b/test/dynamo/test_unittest.py new file mode 100644 index 000000000000..244785e01bcd --- /dev/null +++ b/test/dynamo/test_unittest.py @@ -0,0 +1,619 @@ +# Owner(s): ["module: dynamo"] +import sys +import unittest +import warnings +from itertools import product + +import torch +import torch._dynamo.test_case +from torch.testing._internal.common_utils import make_dynamo_test + + +class TestUnittest(torch._dynamo.test_case.TestCase): + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev + + @make_dynamo_test + def test_SkipTest(self): + z = 0 + SkipTest = unittest.SkipTest + try: + raise SkipTest("abcd") + except Exception: + z = 1 + self.assertEqual(z, 1) + + +class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py + # https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py + + @make_dynamo_test + def test_AlmostEqual(self): + self.assertAlmostEqual(1.00000001, 1.0) + self.assertNotAlmostEqual(1.0000001, 1.0) + self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0 + ) + + self.assertAlmostEqual(1.1, 1.0, places=0) + self.assertRaises( + self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1 + ) + + self.assertAlmostEqual(0, 0.1 + 0.1j, places=0) + self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1) + self.assertRaises( + self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1 + ) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0 + ) + + self.assertAlmostEqual(float("inf"), float("inf")) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf") + ) + + @make_dynamo_test + def test_AmostEqualWithDelta(self): + self.assertAlmostEqual(1.1, 1.0, delta=0.5) + self.assertAlmostEqual(1.0, 1.1, delta=0.5) + self.assertNotAlmostEqual(1.1, 1.0, delta=0.05) + self.assertNotAlmostEqual(1.0, 1.1, delta=0.05) + + self.assertAlmostEqual(1.0, 1.0, delta=0.5) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5 + ) + + self.assertRaises( + self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05 + ) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5 + ) + + self.assertRaises( + TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2 + ) + self.assertRaises( + TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2 + ) + + @make_dynamo_test + def test_assertRaises(self): + def _raise(e): + raise e + + self.assertRaises(KeyError, _raise, KeyError) + self.assertRaises(KeyError, _raise, KeyError("key")) + try: + self.assertRaises(KeyError, lambda: None) + except self.failureException as e: + self.assertIn("KeyError not raised", str(e)) + else: + self.fail("assertRaises() didn't fail") + try: + self.assertRaises(KeyError, _raise, ValueError) + except ValueError: + pass + else: + self.fail("assertRaises() didn't let exception pass through") + with self.assertRaises(KeyError) as cm: + try: + raise KeyError + except Exception as e: + exc = e + raise + self.assertIs(cm.exception, exc) + + with self.assertRaises(KeyError): + raise KeyError("key") + try: + with self.assertRaises(KeyError): + pass + except self.failureException as e: + self.assertIn("KeyError not raised", str(e)) + else: + self.fail("assertRaises() didn't fail") + try: + with self.assertRaises(KeyError): + raise ValueError + except ValueError: + pass + else: + self.fail("assertRaises() didn't let exception pass through") + + @make_dynamo_test + def testAssertNotRegex(self): + self.assertNotRegex("Ala ma kota", r"r+") + try: + self.assertNotRegex("Ala ma kota", r"k.t", "Message") + except self.failureException as e: + self.assertIn("Message", e.args[0]) + else: + self.fail("assertNotRegex should have failed.") + + +class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase): + """Test that the individual asserts honour longMessage. + This actually tests all the message behaviour for + asserts that use longMessage.""" + + def setUp(self): + super().setUp() + + class TestableTestFalse(unittest.TestCase): + longMessage = False + failureException = self.failureException + + def testTest(self): + pass + + class TestableTestTrue(unittest.TestCase): + longMessage = True + failureException = self.failureException + + def testTest(self): + pass + + self.testableTrue = TestableTestTrue("testTest") + self.testableFalse = TestableTestFalse("testTest") + + def testDefault(self): + self.assertTrue(unittest.TestCase.longMessage) + + def test_formatMsg(self): + self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo") + self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo") + + self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo") + self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo") + + # This blows up if _formatMessage uses string concatenation + self.testableTrue._formatMessage(object(), "foo") + + def test_formatMessage_unicode_error(self): + one = "".join(chr(i) for i in range(255)) + # this used to cause a UnicodeDecodeError constructing msg + self.testableTrue._formatMessage(one, "\uFFFD") + + def assertMessages(self, methodName, args, errors): + """ + Check that methodName(*args) raises the correct error messages. + errors should be a list of 4 regex that match the error when: + 1) longMessage = False and no msg passed; + 2) longMessage = False and msg passed; + 3) longMessage = True and no msg passed; + 4) longMessage = True and msg passed; + """ + + def getMethod(i): + useTestableFalse = i < 2 + if useTestableFalse: + test = self.testableFalse + else: + test = self.testableTrue + return getattr(test, methodName) + + for i, expected_regex in enumerate(errors): + testMethod = getMethod(i) + kwargs = {} + withMsg = i % 2 + if withMsg: + kwargs = {"msg": "oops"} + + # with self.assertRaisesRegex( + # self.failureException, expected_regex=expected_regex + # ): + # testMethod(*args, **kwargs) + with self.assertRaises(self.failureException) as cm: + testMethod(*args, **kwargs) + self.assertRegex(str(cm.exception), expected_regex) + + @make_dynamo_test + def testAssertTrue(self): + self.assertMessages( + "assertTrue", + (False,), + [ + "False is not true", + "oops", + "False is not true", + "False is not true : oops", + ], + ) + + @make_dynamo_test + def testAssertFalse(self): + self.assertMessages( + "assertFalse", + (True,), + [ + "True is not false", + "oops", + "True is not false", + "True is not false : oops", + ], + ) + + @make_dynamo_test + def testNotEqual(self): + self.assertMessages( + "assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"] + ) + + @make_dynamo_test + def testAlmostEqual(self): + self.assertMessages( + "assertAlmostEqual", + (1, 2), + [ + r"^1 != 2 within 7 places \(1 difference\)$", + "^oops$", + r"^1 != 2 within 7 places \(1 difference\)$", + r"^1 != 2 within 7 places \(1 difference\) : oops$", + ], + ) + + @make_dynamo_test + def testNotAlmostEqual(self): + self.assertMessages( + "assertNotAlmostEqual", + (1, 1), + [ + "^1 == 1 within 7 places$", + "^oops$", + "^1 == 1 within 7 places$", + "^1 == 1 within 7 places : oops$", + ], + ) + + @make_dynamo_test + def test_baseAssertEqual(self): + self.assertMessages( + "_baseAssertEqual", + (1, 2), + ["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertSequenceEqual(self): + # Error messages are multiline so not testing on full message + # assertTupleEqual and assertListEqual delegate to this method + self.assertMessages( + "assertSequenceEqual", + ([], [None]), + [r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"], + ) + + @make_dynamo_test + def testAssertSetEqual(self): + self.assertMessages( + "assertSetEqual", + (set(), set([None])), # noqa: C405 + ["None$", "^oops$", "None$", "None : oops$"], + ) + + @make_dynamo_test + def testAssertIn(self): + self.assertMessages( + "assertIn", + (None, []), + [ + r"^None not found in \[\]$", + "^oops$", + r"^None not found in \[\]$", + r"^None not found in \[\] : oops$", + ], + ) + + @make_dynamo_test + def testAssertNotIn(self): + self.assertMessages( + "assertNotIn", + (None, [None]), + [ + r"^None unexpectedly found in \[None\]$", + "^oops$", + r"^None unexpectedly found in \[None\]$", + r"^None unexpectedly found in \[None\] : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertDictEqual(self): + self.assertMessages( + "assertDictEqual", + ({}, {"key": "value"}), + [ + r"\+ \{'key': 'value'\}$", + "^oops$", + r"\+ \{'key': 'value'\}$", + r"\+ \{'key': 'value'\} : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertMultiLineEqual(self): + self.assertMessages( + "assertMultiLineEqual", + ("", "foo"), + [r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"], + ) + + @make_dynamo_test + def testAssertLess(self): + self.assertMessages( + "assertLess", + (2, 1), + [ + "^2 not less than 1$", + "^oops$", + "^2 not less than 1$", + "^2 not less than 1 : oops$", + ], + ) + + @make_dynamo_test + def testAssertLessEqual(self): + self.assertMessages( + "assertLessEqual", + (2, 1), + [ + "^2 not less than or equal to 1$", + "^oops$", + "^2 not less than or equal to 1$", + "^2 not less than or equal to 1 : oops$", + ], + ) + + @make_dynamo_test + def testAssertGreater(self): + self.assertMessages( + "assertGreater", + (1, 2), + [ + "^1 not greater than 2$", + "^oops$", + "^1 not greater than 2$", + "^1 not greater than 2 : oops$", + ], + ) + + @make_dynamo_test + def testAssertGreaterEqual(self): + self.assertMessages( + "assertGreaterEqual", + (1, 2), + [ + "^1 not greater than or equal to 2$", + "^oops$", + "^1 not greater than or equal to 2$", + "^1 not greater than or equal to 2 : oops$", + ], + ) + + @make_dynamo_test + def testAssertIsNone(self): + self.assertMessages( + "assertIsNone", + ("not None",), + [ + "^'not None' is not None$", + "^oops$", + "^'not None' is not None$", + "^'not None' is not None : oops$", + ], + ) + + @make_dynamo_test + def testAssertIsNotNone(self): + self.assertMessages( + "assertIsNotNone", + (None,), + [ + "^unexpectedly None$", + "^oops$", + "^unexpectedly None$", + "^unexpectedly None : oops$", + ], + ) + + @make_dynamo_test + def testAssertIs(self): + self.assertMessages( + "assertIs", + (None, "foo"), + [ + "^None is not 'foo'$", + "^oops$", + "^None is not 'foo'$", + "^None is not 'foo' : oops$", + ], + ) + + @make_dynamo_test + def testAssertIsNot(self): + self.assertMessages( + "assertIsNot", + (None, None), + [ + "^unexpectedly identical: None$", + "^oops$", + "^unexpectedly identical: None$", + "^unexpectedly identical: None : oops$", + ], + ) + + @make_dynamo_test + def testAssertRegex(self): + self.assertMessages( + "assertRegex", + ("foo", "bar"), + [ + "^Regex didn't match:", + "^oops$", + "^Regex didn't match:", + "^Regex didn't match: (.*) : oops$", + ], + ) + + @make_dynamo_test + def testAssertNotRegex(self): + self.assertMessages( + "assertNotRegex", + ("foo", "foo"), + [ + "^Regex matched:", + "^oops$", + "^Regex matched:", + "^Regex matched: (.*) : oops$", + ], + ) + + def assertMessagesCM(self, methodName, args, func, errors): + """ + Check that the correct error messages are raised while executing: + with method(*args): + func() + *errors* should be a list of 4 regex that match the error when: + 1) longMessage = False and no msg passed; + 2) longMessage = False and msg passed; + 3) longMessage = True and no msg passed; + 4) longMessage = True and msg passed; + """ + p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"})) + for (cls, kwargs), err in zip(p, errors): + method = getattr(cls, methodName) + # with self.assertRaisesRegex(cls.failureException, err): + with self.assertRaises(cls.failureException) as c: + with method(*args, **kwargs) as cm: # noqa: F841 + func() + self.assertRegex(str(c.exception), err) + + @make_dynamo_test + def testAssertRaises(self): + self.assertMessagesCM( + "assertRaises", + (TypeError,), + lambda: None, + [ + "^TypeError not raised$", + "^oops$", + "^TypeError not raised$", + "^TypeError not raised : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertRaisesRegex(self): + self.assertMessagesCM( + "assertRaisesRegex", + (TypeError, "unused regex"), + lambda: None, + [ + "^TypeError not raised$", + "^oops$", + "^TypeError not raised$", + "^TypeError not raised : oops$", + ], + ) + + # test error raised but with wrong message + def raise_wrong_message(): + raise TypeError("foo") + + self.assertMessagesCM( + "assertRaisesRegex", + (TypeError, "regex"), + raise_wrong_message, + [ + '^"regex" does not match "foo"$', + "^oops$", + '^"regex" does not match "foo"$', + '^"regex" does not match "foo" : oops$', + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertWarns(self): + self.assertMessagesCM( + "assertWarns", + (UserWarning,), + lambda: None, + [ + "^UserWarning not triggered$", + "^oops$", + "^UserWarning not triggered$", + "^UserWarning not triggered : oops$", + ], + ) + + @unittest.expectedFailure + @unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13") + @make_dynamo_test + def test_assertNotWarns(self): + def warn_future(): + warnings.warn("xyz", FutureWarning, stacklevel=2) + + self.assertMessagesCM( + "_assertNotWarns", + (FutureWarning,), + warn_future, + [ + "^FutureWarning triggered$", + "^oops$", + "^FutureWarning triggered$", + "^FutureWarning triggered : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertWarnsRegex(self): + # test error not raised + self.assertMessagesCM( + "assertWarnsRegex", + (UserWarning, "unused regex"), + lambda: None, + [ + "^UserWarning not triggered$", + "^oops$", + "^UserWarning not triggered$", + "^UserWarning not triggered : oops$", + ], + ) + + # test warning raised but with wrong message + def raise_wrong_message(): + warnings.warn("foo") + + self.assertMessagesCM( + "assertWarnsRegex", + (UserWarning, "regex"), + raise_wrong_message, + [ + '^"regex" does not match "foo"$', + "^oops$", + '^"regex" does not match "foo"$', + '^"regex" does not match "foo" : oops$', + ], + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index dd5a5c4593eb..9f51c11a87c6 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -337,6 +337,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'non_compliant_ops': set(), 'num_graph_breaks': 0, 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, 'post_grad_pass_time_us': 0, 'pre_grad_pass_time_us': 0, 'python_version': None, @@ -424,6 +426,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'non_compliant_ops': None, 'num_graph_breaks': 0, 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, 'post_grad_pass_time_us': 0, 'pre_grad_pass_time_us': None, 'python_version': None, diff --git a/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is new file mode 100644 index 000000000000..f243ff1904b0 --- /dev/null +++ b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is @@ -0,0 +1,10 @@ +- Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) +or properly graph break on it rather than skipping the frame altogether. +https://github.com/pytorch/pytorch/issues/128942 + +Fundamental issue is Dynamo tries to probe tensor object properties, but that +could trigger user-defined `__torch_function__` for tensor subclass objects. + +In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the +`__init__` of the following class, but `self._data = data` hasn't fired yet, and +its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines_exit b/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines_exit deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions b/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions_recursive b/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions_recursive deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op b/test/dynamo_expected_failures/TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones index e69de29bb2d1..24f34ca8e8e6 100644 --- a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones +++ b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/149975 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator index e69de29bb2d1..880a24b122bb 100644 --- a/test/dynamo_expected_failures/TestIterator.test_iterator +++ b/test/dynamo_expected_failures/TestIterator.test_iterator @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/150005 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer new file mode 100644 index 000000000000..89dda61098d2 --- /dev/null +++ b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer @@ -0,0 +1 @@ +Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer new file mode 100644 index 000000000000..89dda61098d2 --- /dev/null +++ b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer @@ -0,0 +1 @@ +Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestNamedTuple.test_max b/test/dynamo_expected_failures/TestNamedTuple.test_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPickle.test_pickle b/test/dynamo_expected_failures/TestPickle.test_pickle deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_backward_with_inputs b/test/dynamo_expected_failures/TestScript.test_python_frontend similarity index 100% rename from test/dynamo_expected_failures/TestAutograd.test_backward_with_inputs rename to test/dynamo_expected_failures/TestScript.test_python_frontend diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines b/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 similarity index 100% rename from test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines rename to test/dynamo_expected_failures/TestScript.test_python_frontend_py3 diff --git a/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear new file mode 100644 index 000000000000..276a4f74bbca --- /dev/null +++ b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/149881 diff --git a/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash new file mode 100644 index 000000000000..beb4bf5d003a --- /dev/null +++ b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash @@ -0,0 +1,10 @@ +Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) +or properly graph break on it rather than skipping the frame altogether. +https://github.com/pytorch/pytorch/issues/128942 + +Fundamental issue is Dynamo tries to probe tensor object properties, but that +could trigger user-defined `__torch_function__` for tensor subclass objects. + +In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the +`__init__` of the following class, but `self._diag = _diag` hasn't fired yet, and +its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass new file mode 100644 index 000000000000..c2ddc08d1e40 --- /dev/null +++ b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass @@ -0,0 +1,3 @@ +Dynamo cannot query properties of the tensor subclass object when wrapping it +into a VT, because it has a `__torch_function__` that only allows limited +torch ops. diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 3faa1186562f..7cb72bda99ae 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -383,6 +383,7 @@ aten::_fw_primal_copy aten::_fw_primal_copy.out aten::_grid_sampler_2d_cpu_fallback aten::_grid_sampler_2d_cpu_fallback.out +aten::_grouped_mm aten::_has_same_storage_numel aten::_histogramdd_bin_edges aten::_histogramdd_bin_edges.out @@ -650,6 +651,7 @@ aten::_values_copy aten::_values_copy.out aten::_weight_int4pack_mm aten::_weight_int4pack_mm_for_cpu +aten::_weight_int4pack_mm_with_scales_and_zeros aten::_weight_int8pack_mm aten::_weight_norm_interface_backward aten::_weight_norm_interface_backward.out diff --git a/test/export/test_db.py b/test/export/test_db.py index 7c8c1860bc5b..a035bdd23916 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -99,6 +99,7 @@ def test_exportdb_not_supported_rewrite( rewrite_case.example_args, rewrite_case.example_kwargs, dynamic_shapes=rewrite_case.dynamic_shapes, + strict=True, ) diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 6fda3fcdb0ad..1f23cb5cee4b 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -303,9 +303,7 @@ def forward(self, a): report = ep._report self.assertEqual(len(report.failures), 1) - self.assertEqual( - report.failures[0].failure_type, FailureType.CONSTRAINT_VIOLATION_ERROR - ) + self.assertEqual(report.failures[0].failure_type, FailureType.GUARD_ADDED) inp = (torch.randn(3, 3),) self.assertEqual(ep.module()(*inp), M()(*inp)) diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index f95484f0a128..bd68bb7cd772 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -60,7 +60,9 @@ def _check_equality_and_annotations(m_func, inps): ) # ExportedProgram from original module. - original_exported_module = torch.export.export_for_training(m_func(), inps) + original_exported_module = torch.export.export_for_training( + m_func(), inps, strict=True + ) # Check whether input annotations are the same as tracing the original module. orig_ph_name_list = [ @@ -116,7 +118,7 @@ def forward(self, x): m = Module() example_inputs = (torch.randn(3),) m(*example_inputs) - ep = torch.export.export_for_training(m, example_inputs) + ep = torch.export.export_for_training(m, example_inputs, strict=True) joint_ep = _export_forward_backward(ep) self.assertExpectedInline( str(joint_ep.graph_module.code).strip(), @@ -226,7 +228,7 @@ def forward(self, x): example_inputs = (torch.randn(3),) m(*example_inputs) ep = torch.export.export_for_training( - m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}} + m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}}, strict=True ) _export_forward_backward(ep) @@ -261,7 +263,7 @@ def forward(self, x, labels): labels = torch.ones(4, dtype=torch.int64) inputs = (x, labels) - ep = export_for_training(net, inputs) + ep = export_for_training(net, inputs, strict=True) ep = _export_forward_backward(ep) def test_joint_loss_index(self): @@ -281,7 +283,7 @@ def forward(self, x): inputs = (torch.randn(4, 4),) for i in [0, 1]: - ep = export_for_training(Foo(i), inputs) + ep = export_for_training(Foo(i), inputs, strict=True) ep_joint = _export_forward_backward(ep, joint_loss_index=i) for j, spec in enumerate(ep_joint.graph_signature.output_specs): if i == j: diff --git a/test/export/test_export.py b/test/export/test_export.py index 104133d379b3..f4898783fb3e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -227,6 +227,12 @@ def is_non_strict_legacy_test(test_name): return test_name.endswith(LEGACY_EXPORT_NONSTRICT_SUFFIX) +def is_legacy_test(test_name): + return test_name.endswith(LEGACY_EXPORT_NONSTRICT_SUFFIX) or test_name.endswith( + LEGACY_EXPORT_STRICT_SUFFIX + ) + + def is_retracebility_test(test_name): return test_name.endswith(RETRACEABILITY_STRICT_SUFFIX) or test_name.endswith( RETRACEABILITY_NON_STRICT_SUFFIX @@ -2377,6 +2383,140 @@ def forward(self, x, y, z): ): export(Foo(), inputs, dynamic_shapes=shapes) + def test_dim_dynamic_specialization(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x + 2 + + # 0/1 specialization + with self.assertRaisesRegex( + ValueError, + r"Received user-specified dim hint Dim.DYNAMIC.*" + r"but tracing inferred a static shape of 0 for dimension " + r"inputs\['x'\]\.shape\[0\](.*\n)*.*" + r"Received user-specified dim hint Dim.DYNAMIC.*" + r"but tracing inferred a static shape of 1 for dimension " + r"inputs\['x'\]\.shape\[1\].*", + ): + export( + Foo(), + (torch.randn(0, 1),), + dynamic_shapes={ + "x": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + }, + ) + + class Bar(torch.nn.Module): + def forward(self, x): + assert x.shape[0] <= 32 + return x + 2 + + # static specialization + with self.assertRaisesRegex( + ValueError, + r"Received user-specified dim hint Dim.DYNAMIC.*" + r"but tracing inferred a static shape of 32 for dimension " + r"inputs\['x'\]\.shape\[0\](.*\n)*.*", + ): + export( + Bar(), + (torch.randn(32),), + dynamic_shapes={ + "x": {0: Dim.DYNAMIC(min=32)}, + }, + ) + + def test_dim_hint_ranges(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + return x + y + + inputs = ( + torch.randn(6, 4), + torch.randn(6, 4), + ) + shapes = { + "x": (Dim.AUTO(min=4), Dim.AUTO), + "y": (Dim.DYNAMIC(max=16), Dim.AUTO(max=32)), + } + ep = export(Foo(), inputs, dynamic_shapes=shapes) + ep.module()(torch.randn(8, 5), torch.randn(8, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be >= 4, but got 3" + ): + ep.module()(torch.randn(3, 5), torch.randn(3, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be <= 16, but got 17" + ): + ep.module()(torch.randn(17, 5), torch.randn(17, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be <= 32, but got 33" + ): + ep.module()(torch.randn(9, 33), torch.randn(9, 33)) + + def test_dim_hint_range_violations(self): + class Foo(torch.nn.Module): + def forward(self, xs): + x, y = xs["data"][0] + assert y.shape[0] <= 32 + return x[6:], y + 2 + + x, y = torch.randn(8), torch.randn(8) + + # conflict with lower bound + shapes = torch.export.ShapesCollection() + shapes[x] = [Dim.DYNAMIC(max=5)] + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[None, 5\], conflicting with the inferred .*" + r"\[6, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]", + ): + export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes) + + # conflict with upper bound + shapes = torch.export.ShapesCollection() + shapes[y] = [Dim.AUTO(min=48, max=62)] + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[48, 62\], conflicting with the inferred .*" + r"\[2, 32\],.* for inputs\['xs'\]\['data'\]\[0\]\[1\]\.shape\[0\]", + ): + export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes) + + class Bar(torch.nn.Module): + def forward(self, x): + return x + 2 + + # conflict with static range + shapes = {"x": [Dim.STATIC(min=6, max=8)]} + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[6, 8\], conflicting with the inferred .*" + r"\[4, 4\],.* for inputs\['x'\].shape\[0\]", + ): + export(Bar(), (torch.randn(4),), dynamic_shapes=shapes) + + # multiple conflicts + class Moo(torch.nn.Module): + def forward(self, x, y): + assert x.shape[0] <= 32 + assert y.shape[0] >= 128 + return x + 2, y + 2 + + inps = (torch.randn(16), torch.randn(256)) + shapes = { + "x": (Dim.DYNAMIC(min=33),), + "y": (Dim.DYNAMIC(max=127),), + } + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[33, None\], conflicting with the inferred .*" + r"\[2, 32\],.* for inputs\['x'\].shape\[0\](.*\n)*.*" + r"Received user-specified .* \[None, 127\], conflicting with the inferred .*" + r"\[128, int_oo\],.* for inputs\['y'\].shape\[0\]", + ): + export(Moo(), inps, dynamic_shapes=shapes) + def test_torch_fn(self): class M1(torch.nn.Module): def __init__(self) -> None: @@ -3768,7 +3908,6 @@ def forward(self, x, y, z): if node.op == "placeholder": self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") - @testing.expectedFailureRetraceability def test_dynamic_shapes_builder_pytree(self): torch.export.register_dataclass( Inp1, @@ -3797,6 +3936,62 @@ def forward(self, inp: Inp1): if node.op == "placeholder": self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") + def test_dynamic_shapes_inferred_basic(self): + class M(torch.nn.Module): + def forward(self, x, y, z): + # x and y[0] must have same dynamic shape (say `dim`) >= 3 + tmp = (x + y[0])[:3] + # z["k"] must have static shape = 3 + return tmp * z["k"] + + m = M() + args = (torch.randn(4), [torch.randn(4)], {"k": torch.randn(3)}) + + additional_inputs = torch.export.AdditionalInputs() + # 4->5, 4->5, 3->3 + good_args = (torch.randn(5), [torch.randn(5)], {"k": torch.randn(3)}) + additional_inputs.add(good_args) + + ep = export(m, args, dynamic_shapes=additional_inputs) + got_shapes = [ + str(tuple(node.meta["val"].shape)) + for node in ep.graph.find_nodes(op="placeholder") + ] + dim = next(iter(ep.range_constraints.keys())) + expected_shapes = [f"({dim},)", f"({dim},)", "(3,)"] + self.assertEqual(got_shapes, expected_shapes) + + def expect_error(bad_args, run_time_msg, compile_time_msg): + with self.assertRaisesRegex(RuntimeError, run_time_msg): + ep.module()(*bad_args) + + additional_inputs = torch.export.AdditionalInputs() + additional_inputs.add(bad_args) + + with self.assertRaisesRegex(RuntimeError, compile_time_msg): + export(m, args, dynamic_shapes=additional_inputs) + + expect_error( + # 4->2, 4->2, 3->3 + bad_args=(torch.randn(2), [torch.randn(2)], {"k": torch.randn(3)}), + run_time_msg="Expected input.*to be >= 3, but got 2", + compile_time_msg="Expected input.*to be >= 3, but got 2", + ) + + expect_error( + # 4->6, 4->7, 3->3 + bad_args=(torch.randn(6), [torch.randn(7)], {"k": torch.randn(3)}), + run_time_msg="Expected input.*to be equal to 6, but got 7", + compile_time_msg="Expected input.*to be equal to 6, but got 7", + ) + + expect_error( + # 4->5, 4->5, 3->4 + bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}), + run_time_msg="Expected input.*to be equal to 3, but got 4", + compile_time_msg=r"Constraints violated.*\n.*was inferred to be a constant \(3\)", + ) + def test_mismatched_dynamic_shapes(self): AUTO, STATIC = Dim.AUTO, Dim.STATIC @@ -4649,6 +4844,24 @@ def forward(self, scores, score_thr, topk: torch.Tensor, results=None): self.assertTrue(torch.allclose(orig_res[1], ep_res[1])) self.assertTrue(torch.allclose(orig_res[2], ep_res[2])) + def test_multidimensional_slicing(self): + class M(torch.nn.Module): + def forward(self, x, y): + b = x.item() + torch._check(b >= 0) + torch._check(b < y.shape[0]) + return y[0, b] + + if is_non_strict_test(self._testMethodName): + m = M() + inp = (torch.tensor(4), torch.ones(10, 10)) + r = m(*inp) + + epm = export(m, inp).module() + er = epm(*inp) + + self.assertTrue(torch.allclose(er, r)) + def test_sequential_slicing(self): # See https://github.com/pytorch/pytorch/issues/137455 @@ -4950,7 +5163,6 @@ def forward(self, x): ): self.assertTrue("source_fn_stack" in node.meta) - @testing.expectedFailureRetraceability def test_dynamic_shapes_dataclass(self): torch.export.register_dataclass( Inp2, @@ -5692,14 +5904,37 @@ class Module(torch.nn.Module): def forward(self, x): return x.to("cpu") - ep = export(Module(), (torch.tensor(1, device="cpu"),)).run_decompositions({}) + ep = export(Module(), (torch.tensor(1, device="cpu"),)) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - for op in ops: - self.assertIn(op, (torch.ops.aten._to_copy.default,)) + + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # aten.to will just specialize by decomposing to a no-op + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + ], + ) + + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 1) def test_device_to_dynamic(self): class Module(torch.nn.Module): @@ -5710,14 +5945,37 @@ def forward(self, x): Module(), (torch.tensor([1, 2], device="cpu"),), dynamic_shapes={"x": {0: Dim("i")}}, - ).run_decompositions({}) + ) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - for op in ops: - self.assertIn(op, (torch.ops.aten._to_copy.default,)) + + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # aten.to will just specialize by decomposing to a no-op + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + ], + ) + + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 1) def test_device_to_mutation(self): class Module(torch.nn.Module): @@ -5726,10 +5984,102 @@ def forward(self, x): y.add_(1) return y, x - with self.assertRaisesRegex( - RuntimeError, "cannot mutate tensors with frozen storage" + ep = export(Module(), (torch.tensor(1, device="cpu"),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # aten.to decomposes to no-op, add_ decomposes to functional variant + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.add.Tensor, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + torch.ops.aten.add_.Tensor, + ], + ) + + # test mutation + x = torch.tensor(2, device="cpu") + y, _ = ep.module()(x) + self.assertEqual(x.item(), 3) + self.assertEqual(id(y), id(x)) + + # test decomp ep + ep = ep.run_decompositions({}) + for node in ep.graph.nodes: + if node.op == "call_function": + self.assertNotEqual(node.target, torch.ops.aten.to.dtype_layout) + + # test mutation for decomposed program + y, _ = ep.module()(x) + self.assertEqual(x.item(), 4) + self.assertEqual(id(y), id(x)) + + @requires_gpu + @testing.expectedFailureCppRuntime + def test_device_to_gpu(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x.to("cpu") + + ep = export(Foo(), (torch.randn(64).cuda(),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName ): - export(Module(), (torch.tensor(1, device="cpu"),)).run_decompositions({}) + # aten.to decomposes to _to_copy + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + ], + ) + + # Check device assertion + with self.assertRaisesRegex(RuntimeError, "Tensor device mismatch!"): + ep.module()(torch.randn(64)) + + ep = ep.run_decompositions() + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 2) + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + + # Check device assertion again after decomp + with self.assertRaisesRegex(RuntimeError, "Tensor device mismatch!"): + ep.module()(torch.randn(64)) def test_tensor_constant_aten_to(self): class Module(torch.nn.Module): @@ -5757,40 +6107,96 @@ class Module(torch.nn.Module): def forward(self, x): return x.float() - ep = export(Module(), (torch.tensor(1, dtype=torch.float),)).run_decompositions( - {} - ) + ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - for op in ops: - self.assertIn(op, (torch.ops.aten._to_copy.default,)) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # .float() decomposes to no-op + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + ], + ) + + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 1) + + # test aliasing + x = torch.tensor(1, dtype=torch.float) + out = ep.module()(x) + self.assertEqual(id(x), id(out)) def test_float_conversion_from_int(self): class Module(torch.nn.Module): def forward(self, x): return x.float() - ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions( - {} - ) + ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - self.assertIn(torch.ops.aten._to_copy.default, ops) - self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops) - - self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # .float() decomposes to _to_copy() + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + ], + ) # Raises error because the input dtype is not the same as the input # tensor when exporting. with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"): ep.module()(torch.tensor(1, dtype=torch.float32)) + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + + # Check dtype assertion again after decomp + with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"): + ep.module()(torch.tensor(1, dtype=torch.float32)) + + self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1) + def test_device_to_mutation_float(self): class Module(torch.nn.Module): def forward(self, x): @@ -5798,12 +6204,48 @@ def forward(self, x): y.add_(1) return y, x - with self.assertRaisesRegex( - RuntimeError, "cannot mutate tensors with frozen storage" + ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName ): - export(Module(), (torch.tensor(1, dtype=torch.float),)).run_decompositions( - {} + # aten.to decomposes to no-op, add_ decomposes to functional variant + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.add.Tensor, + ], ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + torch.ops.aten.add_.Tensor, + ], + ) + + # test mutation + x = torch.tensor(2, dtype=torch.float) + y, _ = ep.module()(x) + self.assertEqual(x.item(), 3.0) + self.assertEqual(id(y), id(x)) + + # test decomp ep + ep = ep.run_decompositions({}) + for node in ep.graph.nodes: + if node.op == "call_function": + self.assertNotEqual(node.target, torch.ops.aten.to.dtype) + + # test mutation for decomposed program + y, _ = ep.module()(x) + self.assertEqual(x.item(), 4.0) + self.assertEqual(id(y), id(x)) def test_module(self): class MyLinear(torch.nn.Module): @@ -6843,6 +7285,8 @@ def forward(self, x): len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 ) + # scan is not supported in sigmoid yet + @testing.expectedFailureCppRuntime def test_export_scan_pytree_output(self): def add(carry, accum): return carry + carry, (accum[0]["moo"] + 1, accum[0]["moo2"] + 1) @@ -6857,7 +7301,6 @@ def forward(self, init, accum): self.assertEqual(ep.module()(init, xs), M()(init, xs)) # map_fn references module outside the module hierarchy - @unittest.expectedFailure def test_map_buffers(self): class M1(torch.nn.Module): def __init__(self) -> None: @@ -6997,7 +7440,6 @@ def forward(self): ep = export(m, ()) self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"]) - @testing.expectedFailureRetraceability def test_preserve_shape_dynamism_for_unused_inputs(self): torch.export.register_dataclass( Inp3, @@ -7094,14 +7536,6 @@ def check(inputs, epm): # output shape is (3, 2), with n_row 3 and n_sample 2 <= dist_size 2 check(inputs, epm) - inputs = ( - torch.tensor([[4, 5], [6, 7], [8, 9], [10, 11]], dtype=torch.float32), - torch.ones(1, dtype=torch.int64), - ) - epm = exported_module(inputs) - # output shape is (4, 1), with n_row 4 and n_sample 1 <= dist_size 2 - check(inputs, epm) - inputs = ( torch.tensor([[4, 5], [6, 7], [8, 9]], dtype=torch.float32), torch.ones(3, dtype=torch.int64), diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index f5a324c7afdb..9e7a0793879d 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -16,6 +16,7 @@ import torch import torch._dynamo as torchdynamo +import torch._export.serde.schema as schema import torch.export._trace import torch.utils._pytree as pytree from torch._export.db.case import ExportCase, SupportLevel @@ -99,7 +100,7 @@ def op_schema(cls, op): return torch.ops.aten.add.Tensor._schema inp = (torch.ones(10),) - ep = export_for_training(TestModule(), inp) + ep = export_for_training(TestModule(), inp, strict=True) # Register the custom op handler. foo_custom_op = FooExtensionOp() @@ -164,7 +165,9 @@ def forward(self, x, y, use_p=False): model = MyModule().eval() random_inputs = (torch.rand([2, 3]), torch.rand([2, 3])) - exp_program = export_for_training(model, random_inputs, {"use_p": True}) + exp_program = export_for_training( + model, random_inputs, {"use_p": True}, strict=True + ) output_buffer = io.BytesIO() # Tests that example inputs are preserved when saving and loading module. @@ -183,7 +186,7 @@ class M(torch.nn.Module): def forward(self, x): return x.sin() - exp_program = export_for_training(M(), (torch.randn(4, 4),)) + exp_program = export_for_training(M(), (torch.randn(4, 4),), strict=True) output_buffer = io.BytesIO() # Tests that example forward arg names are preserved when saving and loading module. @@ -223,7 +226,7 @@ def forward(self, x): inp = (torch.ones(10),) # Module will only be able to roundtrip if metadata # can be correctly parsed. - ep = export_for_training(MyModule(), inp) + ep = export_for_training(MyModule(), inp, strict=True) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) @@ -287,7 +290,7 @@ def forward(self, x): # Check that module can be roundtripped, thereby confirming proper deserialization. inp = (torch.ones(10),) - ep = export_for_training(MyModule(), inp) + ep = export_for_training(MyModule(), inp, strict=True) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) @@ -317,6 +320,7 @@ def forward(self, x, w, b): torch.ones([512]), torch.ones([512]), ), + strict=True, ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) @@ -354,7 +358,10 @@ def forward(self, a, b, c) -> torch.Tensor: "c": {0: dim0_ac, 1: dim1_bc}, } exported_module = export_for_training( - DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes + DynamicShapeSimpleModel(), + inputs, + dynamic_shapes=dynamic_shapes, + strict=True, ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) sym_size_nodes = [ @@ -415,7 +422,10 @@ def forward(self, a, b, c) -> torch.Tensor: "c": {0: dim0_ac, 1: dim1_bc}, } exported_module = export_for_training( - DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes + DynamicShapeSimpleModel(), + inputs, + dynamic_shapes=dynamic_shapes, + strict=True, ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) for v in serialized.exported_program.range_constraints.values(): @@ -441,7 +451,9 @@ def forward(self, x): return torch.split(x, 2) input = torch.arange(10.0).reshape(5, 2) - exported_module = export_for_training(MyModule(), (input,)).run_decompositions() + exported_module = export_for_training( + MyModule(), (input,), strict=True + ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] @@ -503,8 +515,7 @@ def forward(self, x): return torch.ops.aten.var_mean.correction(x, [1])[0] exported_module = export_for_training( - MyModule(), - (torch.ones([512, 512], requires_grad=True),), + MyModule(), (torch.ones([512, 512], requires_grad=True),), strict=True ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) @@ -525,7 +536,7 @@ def forward(self, x): return x + x ep = export_for_training( - M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},) + M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},), strict=True ) range_constraints = list(ep.range_constraints.keys()) @@ -560,7 +571,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f = Foo() x, _ = torch.sort(torch.randn(3, 4)) - exported_module = export_for_training(f, (x,)).run_decompositions() + exported_module = export_for_training(f, (x,), strict=True).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] @@ -578,7 +589,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: b = x + y return b + a - ep = export_for_training(Module(), (torch.randn(3, 2), torch.randn(3, 2))) + ep = export_for_training( + Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True + ) s = ExportedProgramSerializer().serialize(ep) c = canonicalize(s.exported_program) g = c.graph_module.graph @@ -592,7 +605,7 @@ class M(torch.nn.Module): def forward(self, x): return torch.ops.aten.sum.dim_IntList(x, []) - ep = torch.export.export_for_training(M(), (torch.randn(3, 2),)) + ep = torch.export.export_for_training(M(), (torch.randn(3, 2),), strict=True) serialized = ExportedProgramSerializer().serialize(ep) for node in serialized.exported_program.graph_module.graph.nodes: if "aten.sum.dim_IntList" in node.target: @@ -918,6 +931,32 @@ def forward(self, a, b, c): inp = (torch.ones(3, 3), torch.ones(3, 3), torch.tensor(2)) self.check_graph(Mod(), inp, use_pre_dispatch=False) + def test_none_input(self): + """ + Testing a backwards-compatibility breakage where old models do not have + an input spec with the node name. + """ + + class M(torch.nn.Module): + def forward(self, x, y, z): + return x + z + + ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3))) + + serialized_program = ExportedProgramSerializer(None, 2).serialize(ep) + serialized_program.exported_program.graph_module.signature.input_specs[ + 1 + ] = schema.InputSpec.create( + user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True)) + ) + ep = ExportedProgramDeserializer(None).deserialize( + serialized_program.exported_program, {}, {}, {} + ) + ep.graph_module.recompile() + unflattened = torch.export.unflatten(ep) + inp = (torch.rand(3, 3), None, torch.rand(3, 3)) + self.assertEqual(unflattened(*inp), M()(*inp)) + def test_multi_return(self) -> None: """ Test multiple return from a single node (ex. layer_norm has 2 outputs) @@ -1233,7 +1272,7 @@ def forward(self, x): a = a * 2 return a, b - ep = torch.export.export_for_training(M(), (torch.ones(3),)) + ep = torch.export.export_for_training(M(), (torch.ones(3),), strict=True) # insert another getitem node for node in ep.graph.nodes: @@ -1379,7 +1418,7 @@ def __init__(self) -> None: def forward(self): return self.p * self.p - ep = torch.export.export_for_training(M(), ()) + ep = torch.export.export_for_training(M(), (), strict=True) ep._example_inputs = None roundtrip_ep = deserialize(serialize(ep)) self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) @@ -1407,7 +1446,7 @@ def forward(self, x): return x + x f = Module() - ep = export_for_training(f, (torch.randn(1, 3),)) + ep = export_for_training(f, (torch.randn(1, 3),), strict=True) serialized_program = ExportedProgramSerializer().serialize(ep) serialized_program.exported_program.schema_version.major = -1 @@ -1443,7 +1482,7 @@ def forward(self, x): y = self.linear(y) return y - ep = export_for_training(Module(), inp) + ep = export_for_training(Module(), inp, strict=True) buffer = io.BytesIO() save(ep, buffer) @@ -1460,7 +1499,7 @@ def forward(self, x): f = Foo() inp = (torch.randn(2, 2),) - ep = export_for_training(f, inp) + ep = export_for_training(f, inp, strict=True) with tempfile.NamedTemporaryFile() as f: save(ep, f) @@ -1477,7 +1516,7 @@ def forward(self, x, y): f = Foo() inp = (torch.tensor([6]), torch.tensor([7])) - ep = export_for_training(f, inp) + ep = export_for_training(f, inp, strict=True) with TemporaryFileName() as fname: path = Path(fname) @@ -1495,7 +1534,7 @@ def forward(self, x): f = Foo() - ep = export_for_training(f, inp) + ep = export_for_training(f, inp, strict=True) buffer = io.BytesIO() save(ep, buffer, extra_files={"extra.txt": "moo"}) @@ -1513,7 +1552,7 @@ def forward(self, x): f = Foo() - ep = export_for_training(f, (torch.randn(1, 3),)) + ep = export_for_training(f, (torch.randn(1, 3),), strict=True) with self.assertRaisesRegex( RuntimeError, r"Serialized version .* does not match our current" @@ -1539,7 +1578,7 @@ def forward(self, x): list_tensor = [torch.tensor(3), torch.tensor(4)] return x + self.a + list_tensor[0] + list_tensor[1] - ep = export_for_training(Foo(), (torch.tensor(1),)) + ep = export_for_training(Foo(), (torch.tensor(1),), strict=True) buffer = io.BytesIO() save(ep, buffer) buffer.seek(0) @@ -1565,7 +1604,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) # Replace one of the values with an instance of our custom class for node in ep.graph.nodes: @@ -1673,7 +1712,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} @@ -1708,7 +1747,7 @@ def forward(self, x): f = Foo() inputs = (torch.ones(2, 2),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} @@ -1744,7 +1783,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} diff --git a/test/export/test_unflatten_training_ir.py b/test/export/test_unflatten_training_ir.py index 684d9a149ecf..6816787eff22 100644 --- a/test/export/test_unflatten_training_ir.py +++ b/test/export/test_unflatten_training_ir.py @@ -14,7 +14,7 @@ def mocked_training_ir_export(*args, **kwargs): - return export_for_training(*args, **kwargs) + return export_for_training(*args, **kwargs, strict=True) def make_dynamic_cls(cls): diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py index dd3d18db1cda..5d3cfd564637 100644 --- a/test/export/test_verifier.py +++ b/test/export/test_verifier.py @@ -20,7 +20,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export_for_training(f, (torch.randn(100), torch.randn(100))) + ep = export_for_training(f, (torch.randn(100), torch.randn(100)), strict=True) verifier = Verifier() verifier.check(ep) @@ -48,7 +48,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() ep = export_for_training( - f, (torch.randn(100), torch.randn(100)) + f, (torch.randn(100), torch.randn(100)), strict=True ).run_decompositions({}) for node in ep.graph.nodes: if node.target == torch.ops.aten.add.Tensor: @@ -72,7 +72,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3))) + ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True) verifier = Verifier() verifier.check(ep) @@ -92,7 +92,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() ep = export_for_training( - f, (torch.randn(3, 3), torch.randn(3, 3)) + f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True ).run_decompositions({}) for node in ep.graph_module.true_graph_0.graph.nodes: if node.target == torch.ops.aten.add.Tensor: @@ -111,7 +111,7 @@ def __init__(self) -> None: def forward(self, x: Tensor) -> Tensor: return self.linear(x) - ep = export_for_training(M(), (torch.randn(10, 10),)) + ep = export_for_training(M(), (torch.randn(10, 10),), strict=True) ep.validate() def test_ep_verifier_invalid_param(self) -> None: @@ -125,7 +125,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a - ep = export_for_training(M(), (torch.randn(100), torch.randn(100))) + ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True) # Parameter doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( @@ -150,7 +150,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a - ep = export_for_training(M(), (torch.randn(100), torch.randn(100))) + ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True) # Buffer doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( @@ -182,7 +182,9 @@ def forward(self, x1, x2): self.my_buffer2.add_(1.0) return output - ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0))) + ep = export_for_training( + M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True + ) ep.validate() def test_ep_verifier_invalid_output(self) -> None: @@ -205,7 +207,9 @@ def forward(self, x1, x2): self.my_buffer2.add_(1.0) return output - ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0))) + ep = export_for_training( + M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True + ) output_node = list(ep.graph.nodes)[-1] output_node.args = ( diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 03b065a3691a..bfd255c50111 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -126,6 +126,9 @@ ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)), ("aten::all_reduce", datetime.date(9999, 1, 30)), + # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp + # TODO: add back restriction when c10d ops can be exported + ("c10d::.*", datetime.date(9999, 1, 1)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 0e3f39eb2266..9349e9c103f2 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -6903,6 +6903,7 @@ def forward(self, t): t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec) sum_1: "f32[]" = torch.ops.aten.sum.default(t) + _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None item: "Sym(u0)" = torch.ops.aten.item.default(to); to = None sin: "f32[2, 3]" = torch.ops.aten.sin.default(t) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index d552179fc9dc..894aa6f544d7 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4598,7 +4598,6 @@ def test_op_has_batch_rule(self, device, dtype, op): "polygamma", "pow", "remainder", - "scatter_add", "scatter", "square", "sub", diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py index e74b90f268da..4e11ed562254 100644 --- a/test/fx/test_dce_pass.py +++ b/test/fx/test_dce_pass.py @@ -232,6 +232,19 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: # %add_ node should not be removed because it has side effects. self._run_dce_and_test(TestModule(), expect_dce_changes=False) + def test_impure_random(self): + """ + Test that DCE doesn't remove call_function for torch.rand. + """ + + class TestModule(torch.nn.Module): + def forward(self, a: torch.Tensor) -> torch.Tensor: + x = torch.rand([10]) # noqa: F841 + return a * 2 + + # %torch.rand should not be removed because it has side effects. + self._run_dce_and_test(TestModule(), expect_dce_changes=False) + def test_impure_kwargs(self): """ Test that DCE doesn't remove call_function nodes with side effects on kwargs. diff --git a/test/fx/test_matcher_utils.py b/test/fx/test_matcher_utils.py index 26caf91485e2..578e0ab07a6a 100644 --- a/test/fx/test_matcher_utils.py +++ b/test/fx/test_matcher_utils.py @@ -173,7 +173,7 @@ def pattern(x, weight): torch.randn(3, 3, 3, 3), ) pattern_gm = export_for_training( - WrapperModule(pattern), example_inputs + WrapperModule(pattern), example_inputs, strict=True ).module() before_split_res = pattern_gm(*example_inputs) pattern_gm, _ = _split_to_graph_and_name_node_map(pattern_gm) @@ -204,11 +204,11 @@ def pattern(x, weight): torch.randn(3, 3, 3, 3), ) pattern_gm = export_for_training( - WrapperModule(pattern), example_inputs + WrapperModule(pattern), example_inputs, strict=True ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) target_gm = export_for_training( - WrapperModule(target_graph), example_inputs + WrapperModule(target_graph), example_inputs, strict=True ).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: @@ -248,9 +248,11 @@ def forward(self, x): return linear, {"linear": linear, "x": x} example_inputs = (torch.randn(3, 5),) - pattern_gm = export_for_training(Pattern(), example_inputs).module() + pattern_gm = export_for_training( + Pattern(), example_inputs, strict=True + ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = export_for_training(M(), example_inputs).module() + target_gm = export_for_training(M(), example_inputs, strict=True).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index da071c4d20ed..69394e0e6428 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -559,23 +559,27 @@ def test_simple_module(self): @mark_compile_region def gn(x): - return mod(x) + return torch.cos(x), mod(x) def fn(x): - return gn(x) + out = gn(x) + return out[0] + out[1] opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) # requires_grad is False deliberately to force None the joint_graph # outputs x = torch.randn(8, 8, requires_grad=False) + x_clone = x.detach().clone().requires_grad_(False) - ref = mod(x) - res = opt_fn(x) - self.assertEqual(ref, res) + ref = fn(x) + res = opt_fn(x_clone) ref.sum().backward() res.sum().backward() + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + def test_fail_with_direct_invoke_subgraph(self): from torch._higher_order_ops import invoke_subgraph @@ -819,6 +823,110 @@ def run(x, train=True): r1.sum().backward() weight.grad.clone() + def test_return_none_from_fwd(self): + @mark_compile_region + def gn(x): + return x * 2, None, x * 3 + + def fn(x): + ys = gn(x) + return ys[0] + ys[2] + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + x = torch.randn(8, 8, requires_grad=True) + x_clone = x.detach().clone().requires_grad_(True) + + ref = fn(x) + res = opt_fn(x_clone) + + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + + backend = AotEagerAndRecordGraphs() + + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + + x = torch.randn(8, 8, requires_grad=True) + res = opt_fn(x_clone) + res.sum().backward() + + self.assertEqual(len(backend.graphs), 1) + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + self.count_unique_get_attr_nodes(backend.graphs[0], [], 1) + self.count_unique_get_attr_nodes(backend.fw_graphs[0], [], 1) + self.count_unique_get_attr_nodes(backend.bw_graphs[0], [], 1) + + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8, 8]"): + l_x_ = L_x_ + + invoke_subgraph_0 = self.invoke_subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = l_x_ = None + getitem: "f32[8, 8]" = invoke_subgraph[0] + getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None + + add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None + return (add,) + + class invoke_subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8, 8]"): + child: "f32[8, 8]" = l_x_ * 2 + child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None + return (child, None, child_1) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[8, 8]"): + ___forward_invoke_subgraph_0_post_graph = self.___forward_invoke_subgraph_0_post_graph + + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_post_graph, '___forward_invoke_subgraph_0_post_graph', (primals_1,)); ___forward_invoke_subgraph_0_post_graph = primals_1 = None + getitem: "f32[8, 8]" = invoke_subgraph_2[0] + getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None + + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_2); getitem = getitem_2 = None + return (add,) + + class ___forward_invoke_subgraph_0_post_graph(torch.nn.Module): + def forward(self, primals_0: "f32[8, 8]"): + mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 2) + mul_1: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 3); primals_0 = None + return (mul, None, mul_1) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[8, 8]"): + ___backward_invoke_subgraph_0_post_graph = self.___backward_invoke_subgraph_0_post_graph + + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_post_graph, '___backward_invoke_subgraph_0_post_graph', (tangents_1, tangents_1)); ___backward_invoke_subgraph_0_post_graph = tangents_1 = None + getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None + return (getitem_3,) + + class ___backward_invoke_subgraph_0_post_graph(torch.nn.Module): + def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): + mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3) + mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + return (add,) +""", + ) + def test_dynamic(self): @mark_compile_region def gn(x): @@ -853,6 +961,27 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_unbacked(self): + @mark_compile_region + def gn(x, y): + b = x.item() + torch._check_is_size(b) + torch._check(b < y.shape[0]) + return y[:b].clone() + + def fn(x, y): + return gn(x, y) + + x = torch.tensor(4) + y = torch.randn(8) + ref = fn(x, y) + opt_fn = torch.compile( + fn, backend="eager", fullgraph=True + ) # Inductor fails with assertion error when lowering aten.sym_constrain_range_for_size.default + res = opt_fn(x, y) + self.assertEqual(ref, res) + def test_bwd_partitioning(self): @mark_compile_region def gn(x, y): @@ -902,7 +1031,9 @@ def forward(self, primals_1: "f32[8, 8]", primals_2: "f32[8, 8]"): class ___forward_invoke_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[8, 8]", primals_1: "f32[8, 8]"): mm: "f32[8, 8]" = torch.ops.aten.mm.default(primals_0, primals_1) + sin: "f32[8, 8]" = torch.ops.aten.sin.default(mm) + t: "f32[8, 8]" = torch.ops.aten.t.default(primals_0); primals_0 = None t_1: "f32[8, 8]" = torch.ops.aten.t.default(primals_1); primals_1 = None return (sin, mm, t, t_1) @@ -927,12 +1058,29 @@ class ___backward_invoke_subgraph_0_post_graph(torch.nn.Module): def forward(self, mm: "f32[8, 8]", t: "f32[8, 8]", t_1: "f32[8, 8]", tangents_0: "f32[8, 8]"): cos: "f32[8, 8]" = torch.ops.aten.cos.default(mm); mm = None mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_0, cos); tangents_0 = cos = None + mm_1: "f32[8, 8]" = torch.ops.aten.mm.default(t, mul); t = None mm_2: "f32[8, 8]" = torch.ops.aten.mm.default(mul, t_1); mul = t_1 = None return (mm_2, mm_1) """, ) + def test_const_tensor(self): + @mark_compile_region + def gn(x): + return torch.tensor(64, dtype=torch.float32) * x + + def fn(x): + return gn(x) + + x = torch.randn(64, requires_grad=True) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + @parameterized_class( [ @@ -981,7 +1129,6 @@ def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"): """, ) - @unittest.expectedFailure def test_unbacked(self): @mark_compile_region def gn(x, y): diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index dd82b5a26f29..d58809ea769e 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -70,6 +70,7 @@ add_kernel_with_tma_2d, mul2_inplace_kernel, strange_config_matmul_kernel, + sub_kernel_autotuned, ) if IS_WINDOWS and IS_CI: @@ -434,6 +435,9 @@ def forward(self, y): self.check_model(model, example_inputs) def test_linear_dynamic_maxautotune(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -550,6 +554,9 @@ def forward(self, x, y): @skip("Test was marked as expected failure, but does not fail always anymore.") def test_dynamic_smem_above_default_limit(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def forward(self, x, y): return x @ y @@ -870,6 +877,9 @@ def forward(self, x, y): ) def test_addmm_multiple_dynamic(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self, n, k, device): super().__init__() @@ -907,6 +917,9 @@ def forward(self, a): ) def test_bmm_multiple_dynamic(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1806,6 +1819,7 @@ def forward(self, x): @skipCUDAIf(True, "Test for x86 backend") @skipIfXpu + @unittest.skipIf(IS_FBCODE, "Need newer ideep") def test_buffer_mutation_and_force_mmap_weights(self): class Model(nn.Module): def __init__(self): @@ -1823,7 +1837,9 @@ def forward(self, x): with config.patch( {"freezing": True, "aot_inductor.force_mmap_weights": True} ), torch.no_grad(): - exported_model = export_for_training(model, example_inputs).module() + exported_model = export_for_training( + model, example_inputs, strict=True + ).module() quantizer = X86InductorQuantizer() quantizer.set_global( xiq.get_default_x86_inductor_quantization_config(reduce_range=True) @@ -2953,6 +2969,9 @@ def forward(self, x): self.check_model(Model(), inputs) def test_convolution(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -3352,6 +3371,55 @@ def forward(self, q, k, v, attn_bias): ) self.check_model(Model(), example_inputs) + def test_aoti_runtime_asserts(self): + from torch.export._draft_export import draft_export, FailureType + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor a, Tensor b) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a[: b.item()] + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_fake_impl(a, b): + ctx = torch.library.get_ctx() + u = ctx.new_dynamic_size() + return torch.empty(u) + + class M(torch.nn.Module): + def forward(self, a, b): + res = torch.ops.mylib.foo(a, b) + s = res.shape[0] + torch._check(s > 3) + torch._check(s < a.shape[0]) + return a[s - 3] + + example_inputs = (torch.randn(100), torch.tensor(10)) + ep = draft_export(M(), example_inputs) + report = ep._report + need_config_patch = any( + not f.xfail and f.failure_type == FailureType.MISMATCHED_FAKE_KERNEL + for f in report.failures + ) + m = ep.module() + + # This should no longer be needed after #150093 + from torch._functorch import config as functorch_config + + with functorch_config.patch( + {"generate_fake_kernels_from_real_mismatches": need_config_patch} + ): + pt2_file = torch._inductor.aoti_compile_and_package(ep) + optimized = torch._inductor.aoti_load_package(pt2_file) + + self.assertTrue(same(optimized(*example_inputs), m(*example_inputs))) + def test_index_put_with_none_index(self): # index_put falls back in the deterministic mode with DeterministicGuard(True): @@ -3501,6 +3569,28 @@ def forward(self, x0, x1): dynamic_shapes=dynamic_shapes, ) + def test_runtime_checks_large(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, *inputs): + result = inputs[0] + for i in range(1, len(inputs)): + result = result + inputs[i] + return result + + inputs = [] + for i in range(1000): + inputs.append(torch.ones(8, 8, 8, dtype=torch.float16, device=self.device)) + inputs = tuple(inputs) + model = Model() + with torch.no_grad(): + AOTIRunnerUtil.compile( + model, + inputs, + ) + def test_runtime_checks_complex(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -3900,6 +3990,47 @@ def forward(self, a): FileCheck().check_not(f"before_launch - {kernel_name}").run(code) FileCheck().check_not(f"after_launch - {kernel_name}").run(code) + @common_utils.parametrize("enable_kernel_profile", (True, False)) + def test_aoti_profiler(self, enable_kernel_profile): + # basic addmm model + class Model(torch.nn.Module): + def __init__(self, n, k, device): + super().__init__() + self.weight = torch.randn(n, k, device=device) + self.bias = torch.randn(n, device=device) + + def forward(self, a): + return torch.nn.functional.linear(a, self.weight, self.bias) + + if sys.platform not in ["linux", "win32"]: + raise unittest.SkipTest( + "enable_kernel_profile only supported on linux and win32" + ) + + M = 8 + N = 6 + K = 16 + model = Model(N, K, self.device) + batch = 2 + a = torch.randn(batch, M, K, device=self.device) + example_inputs = (a,) + kernel_calls = ( + f"aoti_torch_{GPU_TYPE}_addmm_out" + if self.device == GPU_TYPE + else "aoti_torch_cpu_addmm_out" + ) + with config.patch({"cpp.enable_kernel_profile": enable_kernel_profile}): + _, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + shim_fn_codes = ( + f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef());' + ) + if enable_kernel_profile: + FileCheck().check(shim_fn_codes).run(code) + else: + FileCheck().check_not(shim_fn_codes).run(code) + def test_aoti_debug_printer_user_defined_triton_kernel(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -4044,6 +4175,10 @@ def forward(self, a, b, c): AOTIRunnerUtil.compile, model, example_inputs ) self.assertEqual("aoti_torch_print_tensor_handle" in code, True) + + # check if the triton kernel is printed as comment + self.assertEqual("def triton_" in code, True) + # check the codegen for debug printing around aoti model inputs is expected for kernel_call, count in kernel_calls: FileCheck().check_count( @@ -4642,6 +4777,42 @@ def forward(self, x): model, example_inputs, "aoti_torch_clone_preserve_strides", 0 ) + def test_autotuning_args_reuse(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Model(torch.nn.Module): + def forward(self, x, y): + x_out = torch.empty_strided( + (x.size()[0], x.size()[1]), (x.size()[1], 1), device=GPU_TYPE + ) + x_out = torch.permute(x_out, [0, 1]) + add_kernel_autotuned[(4,)](x, x, x_out, 16) + + y_out = torch.empty_strided( + (y.size()[0], y.size()[1]), (y.size()[1], 1), device=GPU_TYPE + ) + y_out = torch.permute(y_out, [0, 1]) + add_kernel_autotuned[(64,)](y, y, y_out, 64) + + sub_kernel_autotuned[(4,)](x, x, x_out, 16) + + return x_out, y_out + + example_inputs = ( + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(8, 8, device=GPU_TYPE), + ) + dim0_x = Dim("dim0_x", min=1, max=2048) + dim0_y = Dim("dim0_y", min=1, max=2048) + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} + self.check_model( + Model(), + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"max_autotune": True}, + ) + @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") def test_stft(self): N_FFT = 400 diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2d9d7cdb1b80..09398c2c59d1 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -19,12 +19,7 @@ from torch._inductor.test_case import TestCase from torch._inductor.utils import fresh_inductor_cache from torch.export import Dim -from torch.testing._internal.common_utils import ( - IS_FBCODE, - skipIfRocm, - skipIfXpu, - TEST_CUDA, -) +from torch.testing._internal.common_utils import IS_FBCODE, skipIfXpu, TEST_CUDA from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -183,7 +178,6 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") - @skipIfRocm # build system may be different @skipIfXpu # build system may be different def test_compile_after_package(self): if not self.package_cpp_only: diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 6868928957a2..04a268abc3cb 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -58,15 +58,6 @@ def legacy_compile( restore_fqn=False, ) - if IS_FBCODE: - from deeplearning.aot_inductor.extern_node_thrift_serializer import ( - thrift_serializer, - ) - - if options is None: - options = {} - options["extern_node_serializer"] = thrift_serializer - with torch.no_grad(): so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 2192e58f0f3f..ca542c81eea1 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -10,7 +10,7 @@ from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import slowTest -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import get_func_call, HAS_CPU, HAS_CUDA # Make the helper files in test/ importable @@ -24,6 +24,7 @@ check_model, check_model_cuda, copy_tests, + skip_if_cpp_wrapper, ) from torch._inductor import config from torch._inductor.scheduler import Scheduler @@ -126,7 +127,7 @@ def f(a, b): self.common(f, (a, b)) - @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON") + @config.patch(max_autotune_gemm_backends="TRITON") def test_avoid_register_spilling(self): if self.device != "cuda": raise unittest.SkipTest("CUDA only") @@ -196,6 +197,7 @@ class BenchmarkingTest(TestCase): @unittest.skipIf( torch.cuda.device_count() < 2, "The test need at least 2 devices" ) + @skip_if_cpp_wrapper("This tests triton scheduling directly") def test_benchmark_on_non_zero_device(self): hit_count = 0 with torch.cuda.device("cuda:0"): @@ -265,9 +267,7 @@ def foo(m, inp): res, code = run_and_get_code(foo_c, m, inp) torch._dynamo.reset() - with unittest.mock.patch.object( - torch._inductor.config, "benchmark_epilogue_fusion", False - ): + with config.patch(benchmark_epilogue_fusion=False): foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) with torch.no_grad(): res2, code2 = run_and_get_code(foo_c, m, inp) @@ -276,32 +276,34 @@ def foo(m, inp): return code, code2 @fresh_inductor_cache() - @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON") + @config.patch(max_autotune_gemm_backends="TRITON") def test_equivalent_template_code(self): code, code2 = self._equivalent_output_code_impl(256) for out_code in [code, code2]: - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("triton_tem_fused_addmm_relu_0.run").check_count( - "del", 3, exactly=True + FileCheck().check(get_func_call()).check_count( + "empty_strided", 1, exactly=True + ).check("triton_tem_fused_addmm_relu_0").check_count( + ".reset()" if config.cpp_wrapper else "del", 3, exactly=True ).check( - "return" + "" if config.cpp_wrapper else "return" ).run( out_code[0] ) @fresh_inductor_cache() - @torch._inductor.config.patch(max_autotune_gemm_backends="ATEN") + @config.patch(max_autotune_gemm_backends="ATEN") def test_equivalent_extern_code(self): torch._dynamo.reset() code, code2 = self._equivalent_output_code_impl(512, 1, False) for out_code in [code, code2]: - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("extern_kernels.").check_count("del", 3, exactly=True).check( - "return" + FileCheck().check(get_func_call()).check_count( + "empty_strided", 1, exactly=True + ).check("" if config.cpp_wrapper else "extern_kernels.").check_count( + ".reset()" if config.cpp_wrapper else "del", 3, exactly=True + ).check( + "" if config.cpp_wrapper else "return" ).run( out_code[0] ) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 1cb4b4f96dfc..bb86d143621e 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +import functools import os import pickle import shutil @@ -1770,6 +1771,69 @@ def f(a, b, c, d, e, f): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") + @requires_triton() + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not SM80OrLater, "Requires SM80+") + @config.patch({"fx_graph_cache": False}) + @config.patch({"fx_graph_remote_cache": False}) + @config.patch({"bundled_autotune_remote_cache": False}) + @config.patch({"max_autotune": True}) + @config.patch( + {"compile_threads": 1} + ) # Worker processes do not register PatchCaches() properly + @parametrize("remote_cache", (True, False)) + def test_modified_autotune_cache(self, remote_cache): + """ + If a developer changes the way the autotune cache is handled, + there's a chance it'll break the cache. This happened with + #150122. This test ensures that if torch code changes, then + old cache entries will be invalidated. + """ + + def mock_torch_key(value: str) -> bytes: + return value.encode("utf-8") + + def get_autotune_stats(): + if remote_cache: + return global_stats.autotune_remote + return global_stats.autotune_local + + def fn(x, y): + return (x + y).relu() + + x = torch.randn(100, 100).cuda() + y = torch.randn(100, 100).cuda() + + with config.patch( + { + "autotune_local_cache": not remote_cache, + "autotune_remote_cache": remote_cache, + } + ): + with PatchCaches(): + with mock.patch( + "torch._inductor.codecache.torch_key", + functools.partial(mock_torch_key, "torchkey1"), + ): + f_compiled = torch.compile(fn, fullgraph=True) + res1 = f_compiled(x, y) + + self.assertEqual(get_autotune_stats(), Stats(1, 0, 1)) + + torch._dynamo.reset() + PyCodeCache.cache_clear() + + with mock.patch( + "torch._inductor.codecache.torch_key", + functools.partial(mock_torch_key, "torchkey2"), + ): + f_compiled = torch.compile(fn, fullgraph=True) + res2 = f_compiled(x, y) + + self.assertEqual(get_autotune_stats(), Stats(2, 0, 2)) + + self.assertEqual(res1, res2) + class TestRemoteAOTAutogradCache(TestCase): @unittest.skipIf(not HAS_CUDA, "Requires CUDA") diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 7294417ad08e..ec0ba10b9cb2 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2801,7 +2801,12 @@ def test_cudagraphs_cpu_division(self): loss.backward() torch._inductor.config.triton.cudagraphs = False - self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue()) + if inductor_config.cpp_wrapper: + self.assertIn("skipping cudagraphs", stderr_msgs.getvalue()) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + else: + self.assertNotIn("skipping cudagraphs", stderr_msgs.getvalue()) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) def test_cudagraphs_cpu_graph(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -2834,7 +2839,10 @@ def test_cudagraphs_sdpa(self): opt_bwd() self.assertEqual(counters["compiled_autograd"]["captures"], 1) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 2 if inductor_config.cpp_wrapper else 0, + ) @unittest.skipIf(not HAS_CUDA, "requires cuda") def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): @@ -2927,7 +2935,10 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 2 if inductor_config.cpp_wrapper else 1, + ) def test_logs(self): logs, ctx = logs_to_string( @@ -3924,6 +3935,9 @@ def backward(ctx, gO): x = torch.randn(10, 10, requires_grad=True) + # https://github.com/pytorch/pytorch/issues/147171 + torch._inductor.config.fallback_random = True + @torch.compile(backend="aot_eager") def fn(x): return SideEffectfulBackward.apply(x).sum() diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 5b10044fb648..7716898c5424 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -189,7 +189,7 @@ class BaseTest(NamedTuple): BaseTest( "test_conv2d_unary", "cpu", - test_mkldnn_pattern_matcher.TestPatternMatcher(), + test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(), condition=torch.backends.mkldnn.is_available(), slow=True, ), @@ -220,9 +220,9 @@ class BaseTest(NamedTuple): ], BaseTest("test_polar"), BaseTest( - "test_linear_binary", + "test_linear_binary_cpu", "", - test_mkldnn_pattern_matcher.TestPatternMatcher(), + test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(), torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), @@ -297,7 +297,7 @@ class BaseTest(NamedTuple): condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, func_inputs=[ [ - "aoti_torch_cpu__qconv2d_pointwise_tensor", + "aoti_torch_cpu__qconv_pointwise_tensor", "torch.ops.quantized.max_pool2d", "aoti_torch_cpu__qlinear_pointwise_tensor", ] @@ -359,7 +359,9 @@ class BaseTest(NamedTuple): BaseTest("test_view_as_complex"), BaseTest("test_view_as_real"), BaseTest( - "test_woq_int4", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher() + "test_woq_int4", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), ), ]: make_test_case( diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index e46dfff708ab..e7722a1eee8f 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -37,12 +37,17 @@ DeterministicGuard, freeze_rng_state, IS_FBCODE, - skipIfRocm, TEST_WITH_ASAN, + TEST_WITH_ROCM, xfailIfPy312Plus, ) +if TEST_WITH_ROCM: + config.force_layout_optimization = 1 + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + + DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -187,7 +192,6 @@ def f(q, k, v, mask): self.assertEqual(out, f(*inputs)) - @skipIfRocm def test_input_channels_last(self): m = torch.nn.Sequential( torch.nn.Conv2d(3, 3, 1, 1), @@ -1323,6 +1327,185 @@ def fn(x, y, z): self.assertEqual(ref, res) + @torch._inductor.config.patch(emulate_precision_casts=True) + def test_dont_inplace_disjoint_accesses(self): + # TODO - would not need mms if we could annotate donated buffer.. + def forward( # noqa: F821, F722 + arg0_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 + arg1_1: "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0", # noqa: F821, F722 + arg2_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 + arg3_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 + arg4_1: "bf16[2048][1]cuda:0", # noqa: F821, F722 + arg5_1: "bf16[2048][1]cuda:0", # noqa: F821, F722 + arg6_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722 + arg7_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722 + ): + permute = torch.ops.aten.permute.default(arg0_1, [1, 0]) + arg0_1 = None + view = torch.ops.aten.view.default(arg1_1, [32768, 2048]) + mm = torch.ops.aten.mm.default(view, permute) + view = permute = None + view_1 = torch.ops.aten.view.default(mm, [8, 4096, 2048]) + mm = None + permute_1 = torch.ops.aten.permute.default(arg2_1, [1, 0]) + arg2_1 = None + view_2 = torch.ops.aten.view.default(arg1_1, [32768, 2048]) + mm_1 = torch.ops.aten.mm.default(view_2, permute_1) + view_2 = permute_1 = None + view_3 = torch.ops.aten.view.default(mm_1, [8, 4096, 2048]) + mm_1 = None + permute_2 = torch.ops.aten.permute.default(arg3_1, [1, 0]) + arg3_1 = None + view_4 = torch.ops.aten.view.default(arg1_1, [32768, 2048]) + arg1_1 = None + mm_2 = torch.ops.aten.mm.default(view_4, permute_2) + view_4 = permute_2 = None + view_5 = torch.ops.aten.view.default(mm_2, [8, 4096, 2048]) + mm_2 = None + convert_element_type_6 = torch.ops.prims.convert_element_type.default( + view_1, torch.float32 + ) + view_1 = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_6, 2) + mean = torch.ops.aten.mean.dim(pow_1, [-1], True) + pow_1 = None + add = torch.ops.aten.add.Tensor(mean, 1e-06) + mean = None + rsqrt = torch.ops.aten.rsqrt.default(add) + add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_6, rsqrt) + convert_element_type_6 = rsqrt = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default( + arg4_1, torch.float32 + ) + arg4_1 = None + mul_1 = torch.ops.aten.mul.Tensor(convert_element_type_7, mul) + convert_element_type_7 = mul = None + convert_element_type_8 = torch.ops.prims.convert_element_type.default( + mul_1, torch.bfloat16 + ) + mul_1 = None + convert_element_type_9 = torch.ops.prims.convert_element_type.default( + view_3, torch.float32 + ) + view_3 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_9, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True) + pow_2 = None + add_1 = torch.ops.aten.add.Tensor(mean_1, 1e-06) + mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_1) + add_1 = None + mul_2 = torch.ops.aten.mul.Tensor(convert_element_type_9, rsqrt_1) + convert_element_type_9 = rsqrt_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default( + arg5_1, torch.float32 + ) + arg5_1 = None + mul_3 = torch.ops.aten.mul.Tensor(convert_element_type_10, mul_2) + convert_element_type_10 = mul_2 = None + convert_element_type_11 = torch.ops.prims.convert_element_type.default( + mul_3, torch.bfloat16 + ) + mul_3 = None + view_6 = torch.ops.aten.view.default( + convert_element_type_8, [8, 4096, -1, 128] + ) + convert_element_type_8 = None + view_7 = torch.ops.aten.view.default( + convert_element_type_11, [8, 4096, -1, 128] + ) + convert_element_type_11 = None + view_8 = torch.ops.aten.view.default(view_5, [8, 4096, -1, 128]) + view_5 = None + convert_element_type_12 = torch.ops.prims.convert_element_type.default( + view_6, torch.float32 + ) + view_6 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default( + view_7, torch.float32 + ) + view_7 = None + unsqueeze = torch.ops.aten.unsqueeze.default(arg6_1, 0) + unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2) + unsqueeze = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(arg7_1, 0) + unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2) + unsqueeze_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_12, unsqueeze_3) + unsqueeze_3 = None + view_9 = torch.ops.aten.view.default( + convert_element_type_12, [8, 4096, 16, 2, 64] + ) + convert_element_type_12 = None + unbind = torch.ops.aten.unbind.int(view_9, -2) + view_9 = None + getitem = unbind[0] + getitem_1 = unbind[1] + unbind = None + neg = torch.ops.aten.neg.default(getitem_1) + getitem_1 = None + cat = torch.ops.aten.cat.default([neg, getitem], -1) + neg = getitem = None + mul_5 = torch.ops.aten.mul.Tensor(cat, unsqueeze_1) + cat = unsqueeze_1 = None + add_2 = torch.ops.aten.add.Tensor(mul_4, mul_5) + mul_4 = mul_5 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg6_1, 0) + arg6_1 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2) + unsqueeze_4 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(arg7_1, 0) + arg7_1 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(unsqueeze_6, 2) + unsqueeze_6 = None + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_13, unsqueeze_7) + unsqueeze_7 = None + view_10 = torch.ops.aten.view.default( + convert_element_type_13, [8, 4096, 16, 2, 64] + ) + convert_element_type_13 = None + unbind_1 = torch.ops.aten.unbind.int(view_10, -2) + view_10 = None + getitem_2 = unbind_1[0] + getitem_3 = unbind_1[1] + unbind_1 = None + neg_1 = torch.ops.aten.neg.default(getitem_3) + getitem_3 = None + cat_1 = torch.ops.aten.cat.default([neg_1, getitem_2], -1) + neg_1 = getitem_2 = None + mul_7 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5) + cat_1 = unsqueeze_5 = None + add_3 = torch.ops.aten.add.Tensor(mul_6, mul_7) + mul_6 = mul_7 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default( + add_2, torch.bfloat16 + ) + add_2 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default( + add_3, torch.bfloat16 + ) + add_3 = None + permute_3 = torch.ops.aten.permute.default( + convert_element_type_14, [0, 2, 1, 3] + ) + convert_element_type_14 = None + permute_4 = torch.ops.aten.permute.default( + convert_element_type_15, [0, 2, 1, 3] + ) + convert_element_type_15 = None + permute_5 = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3]) + view_8 = None + return (permute_3, permute_4, permute_5) + + from torch._dynamo.debug_utils import aot_graph_input_parser + + kwargs = aot_graph_input_parser(forward) + out, code = run_and_get_code(torch.compile(forward), **kwargs) + # ignore tiny values.. prior to this fix absolute error was ~28 + self.assertEqual(forward(**kwargs), out, atol=0.01, rtol=2) + FileCheck().check_not("in_out").run(code[0]) + # https://github.com/pytorch/pytorch/issues/104937 def test_linear_with_zero_infeature_size(self): m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda") @@ -1403,7 +1586,6 @@ def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3): fn(*args) torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address - @skipIfRocm def test_non_commutative_scan_op(self): from torch._higher_order_ops.associative_scan import associative_scan @@ -1450,7 +1632,6 @@ def outer_reduce(x): self.assertEqual(outer_reduce(a), out) self.assertTrue("for roffset" not in code) - @skipIfRocm def test_scaled_dot_product_efficient_attention_backward(self): from torch import nn, Tensor diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 742000347a35..a536aa7ab74f 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -3048,6 +3048,23 @@ def run(shape_x, shape_y): self.assertEqual(self.get_manager().new_graph_id().id, 3) + def test_meta_tensor(self): + def foobar(x, y): + return x * 2, y * 3 + + foo_c = torch.compile(mode="reduce-overhead")(foobar) + t = torch.empty((1, 16, 128, 128), device="meta") + y = torch.rand([64], device="cuda") + + eager_out = foobar(t, y) + + for _ in range(3): + compiled_out = foo_c(t, y) + + compiled_out = foo_c(t, y) + self.assertEqual(eager_out, compiled_out) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + class TestSAC(TestCase): def _make_observer_mode(self): class ObserverMode(TorchDispatchMode): diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 64086e5071c6..e208565081a1 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -12,7 +12,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.utils._triton import has_triton_tma_device @@ -118,7 +117,6 @@ def _fix_fp8_dtype_for_rocm( @instantiate_parametrized_tests class TestFP8Types(TestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): """ @@ -129,6 +127,7 @@ def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): We should not pick a XBLOCK larger than xnumel """ + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") def f(x): return x.to(dtype=float8_dtype) @@ -139,7 +138,6 @@ def f(x): torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("dtype", (torch.float16, torch.bfloat16)) def test_eager_fallback(self, dtype: torch.dtype): weight_shape = (32, 16) @@ -247,7 +245,6 @@ def fp8_saturated(x, dtype): torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) @@ -303,7 +300,6 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 ) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("amax_keep_dim", (True, False)) @@ -413,7 +409,6 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): @instantiate_parametrized_tests class TestFP8Lowering(TestCase): - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.bfloat16, torch.float32)) @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @@ -435,6 +430,7 @@ def test_tensorwise_scaling( device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) shape = [int(dim) for dim in shape.split(",")] M, K, N = shape # Matmul Y = X [M, K] x W [N, K] @@ -491,7 +487,6 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): # setting a small absolute tolerance in these tests torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @parametrize("has_bias", (False, True)) @@ -506,6 +501,7 @@ def test_rowwise_scaling( dtype: torch.dtype = torch.bfloat16 device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) shape = [int(dim) for dim in shape.split(",")] M, K, N = shape # Matmul Y = X [M, K] x W [N, K] @@ -557,7 +553,6 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @@ -573,6 +568,7 @@ def test_tensorwise_scaling_acceptable_input_dims( use_fast_accum = True device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -615,7 +611,6 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @@ -630,6 +625,7 @@ def test_rowwise_scaling_acceptable_input_dims( use_fast_accum = True device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -674,13 +670,14 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_unacceptable_input_dims(self): # for compiled ops, type checking is in torch/_meta_registrations.py dtype: torch.dtype = torch.bfloat16 device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) + M, K, N = 64, 15, 2048 # K needs to be a multiple of 16 x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -714,12 +711,13 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): in str(cm.exception) ) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_unacceptable_scale_dims_rowwise_scaling(self): dtype: torch.dtype = torch.bfloat16 device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) + M, K, N = 233, 32, 128 x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index a62711196c88..3aa7ee276fc6 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -45,8 +45,15 @@ from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU +from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu +from torch.testing._internal.inductor_utils import ( + get_func_call, + get_kernel_launch, + GPU_TYPE, + HAS_CPU, + HAS_CUDA, + HAS_GPU, +) torch.set_float32_matmul_precision("high") @@ -54,14 +61,6 @@ torch.cuda.memory._set_allocator_settings("expandable_segments:False") -def _get_func_call() -> str: - return "void inductor_entry_impl(" if config.cpp_wrapper else "def call(" - - -def _get_kernel_launch() -> str: - return "call_triton_" if config.cpp_wrapper else ".run(" - - def benchmark_choice(choice, args, out, expected_out, timings): result = choice.benchmark(*args, out=out) if expected_out is not None: @@ -672,7 +671,7 @@ def fn(x, number): torch._export.aot_compile(fn, args=inputs) @config.patch(autotune_local_cache=False, autotune_remote_cache=False) - @skipIfRocm + @runOnRocmArch(MI300_ARCH) def test_precompilations(self): def fn(a, b, c): a = (a @ b) @ c @@ -899,8 +898,8 @@ def f(x, y): # mm kernel, and cos kernel count = 2 if using_triton_mm else 1 - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), count, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), count, exactly=True ).run(code[0]) def f(x, y): @@ -912,8 +911,8 @@ def f(x, y): f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f) _, code = run_and_get_code(f_c, inps[0], inps[1]) self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 2, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 2, exactly=True ).run(code[0]) def f(x, y): @@ -1362,24 +1361,65 @@ def setUpClass(cls): ) def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), num_kernels, exactly=True, ).run(code_str) if num_allocs is not None: - FileCheck().check(_get_func_call()).check_count( + FileCheck().check(get_func_call()).check_count( "empty_strided", num_allocs, exactly=True ).run(code_str) # skip the deallocation check when using cpp_wrapper; most deallocations happen # outside of our control via RAIIAtenTensorHandle if num_deallocs is not None and not config.cpp_wrapper: - FileCheck().check(_get_func_call()).check_count( + FileCheck().check(get_func_call()).check_count( "del", num_deallocs, exactly=True ).run(code_str) + @parametrize("prologue", (False, True)) + @unittest.skipIf(TEST_WITH_ROCM, "ROCM Different layout decisions") + def test_conv1x1_cast(self, prologue): + with torch._inductor.config.patch( + prologue_fusion=prologue, force_layout_optimization=True + ): + conv1x1 = ( + torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) + .to(memory_format=torch.channels_last) + .to(GPU_TYPE) + .to(dtype=torch.float16) + ) + input_tensor = ( + torch.randn(4, 3, 32, 32) + .contiguous(memory_format=torch.channels_last) + .to(GPU_TYPE) + ) + + def foo(mod, input): + return torch.nn.functional.conv2d( + input, + mod.weight.to(input.dtype), + None, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + ) + + with torch.no_grad(): + out_eager = foo(conv1x1, input_tensor) + foo_c = torch.compile(foo) + out, code = run_and_get_code(foo_c, conv1x1, input_tensor) + + FileCheck().check_not("extern_kernels.convolution").run(code[0]) + if prologue: + self.check_code( + code[0], num_kernels=1, num_allocs=1, num_deallocs=2 + ) + self.assertEqual(out_eager, out, atol=1e-2, rtol=0) + @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes @@ -1516,8 +1556,8 @@ def multi_use(x, y): out, code = run_and_get_code(torch.compile(multi_use), x, y) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 2, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 2, exactly=True ).run(code[0]) self.assertEqual(out, multi_use(x, y), atol=0.05, rtol=0.05) @@ -1526,8 +1566,8 @@ def resolve_pending(x): x = torch.rand([128, 128], device=GPU_TYPE) out, code = run_and_get_code(torch.compile(resolve_pending), x) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 1, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 1, exactly=True ).run(code[0]) self.assertEqual(out, resolve_pending(x), atol=0.05, rtol=0.05) @@ -1550,8 +1590,8 @@ def test_multiple_fusions(x): x = torch.rand([128, 128], dtype=torch.float16, device=GPU_TYPE) out, code = run_and_get_code(torch.compile(test_multiple_fusions), x) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 1, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 1, exactly=True ).run(code[0]) self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 52a705911166..e3727df7dc87 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -13,6 +13,7 @@ from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.nn import functional as F +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_quantization import ( _generate_qdq_quantized_model, skipIfNoDynamoSupport, @@ -33,7 +34,11 @@ TEST_MKL, xfailIfACL, ) -from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU +from torch.testing._internal.inductor_utils import ( + _check_has_dynamic_shape, + clone_preserve_strides_offset, + HAS_CPU, +) # The dict value is match_nodes(computation_op+unary_op) @@ -91,7 +96,7 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer -def cal_conv_generated_kernel_number(mod, input, dtype, dim=4): +def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): # this function is to decide how many kernels are generated # while testing conv2d/3d/deconv2d # the assumption is: @@ -103,11 +108,14 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4): # and force the output to have same stride with eager. # So there will be a to_contiguous for output if eager output is contiguouse mod = copy.deepcopy(mod) + mod = mod.to(device=device) input = input.clone() + input = input.to(device) + if dtype == torch.float32: maybe_autocast = contextlib.nullcontext() else: - maybe_autocast = torch.amp.autocast("cpu", dtype=dtype) + maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype) with torch.no_grad(), maybe_autocast: output = mod(input) input_kernel, output_kernel = 0, 0 @@ -155,26 +163,33 @@ def _test_common( quantizer=None, compile_options={}, # noqa: B006 ): + if not hasattr(self, "device"): + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) + device = "xpu" if has_xpu else "cpu" + else: + device = self.device + + mod = mod.to(device=device) + if device != "cpu": + inputs = tuple( + clone_preserve_strides_offset(x, device=device) for x in inputs + ) counters.clear() torch._dynamo.reset() - has_xpu = any( - isinstance(input, torch.Tensor) and input.device.type == "xpu" - for input in inputs - ) - device_type = "xpu" if has_xpu else "cpu" if check_autocast == torch.bfloat16 and ( - torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu + torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu" ): maybe_autocast = torch.amp.autocast( - device_type=device_type, dtype=torch.bfloat16 + device_type=device, dtype=torch.bfloat16 ) atol, rtol = 1e-2, 1e-2 elif check_autocast == torch.float16 and ( - torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu + torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu" ): - maybe_autocast = torch.amp.autocast( - device_type=device_type, dtype=torch.float16 - ) + maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16) atol, rtol = 1e-2, 1e-2 else: assert check_autocast == torch.float32 @@ -233,8 +248,8 @@ def _test_code_common( torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) -class TestPatternMatcher(TestPatternMatcherBase): - def _test_conv_unary_cpu_base(self, dim=4): +class TestPatternMatcherGeneric(TestPatternMatcherBase): + def _test_conv_unary_base(self, dim=4): assert dim == 4 or dim == 5 class M(torch.nn.Module): @@ -304,23 +319,27 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) generated_kernel_count = cal_conv_generated_kernel_number( - mod, v, dtype, dim + mod, v, dtype, dim, self.device ) self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv2d_unary_cpu(self): - self._test_conv_unary_cpu_base(dim=4) + def test_conv2d_unary(self, device): + self.device = device + self._test_conv_unary_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv3d_unary_cpu(self): - self._test_conv_unary_cpu_base(dim=5) + def test_conv3d_unary(self, device): + self.device = device + self._test_conv_unary_base(dim=5) + + def test_linear_unary(self, device): + self.device = device - def test_linear_unary(self): class M(torch.nn.Module): def __init__( self, @@ -374,7 +393,9 @@ def matcher_check_fn(): self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_fp32(self): + def test_linear_fp32(self, device): + self.device = device + class M(torch.nn.Module): def __init__(self, bias): super().__init__() @@ -396,7 +417,9 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn) @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_input_non_contiguous_3D_wo_bias(self): + def test_linear_input_non_contiguous_3D_wo_bias(self, device): + self.device = device + # Activation is 3D, non-contiguous and without Bias class M(torch.nn.Module): def __init__(self): @@ -438,17 +461,19 @@ def forward(self, x): ) torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) - def test_linear_add_bias(self): + def test_linear_add_bias(self, device): + self.device = device + class M(torch.nn.Module): - def __init__(self, dtype, unary_fn, cast_bias): + def __init__(self, device, dtype, unary_fn, cast_bias): super().__init__() self.linear1 = torch.nn.Linear(10, 64, bias=False) - self.bias1 = torch.randn(64) + self.bias1 = torch.randn(64, device=device) self.linear2 = torch.nn.Linear(10, 64, bias=False) - self.bias2 = torch.randn(64) + self.bias2 = torch.randn(64, device=device) if cast_bias: - self.bias1 = self.bias1.to(dtype=dtype) - self.bias2 = self.bias2.to(dtype=dtype) + self.bias1 = self.bias1.to(dtype=dtype, device=device) + self.bias2 = self.bias2.to(dtype=dtype, device=device) self.unary_fn = unary_fn def forward(self, x): @@ -464,7 +489,7 @@ def forward(self, x): options = itertools.product(unary_list, dtypes) for unary_fn, dtype in options: metrics.reset() - fold_mod = M(dtype, unary_fn, cast_bias=True).eval() + fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval() v = torch.randn(2, 10) def folder_matcher_check_fn(): @@ -495,7 +520,7 @@ def folder_matcher_check_fn(): # we won't fold the bias if bias is not same dtype with weight # https://github.com/pytorch/pytorch/pull/129138 metrics.reset() - mod = M(dtype, unary_fn, cast_bias=False).eval() + mod = M(self.device, dtype, unary_fn, cast_bias=False).eval() def matcher_check_fn(): self.assertEqual( @@ -575,20 +600,22 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) generated_kernel_count = cal_conv_generated_kernel_number( - mod, v, dtype, dim + mod, v, dtype, dim, self.device ) self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv_transpose2d_unary_cpu(self): + def test_conv_transpose2d_unary(self, device): + self.device = device self._test_conv_transpose_unary_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv_transpose3d_unary_cpu(self): + def test_conv_transpose3d_unary(self, device): + self.device = device self._test_conv_transpose_unary_base(dim=5) def _test_conv_binary_base(self, dim=4): @@ -669,20 +696,22 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) generated_kernel_count = cal_conv_generated_kernel_number( - mod, v, dtype, dim + mod, v, dtype, dim, self.device ) self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv2d_binary(self): + def test_conv2d_binary(self, device): + self.device = device self._test_conv_binary_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv3d_binary(self): + def test_conv3d_binary(self, device): + self.device = device self._test_conv_binary_base(dim=5) def _test_conv_binary_broadcast_shapes_base(self, dim=4): @@ -788,7 +817,9 @@ def test_conv2d_binary_broadcast_shapes_cpu(self): def test_conv3d_binary_broadcast_shapes_cpu(self): self._test_conv_binary_broadcast_shapes_base(dim=5) - def test_linear_binary(self): + def test_linear_binary(self, device): + self.device = device + class M(torch.nn.Module): def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() @@ -939,7 +970,9 @@ def matcher_check_fn(): self._test_common(mod, (x1, x2), matcher_check_fn) - def test_multi_linear_share_same_input(self): + def test_multi_linear_share_same_input(self, device): + self.device = device + # llama pattern. class M(torch.nn.Module): def __init__( @@ -979,6 +1012,8 @@ def matcher_check_fn(): v = torch.randn(2, 4, 16).to(dtype) self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) + +class TestPatternMatcher(TestPatternMatcherBase): def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): def __init__( @@ -1008,14 +1043,14 @@ def matcher_check_fn(): # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 18 if int8_mixed_bf16 else 12, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 3 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 ) self._test_common( @@ -1069,7 +1104,7 @@ def _qconv2d_unary_test_helper( device="cpu", int8_mixed_bf16=False, unary_op=torch.nn.ReLU(), - qconv2d_unary_matcher_nodes=None, + qconv_unary_matcher_nodes=None, ): class M(torch.nn.Module): def __init__( @@ -1098,20 +1133,20 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 2 self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 ) - if qconv2d_unary_matcher_nodes: + if qconv_unary_matcher_nodes: self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_nodes"], - 0 if TEST_ACL else qconv2d_unary_matcher_nodes, + counters["inductor"]["qconv_unary_matcher_nodes"], + 0 if TEST_ACL else qconv_unary_matcher_nodes, ) self._test_common( @@ -1195,7 +1230,7 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) @skipIfNoDynamoSupport @@ -1213,7 +1248,7 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_xpu(self): device="xpu", unary_op=torch.nn.Hardtanh(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) @skipIfNoDynamoSupport @@ -1247,7 +1282,7 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=17, + qconv_unary_matcher_nodes=17, ) @skipIfNoDynamoSupport @@ -1266,7 +1301,7 @@ def test_qconv2d_hardswish_int8_mixed_bf16_xpu(self): device="xpu", unary_op=torch.nn.Hardswish(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=17, + qconv_unary_matcher_nodes=17, ) @skipIfNoDynamoSupport @@ -1300,7 +1335,7 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) @skipIfNoDynamoSupport @@ -1319,7 +1354,7 @@ def test_qconv2d_silu_int8_mixed_bf16_xpu(self): device="xpu", unary_op=torch.nn.SiLU(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) def _qconv2d_add_test_helper( @@ -1380,7 +1415,7 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 4 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 4 ) # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 self.assertEqual( @@ -1477,7 +1512,7 @@ def forward(self, x, x2, x3): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 self.assertEqual( @@ -1576,7 +1611,7 @@ def forward(self, x1, x2): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 1 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 0 self.assertEqual( @@ -1632,14 +1667,14 @@ def forward(self, x: torch.Tensor): def matcher_check_fn(): self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 4 ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 3, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 4 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 4 ) self._test_common( @@ -1805,23 +1840,23 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 4 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default, quantize_per_tensor] self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 1, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_nodes"], + counters["inductor"]["qconv_unary_matcher_nodes"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 1 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 1 ) self._test_common( @@ -1860,16 +1895,16 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 ) self._test_common( @@ -1959,10 +1994,10 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 2 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8 ) # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor] @@ -2028,10 +2063,10 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 2 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8 ) # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor] @@ -2100,10 +2135,10 @@ def matcher_check_fn(): # 2. Dequant-conv pattern matched in quantization weight prepack * 3 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 12 ) # 3. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, add_3] @@ -2140,6 +2175,59 @@ def test_qconv2d_dequant_promotion_cpu(self): def test_qconv2d_dequant_promotion_xpu(self): self._test_qconv2d_dequant_promotion_helper(device="xpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv1d_relu_cpu(self): + r""" + This testcase will quantize Conv1d->ReLU pattern. + """ + device = "cpu" + unary_op = torch.nn.ReLU() + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv1d(3, 128, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.conv2 = torch.nn.Conv1d( + 128, 128, kernel_size=3, stride=1, bias=False + ) + self.unary_fn2 = copy.deepcopy(unary_op) + + def forward(self, x): + tmp = self.unary_fn(self.conv(x)) + return self.unary_fn2(self.conv2(tmp)) + + mod = M().eval().to(device=device) + v = ( + torch.randn((1, 3, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + # 2. QConv2D Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 + ) + + self._test_common( + mod, + (v,), + check_quantization=True, + matcher_check_fn=matcher_check_fn, + ) + def _qlinear_test_helper( self, inputs, @@ -3176,14 +3264,14 @@ def matcher_check_fn(): 0 if TEST_ACL else 1, ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 1, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 1, ) @@ -3275,14 +3363,14 @@ def matcher_check_fn(): counters["inductor"]["qcat_matcher_count"], 0 if TEST_ACL else 1 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 ) self._test_common( @@ -4119,30 +4207,42 @@ def matcher_check_fn(): self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) -# When testing kernel counts, unspecializing float causes wobbling of our tests because -# we end up reusing the same compiled region across tests. Thus we purposely specialize floats -# here since we primarily care about number of kernels generated in the absence of compile -# caching. -@dynamo_config.patch( - { - "dynamic_shapes": True, - "assume_static_by_default": False, - "specialize_float": True, - } -) -class TestDynamicPatternMatcher(TestPatternMatcherBase): - _test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base - test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu - test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu - _test_conv_binary_base = TestPatternMatcher._test_conv_binary_base - test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary - test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary - test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary +class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase): + def setUp(self): + TestCase.setUp(self) + self.ctx_stack = contextlib.ExitStack() + self.ctx_stack.enter_context( + # When testing kernel counts, unspecializing float causes wobbling of our tests because + # we end up reusing the same compiled region across tests. Thus we purposely specialize floats + # here since we primarily care about number of kernels generated in the absence of compile + # caching. + dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } + ) + ) + + def tearDown(self): + TestCase.tearDown(self) + self.ctx_stack.close() + + _test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base + test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary + test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary + _test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base + test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary + test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary + test_linear_unary_dynamic_shapes = TestPatternMatcherGeneric.test_linear_unary test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = ( - TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias + TestPatternMatcherGeneric.test_linear_input_non_contiguous_3D_wo_bias ) - def test_conv_transpose2d_dynamic_shapes(self): + def test_conv_transpose2d_dynamic_shapes(self, device): + self.device = device + # We don't support conv_transpose2d for now. class M(torch.nn.Module): def __init__(self) -> None: @@ -4163,7 +4263,9 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn) - def test_multi_linear_share_same_input_dynamic(self): + def test_multi_linear_share_same_input_dynamic(self, device): + self.device = device + # llama pattern. class M(torch.nn.Module): def __init__( @@ -4206,6 +4308,15 @@ def matcher_check_fn(): v = torch.randn(2, 4, 16).to(dtype) self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) + +@dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } +) +class TestDynamicPatternMatcher(TestPatternMatcherBase): @xfailIfACL def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" @@ -4238,7 +4349,7 @@ def forward(self, x): v = torch.randn((2, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) if include_ops is None: include_ops = [ - "torch.ops.onednn.qconv2d_pointwise", + "torch.ops.onednn.qconv_pointwise", "torch.ops.quantized.max_pool2d", "torch.ops.onednn.qlinear_pointwise", ] @@ -4277,7 +4388,7 @@ def forward(self, x): def matcher_check_fn(): self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) self._test_common( @@ -4367,8 +4478,13 @@ def matcher_check_fn(): ) +instantiate_device_type_tests( + TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu") +) +instantiate_device_type_tests( + TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu") +) instantiate_parametrized_tests(TestPatternMatcher) - if __name__ == "__main__": - if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): + if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available(): run_tests() diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 021ab0440492..d2b1c5c2bec2 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -132,6 +132,7 @@ def test_pointwise_polygamma(self): "chebyshev_polynomial_u", "chebyshev_polynomial_v", "chebyshev_polynomial_w", + "hermite_polynomial_h", ], ) def test_pointwise_binary_op(self, op_name): @@ -162,6 +163,7 @@ def fn(a): # Copy tests for test_name in [ "test_min_max_reduction", + "test_add_complex4", "test_add_const_int", "test_add_inplace_permuted", "test_addmm", @@ -200,13 +202,14 @@ def fn(a): "test_lgamma", "test_linear_float64", "test_log_fp64", - "test_low_memory_max_pool_dilation_1", - "test_low_memory_max_pool_dilation_2", + "test_low_memory_max_pool_dilation_1_dim_2", + "test_low_memory_max_pool_dilation_2_dim_2", "test_max_min", "test_max_pool2d2", "test_multilayer_prime_size", "test_min_max_reduction_nan", "test_nan_to_num", + "test_neg_max_uint8", "test_pow2", "test_prod", "test_randint_int64_mod", @@ -228,6 +231,7 @@ def fn(a): "test_sum_keepdims", "test_tanh", "test_vectorized_ops_masked", + "test_var_mean_tile_reduction_True", "test_view_as_complex", "test_view_on_aliased", "test_views3", diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index b94ba8ef1556..6f4e9fb876d4 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -275,7 +275,8 @@ def test_torchbind_aot_compile(self): "is_hop_single_tensor_return": None, }, }, - ] + ], + "protocol": "json", }, ) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index acb7fc2e12ec..8fa9be1966b1 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -142,6 +142,10 @@ HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines +if TEST_WITH_ROCM: + torch._inductor.config.force_layout_optimization = 1 + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + aten = torch.ops.aten requires_multigpu = functools.partial( @@ -1364,6 +1368,8 @@ def fn(a, b): return c + d for dtype in [torch.complex32, torch.complex64, torch.complex128]: + if not self.is_dtype_supported(dtype): + continue x = torch.tensor( [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], dtype=dtype, @@ -1990,6 +1996,8 @@ def fn(a): return torch.max(a), torch.sum(a) # Requires masked loading for the intermediate reduction + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") sample = torch.full((3999971,), 0, dtype=torch.int64) sample[-1] = 1 self.common(fn, (sample,)) @@ -2061,7 +2069,6 @@ def fn(a): self.common(fn, (inp.view(10, -1),), rtol=1e-4, atol=1e-5, check_lowp=False) @skipCUDAIf(not SM80OrLater, "Requires sm80") - @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm") @skip_if_gpu_halide # accuracy issue def test_split_cumsum_low_prec(self): if is_cpp_backend(self.device): @@ -2133,7 +2140,6 @@ def fn(a): self.common(fn, (inp,), atol=1e-5, rtol=1e-4, check_lowp=False) @skipCUDAIf(not SM80OrLater, "Requires sm80") - @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm") @skip_if_gpu_halide # accuracy issue def test_split_cumprod_low_prec(self): if is_cpp_backend(self.device): @@ -2172,7 +2178,6 @@ def fn(a, b): self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False) - @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops # TODO: support lifted symints when dynamic @torch._dynamo.config.patch( @@ -2232,7 +2237,6 @@ def fn(a, b, dim): r"triton_.*\.run\(arg[01]_1, arg[12]_1, buf1," ).check_not("run(").run(code[0]) - @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops # TODO: support lifted symints when dynamic @torch._dynamo.config.patch( @@ -2260,7 +2264,6 @@ def argmax_combine(a, b): actual = associative_scan(argmax_combine, (a, idx), 0) self.assertEqual(expect, actual) - @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops # TODO: support lifted symints when dynamic @torch._dynamo.config.patch( @@ -2491,6 +2494,8 @@ def fn(x): dtypes = torch.bool, torch.uint8, torch.int inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes] + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") for i in inps: self.common(fn, (i,), check_lowp=False) @@ -2637,6 +2642,17 @@ def fn(a, b): self.common(fn, (torch.randn(4, 4), torch.randn(4, 4))) + @skip_if_halide # different pow accuracies + @xfail_if_triton_cpu + def test_norm_constant_overflow(self): + def fn(a): + return ( + torch.norm(a, p=-41.0, dim=1), + torch.norm(a, p=-41.0, dim=0), + ) + + self.common(fn, (torch.randn(4, 1, 4),)) + @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 def test_dist_bf16(self): @@ -3781,6 +3797,9 @@ def forward(self, x): } ) def test_linear_dynamic_maxautotune(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + @torch.compile(dynamic=True) class Model(torch.nn.Module): def __init__(self) -> None: @@ -4272,34 +4291,35 @@ def fn2(a): ) @parametrize("dilation", (1, 2)) - def test_low_memory_max_pool(self, dilation: int): + @parametrize("dim", (2, 3)) + def test_low_memory_max_pool(self, dilation: int, dim: int): prims = torch.ops.prims def fn(x): - kernel_size = [3, 3] - stride = [2, 2] - padding = [1, 1] + kernel_size = [3, 3] if dim == 2 else [3, 3, 2] + stride = [2] * dim + padding = [1] * dim ceil_mode = False - vals, offsets = prims._low_memory_max_pool2d_with_offsets( + vals, offsets = prims._low_memory_max_pool_with_offsets( x, kernel_size, stride, padding, - [dilation] * 2, + [dilation] * dim, ceil_mode, ) - indices = prims._low_memory_max_pool2d_offsets_to_indices( + indices = prims._low_memory_max_pool_offsets_to_indices( offsets, - kernel_size[1], - x.size(-1), + kernel_size, + x.shape[-dim:], stride, padding, - dilation=[dilation] * 2, + dilation=[dilation] * dim, ) return vals, indices, offsets - self.common(fn, (torch.randn(1, 3, 10, 10),)) + self.common(fn, (torch.randn(1, 3, *[10] * dim),)) @xfail_if_mps def test_to_dtype(self): @@ -5355,6 +5375,19 @@ def test_embedding(self): (torch.randint(10, [2, 8]),), ) + def test_embedding_sparse(self): + # Fix https://github.com/pytorch/pytorch/issues/150656 + def fn(weight, indices): + return F.embedding(indices, weight, sparse=True) + + indices = torch.randint(10, (2, 3)) + weight = torch.randn(10, 3, requires_grad=True) + + self.common( + fn, + (weight, indices), + ) + def test_mean(self): def fn(x): return ( @@ -6565,6 +6598,7 @@ def fn(a, b): ), ) + @skip_if_halide # log2 not implemented for halide def test_log2(self): def fn(x): return torch.log2(x), torch.log2(x + 1) - 2 @@ -6583,6 +6617,7 @@ def fn(x): (torch.randn([8, 8]) + 10,), ) + @skip_if_halide # log2 not implemented for halide def test_log_fp64(self): def fn(x): return torch.log(x), torch.log2(x) @@ -10100,9 +10135,6 @@ def fn(x): for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)): self.common(fn, (x,)) - @skip_if_cpp_wrapper( - "cannot currently handle fallback ops with return types containing list[Tensor]" - ) def test_kwargs(self): if self.device == GPU_TYPE: raise unittest.SkipTest("histogramdd only supports cpu") @@ -10336,6 +10368,15 @@ def fn(arg0_1): [x], ) + @skip_if_halide # log2 not yet implemented + @skip_if_triton_cpu # log2 implemented only in Dec 2024 + def test_pow_by_natural_log2_dynamic_shapes(self): + @torch.compile(dynamic=True) + def fn(x): + return x + 2 ** (math.floor(math.log2(x.shape[0]) + 1)) + + self.common(fn, [torch.randn(5)]) + def test_setitem_with_int_parameter(self): x = torch.zeros(7, device=self.device) @@ -10446,7 +10487,6 @@ def forward(self, arg0_1, arg1_1): eager_out = eager_mod(*eager_args) self.assertEqual(inductor_out, eager_out) - @skipIfRocm def test_require_stride_expanded(self): def forward(arg6, arg7, arg16): convolution = torch.ops.aten.convolution( @@ -12855,6 +12895,68 @@ def test_special_polygamma(self): self.common(fn, (1, x)) self.common(fn, (2, x)) + def test_unaligned_input(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024 + 16, device=self.device)[1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_unaligned_input_2d(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_alignment_without_custom_op(self): + def fn(x): + a = torch.nn.functional.relu(x) + b = (3 * a)[1:-15] + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op(self): + def slice1d(x): + return (3 * x)[1:-15] + + def slice1d_meta(x): + return torch.empty_like(x)[1:-15] + + define_custom_op_for_test("slice1d", slice1d, slice1d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice1d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op_2d(self): + def slice2d(x): + return (3 * x)[..., 1:-15] + + def slice2d_meta(x): + return torch.empty_like(x)[..., 1:-15] + + define_custom_op_for_test("slice2d", slice2d, slice2d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice2d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024, 1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + @dataclasses.dataclass class TestFailure: diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index ee8e22193f41..29d74152bf4e 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -137,6 +137,7 @@ def run(*ex, **kwargs): "test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)), "test_flip_cat_dynamic_shapes": TestFailure(("cpu",)), "test_pad_single_dynamic_shapes": TestFailure(("cpu",)), + "test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), # # Failed to find for loop/triton kernel: # @@ -261,9 +262,6 @@ def run(*ex, **kwargs): ), "test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_custom_op_3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), - "test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure( - ("cuda", "xpu") if IS_LINUX else ("cpu", "cuda", "xpu") - ), "test_cat_uint8_dynamic_shapes": TestFailure( ("cpu",) ), # cat on uint8 input is using aten fallback on cpu @@ -383,11 +381,12 @@ def run(*ex, **kwargs): **dynamic_shapes_test_failures, } -if TEST_WITH_ROCM: +if not TEST_WITH_ROCM: test_failures.update( { - "test_split_cumsum_low_prec_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_split_cumprod_low_prec_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure( + ("cuda", "xpu") if IS_LINUX else ("cpu", "cuda", "xpu") + ), } ) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index b84640235739..ac552b312dea 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -45,6 +45,7 @@ GPU_TYPE, HAS_CPU, HAS_CUDA, + has_triton, HAS_XPU, maybe_skip_size_asserts, ) @@ -976,6 +977,8 @@ def test_comprehensive(self, device, dtype, op): "nn.functional.interpolate.bicubic", "nn.functional.upsample_bilinear", "nn.functional.upsample_nearest", + "fill", + "full_like", ): if dtype not in allowed_dtypes: raise unittest.SkipTest("Skipped!") @@ -1103,48 +1106,50 @@ def _get_tolerances(dtype): # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True, file=f) # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True) rtol, atol = _get_tolerances(dtype) - if device_type == GPU_TYPE: - # opinfo test case have already place the input on the correct device - # so we don't need do additional copy by setting copy_to_gpu=False - - no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) - for context_fn, kwarg_overrides in get_contexts(has_rng_op): - with context_fn(): - adjusted_kwargs = { - "check_lowp": False, - "nopython": no_python, - "copy_to_gpu": False, - "reference_in_float": False, - "check_gradient": requires_grad, - "check_has_compiled": no_python, - "output_process_fn_grad": sample_input.output_process_fn_grad, - "atol": atol, - "rtol": rtol, - } - adjusted_kwargs.update(overridden_kwargs) - adjusted_kwargs.update(kwarg_overrides) + no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) + for context_fn, kwarg_overrides in get_contexts(has_rng_op): + with context_fn(): + # Base kwargs + adjusted_kwargs = { + "check_lowp": False, + "nopython": no_python, + "check_has_compiled": no_python, + "atol": atol, + "rtol": rtol, + } + + # Backend-specific adjustments + # Triton + if has_triton(): + adjusted_kwargs.update( + { + "copy_to_gpu": False, + "reference_in_float": False, + "check_gradient": requires_grad, + "output_process_fn_grad": sample_input.output_process_fn_grad, + } + ) + # C++ CPU backend + elif torch._inductor.config.cpu_backend == "cpp": + adjusted_kwargs.update( + { + "check_gradient": False, # Skip checking gradient on CPU for now + } + ) + + # Update with overridden kwargs and context-specific overrides + adjusted_kwargs.update(overridden_kwargs) + adjusted_kwargs.update(kwarg_overrides) + + # Call the appropriate check method based on device type + if device_type == GPU_TYPE: self.check_model_gpu( fn, args, kwargs, **adjusted_kwargs, ) - elif device_type == "cpu": - no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) - for context_fn, kwarg_overrides in get_contexts(has_rng_op): - with context_fn(): - adjusted_kwargs = { - "check_lowp": False, - "nopython": no_python, - "check_has_compiled": no_python, - # skip checking gradient on CPU for now - "check_gradient": False, - "atol": atol, - "rtol": rtol, - } - adjusted_kwargs.update(overridden_kwargs) - adjusted_kwargs.update(kwarg_overrides) - + else: self.check_model( fn, args, diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 895c536ed326..ec6c3dc8a578 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -931,6 +931,40 @@ def foo(x, y, z): # Check for 3D tiling self.assertIn("ZBLOCK", code) + @torch._dynamo.config.patch({"capture_scalar_outputs": True}) + @parametrize("num_tile_candidates", (1, 2)) + def test_unbacked_size_on_non_contig_dim(self, num_tile_candidates: int): + # NUM_REPEAT should determine # of candidate_tilings. + NUM_REPEAT = 2 if num_tile_candidates == 2 else 8 + + def foo(x, length): + unbacked = length.item() + torch._check_is_size(unbacked) + + repeated = x.repeat(1, unbacked, NUM_REPEAT) + # permute creates split in middle with unbacked symint is the first range + # ranges: [33*unbacked, NUM_REPEAT, 64] + permute120 = repeated.permute([1, 2, 0]) + return permute120.cos() + + inps = ( + torch.rand((64, 33, 1), device=self.device, dtype=torch.float32), + torch.scalar_tensor(16, device=self.device, dtype=torch.int32), + ) + + with torch._dynamo.config.patch({"capture_scalar_outputs": True}): + run_and_compare( + self, + foo, + *inps, + expected_num_triton_kernels=1, + expected_num_block_pointers=0, + config_patches={ + "triton.max_tiles": 3, + "triton.prefer_nd_tiling": True, + }, + ) + # block_ptr advancements should also be deferrered conditional # on the associated buffer not being removed # in this case the bernoulli operation is fused with the following sum diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 3f495042a392..4966821120c5 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -25,13 +25,7 @@ from torch._library import capture_triton from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import ( - parametrize, - skipIfRocm, - skipIfWindows, - skipIfXpu, - TEST_WITH_ROCM, -) +from torch.testing._internal.common_utils import parametrize, skipIfWindows, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU from torch.testing._internal.logging_utils import log_settings, logs_to_string @@ -44,23 +38,22 @@ import triton from triton import language as tl - if not TEST_WITH_ROCM: - if HAS_CUDA: - try: - from triton.language.extra.libdevice import ( # @manual - fast_dividef, - fast_dividef as my_fast_dividef, - ) - except ImportError: - from triton.language.extra.cuda.libdevice import ( # @manual - fast_dividef, - fast_dividef as my_fast_dividef, - ) - elif HAS_XPU: - from triton.language.extra.intel.libdevice import ( # @manual + if HAS_CUDA: + try: + from triton.language.extra.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, ) + except ImportError: + from triton.language.extra.cuda.libdevice import ( # @manual + fast_dividef, + fast_dividef as my_fast_dividef, + ) + elif HAS_XPU: + from triton.language.extra.intel.libdevice import ( # @manual + fast_dividef, + fast_dividef as my_fast_dividef, + ) def _triton_get_ast_equal_to_str(params): try: @@ -1341,7 +1334,6 @@ def f(x, y): self.assertEqual(compiled_out, eager_out) @requires_gpu - @skipIfRocm def test_triton_kernel_with_imported_symbol(self): @triton.jit def add_kernel_with_imported_symbol( @@ -1373,7 +1365,6 @@ def f(x): self.assertEqual(compiled_out, eager_out) @requires_gpu - @skipIfRocm def test_triton_kernel_with_imported_symbol_with_custom_name(self): @triton.jit def add_kernel_with_imported_symbol( @@ -3385,7 +3376,10 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertEqual(z, (x + y) * 2) @requires_gpu - def test_preserves_strides(self): + @common_utils.parametrize( + "variant", ["triton_kernel", "custom_op", "mutable_custom_op"] + ) + def test_preserves_strides(self, variant): import triton import triton.language as tl @@ -3409,12 +3403,10 @@ def add_kernel( x = torch.randn(4, 4, 2, 2, device=GPU_TYPE) other = torch.randn(4, 4, 2, 2, device=GPU_TYPE) - def f(x, other): - y = x.transpose(2, 3).contiguous().transpose(2, 3) - z = y.sin().transpose(2, 3) + def add_triton(y, z): grid = (z.numel(),) - out = torch.empty_like(other) - add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16) + out = torch.empty_like(z, memory_format=torch.contiguous_format) + add_kernel[grid](y, z, out, z.numel(), BLOCK_SIZE=16) return out class _CustomPass(PatternMatcherPass): @@ -3436,8 +3428,8 @@ def _(match, *args, **kwargs): def decomp(*flat_args): args, kwargs = pytree.tree_unflatten(flat_args, spec) - return torch.ops.aten.permute(*args, **kwargs).clone( - memory_format=torch.channels_last + return torch.ops.mylib.force_channels_last( + torch.ops.aten.permute(*args, **kwargs) ) nonlocal called @@ -3446,12 +3438,63 @@ def decomp(*flat_args): from torch._inductor import config - with config.patch( - post_grad_custom_post_pass=g, - ): - f_compile = torch.compile(f) - self.assertEqual(f(x, other), f_compile(x, other)) - self.assertTrue(called) + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define( + "force_channels_last(Tensor x) -> Tensor", + tags=[torch._C.Tag.flexible_layout], + ) + + def impl2(x): + return x.clone(memory_format=torch.channels_last) + + lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd") + + lib.define( + "add_op(Tensor x, Tensor y) -> Tensor", + tags=[torch._C.Tag.needs_exact_strides], + ) + + def impl(x, y): + return add_triton(x, y) + + def meta(x, y): + return torch.empty_like(y, memory_format=torch.contiguous_format) + + lib.impl("add_op", impl, "CompositeExplicitAutograd") + lib.impl("add_op", meta, "Meta") + + lib.define( + "add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()", + tags=[torch._C.Tag.needs_exact_strides], + ) + + def impl_out(x, y, out): + grid = (y.numel(),) + add_kernel[grid](x, y, out, y.numel(), BLOCK_SIZE=16) + + lib.impl("add_out_op", impl_out, "CompositeExplicitAutograd") + lib.impl("add_out_op", lambda x, y, out: None, "Meta") + + def f(x, other): + y = x.transpose(2, 3).contiguous().transpose(2, 3) + z = y.sin().transpose(2, 3) + if variant == "triton_kernel": + return add_triton(y, z) + elif variant == "custom_op": + return torch.ops.mylib.add_op.default(y, z) + elif variant == "mutable_custom_op": + out = torch.empty_like(y, memory_format=torch.contiguous_format) + torch.ops.mylib.add_out_op(y, z, out) + return out + else: + raise AssertionError("should not be hit") + + with config.patch( + post_grad_custom_post_pass=g, + ): + f_compile = torch.compile(f, fullgraph=True) + self.assertEqual(f(x, other), f_compile(x, other)) + self.assertTrue(called) @requires_gpu @common_utils.parametrize("dynamic", [False, True]) diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index a4dc1c97772d..3ebf00eccec0 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -246,6 +246,31 @@ def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(se ) ) + def test_upgraded_torchlib_impl(self): + class GeluModel(torch.nn.Module): + def forward(self, input): + # Use GELU activation function + return torch.nn.functional.gelu(input, approximate="tanh") + + input = torch.randn(1, 3, 4, 4) + onnx_program_op18 = torch.onnx.export( + GeluModel(), + input, + dynamo=True, + ) + all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] + self.assertIn("Tanh", all_nodes_op18) + self.assertNotIn("Gelu", all_nodes_op18) + + onnx_program_op20 = torch.onnx.export( + GeluModel(), + input, + opset_version=20, + dynamo=True, + ) + all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph] + self.assertIn("Gelu", all_nodes_op20) + def test_refine_dynamic_shapes_with_onnx_export(self): # NOTE: From test/export/test_export.py diff --git a/test/onnx/torchlib/ops_test_common.py b/test/onnx/torchlib/ops_test_common.py index 73c00de388fa..884b66d4e02f 100644 --- a/test/onnx/torchlib/ops_test_common.py +++ b/test/onnx/torchlib/ops_test_common.py @@ -52,6 +52,7 @@ torch.float64, ) + TEST_OPSET_VERSION = 18 IS_MACOS = sys.platform.startswith("darwin") IS_WINDOWS = os.name == "nt" @@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - def graph_executor( test_name: str, outputs: Sequence[Any], + opset_version: int = TEST_OPSET_VERSION, ) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]: """Eagerly executes a function.""" @@ -500,10 +502,10 @@ def _capture_graph_and_evaluate_torch_script_evaluator( (), (), nodes=(), - opset_imports={"": 18, "pkg.torch.onnx": 1}, + opset_imports={"": opset_version, "pkg.torch.onnx": 1}, name="main_graph", ) - opset = onnxscript.opset18 + opset = onnxscript.values.Opset("", opset_version) tracer = _building.OpRecorder(opset, {}) ort_inputs = {} onnxscript_args: list[Any] = [] @@ -590,7 +592,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator( proto = onnxscript_function.to_function_proto() ir_function = ir.serde.deserialize_function(proto) onnx_model.functions[identifier] = ir_function - _ir_passes.add_torchlib_common_imports(onnx_model) + _ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version) _ir_passes.add_opset_imports(onnx_model) # Make sure the model is valid model_proto = ir.to_proto(onnx_model) diff --git a/test/onnx/torchlib/ops_test_data.py b/test/onnx/torchlib/ops_test_data.py index b255f07640b8..a69d7a4bec1e 100644 --- a/test/onnx/torchlib/ops_test_data.py +++ b/test/onnx/torchlib/ops_test_data.py @@ -46,7 +46,7 @@ import ops_test_common import torch -from torch.onnx._internal.exporter._torchlib.ops import core as core_ops +from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops from torch.testing._internal import common_methods_invocations from torch.testing._internal.opinfo import definitions as opinfo_definitions @@ -78,6 +78,12 @@ class TorchLibOpInfo: compare_shape_only_for_output: tuple[int, ...] = () # Whether the function is designed for complex inputs complex: bool = False + # The ONNX opset version in which the function was introduced. + # Its specifies the minimum ONNX opset version required to use the function. + # It ensures that the function is only used when the target ONNX opset version + # is compatible. For example, if `opset_introduced=20`, the function will only + # be used when exporting to ONNX models targeting opset version 20 or higher. + opset_introduced: int = 18 # The acceptable tolerance of the inference result difference between PyTorch and ORT. # Format: {dtype: (rtol, atol)}. # For example: {torch.float16: (1e-3, 1e-3)} @@ -447,8 +453,10 @@ def _where_input_wrangler( TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), + TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20), ) + ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) ops_test_common.duplicate_opinfo( @@ -500,6 +508,7 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.scaled_dot_product_attention", diff --git a/test/onnx/torchlib/test_ops.py b/test/onnx/torchlib/test_ops.py index 74cbeeca3138..a7a52698cd23 100644 --- a/test/onnx/torchlib/test_ops.py +++ b/test/onnx/torchlib/test_ops.py @@ -220,7 +220,9 @@ def run_test_output_match( test_name = test_suite.id() function_output, model_proto = function_executor( - test_name, reference_torch_outputs + test_name, + reference_torch_outputs, + opset_version=torchlib_op_info.opset_introduced, )(onnx_function, input_onnx, kwargs_onnx) # Finally we re-flatten everything # TODO: add pytree structure comparison. diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 7dac5fb70905..48bbbf01727f 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -690,6 +690,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ...""", ) + @skipIfTorchDynamo("segfaults in 3.13+") @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 33c0c932ea05..070f341faf13 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -7010,7 +7010,7 @@ def test_qconv2d_pt2e(self): if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise): # Remove some test combination to reduce UT test time continue - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7123,7 +7123,7 @@ def test_qconv2d_relu_pt2e(self): output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7174,7 +7174,7 @@ def test_qconv2d_hardtanh_pt2e(self): output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7225,7 +7225,7 @@ def test_qconv2d_silu_pt2e(self): output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7277,7 +7277,7 @@ def test_qconv2d_hardswish_pt2e(self): options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7480,6 +7480,58 @@ def test_qconv2d_sum_relu_float_output_pt2e(self): qconv_x2_dtype=qconv_x2_dtype, ) + # Test qconv1d with post op relu + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv1d_relu_pt2e(self): + input_channels_per_group = 2 + output_channels_per_group = 2 + groups_list = [1, 10] + input_feature_map_shape = (10,) + kernels = (3,) + strides = (2,) + pads = (1,) + dilations = (1,) + W_scale = [1.5] + W_zero_point = [0] + use_bias_list = [False, True] + use_channelwise_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] + options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) + for groups, use_bias, use_channelwise, output_dtype in options: + qconv = torch.ops.onednn.qconv_pointwise + qconv_prepack = torch.ops.onednn.qconv_prepack + conv_op = torch.nn.Conv1d( + input_channels_per_group * groups, + output_channels_per_group * groups, + kernels, + strides, + pads, + dilations, + groups, + ) + pointwise_post_op = PointwisePostOp(unary_attr="relu") + self._test_qconv_impl_cpu_tensor( + qconv, + qconv_prepack, + conv_op, + input_channels_per_group=input_channels_per_group, + input_feature_map_shape=input_feature_map_shape, + output_channels_per_group=output_channels_per_group, + groups=groups, + kernels=kernels, + strides=strides, + pads=pads, + dilations=dilations, + W_scale=W_scale, + W_zero_point=W_zero_point, + use_bias=use_bias, + post_op=pointwise_post_op, + use_channelwise=use_channelwise, + qconv_output_dtype=output_dtype, + ) + + class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), channels=st.integers(1, 64), diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 54456ab37b15..4a5cb6edaeb6 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -101,10 +101,7 @@ def _test_duplicate_dq( # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 4f6eb4f56d3a..96eff3a789f2 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -98,10 +98,7 @@ def _test_metadata_porting( # program capture m = copy.deepcopy(m_eager) - m = torch.export.export_for_training( - m, - example_inputs, - ).module() + m = torch.export.export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index b5ada0cc3d59..deff8e4987e5 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -81,7 +81,7 @@ def _extract_debug_handles_with_prev_decomp_op_from_node(node): def test_simple(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) debug_handle_map = self._extract_debug_handles(ep) @@ -91,7 +91,7 @@ def test_simple(self): def test_control_flow(self): m = TestHelperModules.ControlFlow() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -102,7 +102,7 @@ def test_control_flow(self): def test_quantize_pt2e_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() @@ -162,14 +162,14 @@ def test_deepcopy_preserve_handle(self): def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() self._assert_each_node_has_debug_handle(ep) debug_handle_map_ref = self._extract_debug_handles(ep) - ep_reexport = export_for_training(m, example_inputs) + ep_reexport = export_for_training(m, example_inputs, strict=True) self._assert_each_node_has_debug_handle(ep_reexport) debug_handle_map = self._extract_debug_handles(ep_reexport) @@ -179,7 +179,7 @@ def test_re_export_preserve_handle(self): def test_run_decompositions_same_handle_id(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -204,7 +204,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self): for m in test_models: example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -227,7 +227,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self): def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_logger = prepare_for_propagation_comparison(m) @@ -244,7 +244,7 @@ def test_prepare_for_propagation_comparison(self): def test_extract_results_from_loggers(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -269,7 +269,7 @@ def test_extract_results_from_loggers(self): def test_extract_results_from_loggers_list_output(self): m = TestHelperModules.Conv2dWithSplit() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -299,7 +299,7 @@ def test_extract_results_from_loggers_list_output(self): def test_added_node_gets_unique_id(self) -> None: m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) ref_handles = self._extract_debug_handles(ep) ref_counter = Counter(ref_handles.values()) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 2bc87f72fc25..87ac89fe852c 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -767,10 +767,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, BackendAQuantizer()) # make sure the two observers for input are shared conv_output_obs = [] @@ -830,10 +827,7 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer): ) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) # make sure the two input observers and output are shared @@ -1152,10 +1146,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: ) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = BackendAQuantizer() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -1305,7 +1296,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = export_for_training(m, (example_inputs,)).module() + m = export_for_training(m, (example_inputs,), strict=True).module() with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1428,10 +1419,7 @@ def forward(self, x): quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() weight_meta = None for n in m.graph.nodes: if ( @@ -1518,7 +1506,7 @@ def forward(self, x): m = M().eval() quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1569,7 +1557,7 @@ def forward(self, x, y, z): torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3), ) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1824,7 +1812,7 @@ def forward(self, x): example_inputs = (torch.randn(1),) m = M().train() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() if inplace: target = torch.ops.aten.dropout_.default else: @@ -1889,7 +1877,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() # Assert that batch norm op exists and is in train mode bn_node = self._get_node(m, bn_train_op) @@ -1920,7 +1908,7 @@ def test_disallow_eval_train(self): m.train() # After export: this is not OK - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -1961,7 +1949,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): targets = [n.target for n in m.graph.nodes] @@ -2027,7 +2015,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() torch.ao.quantization.allow_exported_model_train_eval(m) # Mock m.recompile() to count how many times it's been called @@ -2059,7 +2047,7 @@ def _fake_recompile(): def test_model_is_exported(self): m = TestHelperModules.ConvWithBNRelu(relu=True) example_inputs = (torch.rand(3, 3, 5, 5),) - exported_gm = export_for_training(m, example_inputs).module() + exported_gm = export_for_training(m, example_inputs, strict=True).module() fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) self.assertTrue( torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm) @@ -2077,7 +2065,9 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_qat=True) ) - m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs, strict=True + ).module() m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) m(*example_inputs) m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) @@ -2085,7 +2075,7 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_module_type( torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) ) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -2257,7 +2247,7 @@ def test_speed(self): def dynamic_quantize_pt2e(model, example_inputs): torch._dynamo.reset() - model = export_for_training(model, example_inputs).module() + model = export_for_training(model, example_inputs, strict=True).module() # Per channel quantization for weight # Dynamic quantization for activation # Please read a detail: https://fburl.com/code/30zds51q @@ -2360,7 +2350,7 @@ def forward(self, x): example_inputs = (torch.randn(1, 3, 5, 5),) m = M() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(), ) @@ -2442,7 +2432,7 @@ def prepare_obs_or_fq_callback( edge_or_node_to_obs_or_fq[x] = new_observer example_inputs = (torch.rand(1, 32, 16, 16),) - gm = export_for_training(Model().eval(), example_inputs).module() + gm = export_for_training(Model().eval(), example_inputs, strict=True).module() gm = prepare_pt2e(gm, BackendAQuantizer()) gm = convert_pt2e(gm) for n in gm.graph.nodes: @@ -2469,7 +2459,9 @@ def check_nn_module(node): "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] ) - m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs, strict=True + ).module() for node in m.conv_bn_relu.graph.nodes: if node.op not in ["placeholder", "output", "get_attr"]: check_nn_module(node) @@ -2562,5 +2554,188 @@ def forward(self, x): is_debug_mode=True, ) + def test_dynamic_affine_act_per_channel_weights(self): + import operator + + from torch.ao.quantization.observer import ( + MappingType, + PerChannelMinMaxObserver, + PerToken, + ) + from torch.ao.quantization.pt2e._affine_quantization import ( + AffineQuantizedMovingAverageMinMaxObserver, + ) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + activation_dtype = torch.int8 + act_qspec = QuantizationSpec( + dtype=activation_dtype, + quant_min=-128, + quant_max=127, + qscheme=None, + is_dynamic=True, + observer_or_fake_quant_ctr=AffineQuantizedMovingAverageMinMaxObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=activation_dtype, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + averaging_constant=1, + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + node_occurrence = { + torch.ops.pt2e_quant.choose_qparams_affine: 1, + operator.getitem: 2, + torch.ops.pt2e_quant.quantize_affine: 1, + torch.ops.pt2e_quant.dequantize_affine: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.choose_qparams_affine, + operator.getitem, + torch.ops.pt2e_quant.quantize_affine, + torch.ops.pt2e_quant.dequantize_affine, + ] + example_inputs = (torch.randn(5, 128),) + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + is_debug_mode=True, + ) + + def test_dynamic_per_tok_act_per_group_weights(self): + import operator + + from torch.ao.quantization.observer import MappingType, PerGroup, PerToken + from torch.ao.quantization.pt2e._affine_quantization import ( + AffineQuantizedMinMaxObserver, + AffineQuantizedPlaceholderObserver, + ) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + activation_dtype = torch.int8 + act_qspec = QuantizationSpec( + dtype=activation_dtype, + quant_min=-128, + quant_max=127, + qscheme=None, + is_dynamic=True, + observer_or_fake_quant_ctr=AffineQuantizedPlaceholderObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=activation_dtype, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + target_dtype=torch.int8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerGroup(group_size=128), + ), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + node_occurrence = { + torch.ops.pt2e_quant.choose_qparams_affine: 1, + operator.getitem: 2, + torch.ops.pt2e_quant.quantize_affine: 1, + torch.ops.pt2e_quant.dequantize_affine: 2, + } + node_list = [ + torch.ops.pt2e_quant.dequantize_affine, + torch.ops.pt2e_quant.choose_qparams_affine, + operator.getitem, + torch.ops.pt2e_quant.quantize_affine, + torch.ops.pt2e_quant.dequantize_affine, + ] + example_inputs = (torch.randn(5, 128),) + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + is_debug_mode=True, + ) + instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index abc9849aee82..b52f34c68c5b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -140,8 +140,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper( ) ) model_pt2e = export_for_training( - model_pt2e, - example_inputs, + model_pt2e, example_inputs, strict=True ).module() model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) torch.manual_seed(MANUAL_SEED) @@ -229,10 +228,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( quantizer.set_global( get_symmetric_quantization_config(is_per_channel, is_qat=True) ) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -621,7 +617,7 @@ def forward(self, x): m = M(self.conv_class, self.bn_class, backbone) quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) @@ -679,7 +675,7 @@ def get_source_fn(node: torch.fx.Node): def test_qat_conv_bn_bias_derived_qspec(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -726,7 +722,7 @@ def test_qat_conv_bn_bias_derived_qspec(self): def test_qat_per_channel_weight_custom_dtype(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnInt32WeightQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -780,7 +776,7 @@ def test_qat_conv_transpose_bn_relu(self): def test_qat_conv_bn_per_channel_weight_bias(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -837,7 +833,7 @@ def test_fold_bn_erases_bn_node(self): it into conv in `convert_pt2e` even in train mode. """ m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = export_for_training(m, self.example_inputs).module() + m = export_for_training(m, self.example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True), @@ -1085,7 +1081,9 @@ def _prepare_qat_linears(self, model): in_channels = child.linear1.weight.size(1) example_input = (torch.rand((1, in_channels)),) - traced_child = export_for_training(child, example_input).module() + traced_child = export_for_training( + child, example_input, strict=True + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=True, is_qat=True @@ -1116,10 +1114,7 @@ def test_mixing_qat_ptq(self): self._convert_qat_linears(model) model(*example_inputs) - model_pt2e = export_for_training( - model, - example_inputs, - ).module() + model_pt2e = export_for_training(model, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_module_type(torch.nn.Linear, None) diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index c6eed1ed8260..3648ac352dc4 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -33,10 +33,7 @@ def _test_representation( ) -> torch.nn.Module: # resetting dynamo cache torch._dynamo.reset() - model = export_for_training( - model, - example_inputs, - ).module() + model = export_for_training(model, example_inputs, strict=True).module() model_copy = copy.deepcopy(model) model = prepare_pt2e(model, quantizer) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 31cecf9adeda..e0fcbbc9b515 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -7,6 +7,7 @@ import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.nn as nn from torch.ao.quantization import ObserverBase +from torch.ao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -551,6 +552,102 @@ def forward(self, x): y = torch.cat([y, y], dim=-1) return y.transpose(1, 2) + class MiniResNet(nn.Module): + class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, downsample=None): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(out_channels) + self.downsample = downsample + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + def __init__(self, num_classes=10): + super().__init__() + self.in_channels = 16 + self.conv1 = nn.Conv2d( + 3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(self.in_channels) + self.relu = nn.ReLU() + self.layer1 = self._make_layer(16, 1) + self.layer2 = self._make_layer(32, 1, stride=2) + self.layer3 = self._make_layer(64, 1, stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64, num_classes) + + def _make_layer(self, out_channels, blocks, stride=1): + downsample = None + if stride != 1 or self.in_channels != out_channels: + downsample = nn.Sequential( + nn.Conv2d( + self.in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(out_channels), + ) + + layers = [] + layers.append( + self.BasicBlock(self.in_channels, out_channels, stride, downsample) + ) + self.in_channels = out_channels + for _ in range(1, blocks): + layers.append(self.BasicBlock(self.in_channels, out_channels)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + class X86InductorQuantTestCase(QuantizationTestCase): def _test_quantizer( @@ -562,15 +659,13 @@ def _test_quantizer( expected_node_list=None, is_qat=False, debug=False, + lower=False, ): m_eager = model.train() if is_qat else model.eval() # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -582,6 +677,8 @@ def _test_quantizer( convert_model = copy.deepcopy(m) if debug: convert_model.print_readable(True) + if lower: + m = lower_pt2e_quantized_to_x86(m, example_inputs) m(*example_inputs) node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() @@ -2244,7 +2341,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -2732,3 +2829,32 @@ def test_attention_block(self): node_occurrence, node_list, ) + + @skipIfNoX86 + def test_lowering_to_x86(self): + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.MiniResNet().eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.onednn.qconv_pointwise.default: 6, + torch.ops.onednn.qconv2d_pointwise.binary: 3, + torch.ops.onednn.qlinear_pointwise.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + lower=True, + ) diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 36209e5aad10..4e14dfd27ae2 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -361,7 +361,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -497,10 +497,7 @@ def test_propagate_annotation(self): example_inputs = (torch.randn(1, 3, 5, 5),) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -766,8 +763,7 @@ def forward(self, input_tensor, hidden_tensor): with torchdynamo.config.patch(allow_rnn=True): model_graph = export_for_training( - model_graph, - example_inputs, + model_graph, example_inputs, strict=True ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( @@ -829,8 +825,7 @@ def forward(self, input_tensor, hidden_tensor): with torchdynamo.config.patch(allow_rnn=True): model_graph = export_for_training( - model_graph, - example_inputs, + model_graph, example_inputs, strict=True ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( @@ -1039,10 +1034,7 @@ def test_resnet18(self): m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/test/run_test.py b/test/run_test.py index efa7e46554cb..d341a182e29b 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -171,19 +171,10 @@ def __contains__(self, item): "distributed/rpc/test_tensorpipe_agent", "distributed/rpc/test_share_memory", "distributed/rpc/cuda/test_tensorpipe_agent", - "distributed/_shard/checkpoint/test_checkpoint" - "distributed/_shard/checkpoint/test_file_system_checkpoint" - "distributed/_shard/sharding_spec/test_sharding_spec", - "distributed/_shard/sharded_tensor/ops/test_embedding", - "distributed/_shard/sharded_tensor/ops/test_embedding_bag", - "distributed/_shard/sharded_tensor/ops/test_binary_cmp", - "distributed/_shard/sharded_tensor/ops/test_init", - "distributed/_shard/sharded_optim/test_sharded_optim", "test_determination", "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", - "distributed/tensor/test_attention", ] S390X_BLOCKLIST = [ diff --git a/test/slow_tests.json b/test/slow_tests.json index 7434d944c2d0..bbda0f96278f 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,305 +1,295 @@ { - "EndToEndLSTM (__main__.RNNTest)": 187.95632934570312, - "MultiheadAttention (__main__.ModulesTest)": 137.24066670735678, - "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 216.9356689453125, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 159.3027776082357, - "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 112.10600026448567, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 65.38766564263238, - "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 173.56966654459634, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 75.28399658203125, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 163.19466654459634, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 88.1193339029948, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.00295284816197, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 83.75133260091145, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 691.2717827690972, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 117.44299926757813, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 503.3826666937934, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 503.24066840277777, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 126.52850087483723, - "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 61.86766688028971, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 163.50066630045572, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 97.42933400472005, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 308.0576663547092, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 134.46916961669922, - "test_builtin_equivalent_funcs (__main__.TorchFunctionModeTests)": 81.6673030275287, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 355.91133287217883, - "test_comprehensive_constant_pad_nd_cpu_float16 (__main__.TestInductorOpInfoCPU)": 73.32400004069011, - "test_comprehensive_constant_pad_nd_cpu_float32 (__main__.TestInductorOpInfoCPU)": 70.80933125813802, - "test_comprehensive_constant_pad_nd_cpu_float64 (__main__.TestInductorOpInfoCPU)": 70.98533376057942, - "test_comprehensive_constant_pad_nd_cpu_int32 (__main__.TestInductorOpInfoCPU)": 67.57033284505208, - "test_comprehensive_constant_pad_nd_cpu_int64 (__main__.TestInductorOpInfoCPU)": 70.75233205159505, - "test_comprehensive_diff_cpu_bool (__main__.TestInductorOpInfoCPU)": 102.2750015258789, - "test_comprehensive_diff_cpu_float32 (__main__.TestInductorOpInfoCPU)": 103.07066599527995, - "test_comprehensive_diff_cpu_float64 (__main__.TestInductorOpInfoCPU)": 105.27833557128906, - "test_comprehensive_diff_cpu_int32 (__main__.TestInductorOpInfoCPU)": 100.10233561197917, - "test_comprehensive_diff_cpu_int64 (__main__.TestInductorOpInfoCPU)": 102.20266977945964, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 93.59800084431966, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 93.51633326212566, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 62.04499944051107, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 63.05183347066244, - "test_comprehensive_dist_cpu_float16 (__main__.TestInductorOpInfoCPU)": 86.4076639811198, - "test_comprehensive_dist_cpu_float32 (__main__.TestInductorOpInfoCPU)": 81.19499969482422, - "test_comprehensive_dist_cpu_float64 (__main__.TestInductorOpInfoCPU)": 86.38233439127605, - "test_comprehensive_eye_cpu_bool (__main__.TestInductorOpInfoCPU)": 124.90833536783855, - "test_comprehensive_eye_cpu_float16 (__main__.TestInductorOpInfoCPU)": 123.35333251953125, - "test_comprehensive_eye_cpu_float32 (__main__.TestInductorOpInfoCPU)": 121.35933430989583, - "test_comprehensive_eye_cpu_float64 (__main__.TestInductorOpInfoCPU)": 123.5403340657552, - "test_comprehensive_eye_cpu_int32 (__main__.TestInductorOpInfoCPU)": 120.98033396402995, - "test_comprehensive_eye_cpu_int64 (__main__.TestInductorOpInfoCPU)": 124.76566823323567, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 71.77733357747395, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 83.0576655069987, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 83.4250005086263, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 353.14801025390625, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 79.26999918619792, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 329.7780049641927, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 80.16866556803386, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 273.2213312784831, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 249.29500325520834, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 988.9061686197916, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.60549990336101, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1203.5001627604167, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.39716657002766, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.78449948628743, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.64166704813639, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 89.50603711163556, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.10983276367188, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 61.83733304341634, - "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 203.87232971191406, - "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 203.09432983398438, - "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 199.30699666341147, - "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.42596266004774, - "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 67.53049977620442, - "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 423.8486633300781, - "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 425.5379943847656, - "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 403.22300211588544, - "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 409.60033162434894, - "test_comprehensive_masked_amax_cpu_float16 (__main__.TestInductorOpInfoCPU)": 93.37733459472656, - "test_comprehensive_masked_amax_cpu_float32 (__main__.TestInductorOpInfoCPU)": 99.49733225504558, - "test_comprehensive_masked_amax_cpu_float64 (__main__.TestInductorOpInfoCPU)": 94.82899983723958, - "test_comprehensive_masked_amax_cpu_int32 (__main__.TestInductorOpInfoCPU)": 89.32633209228516, - "test_comprehensive_masked_amax_cpu_int64 (__main__.TestInductorOpInfoCPU)": 90.41433207194011, - "test_comprehensive_masked_amin_cpu_float16 (__main__.TestInductorOpInfoCPU)": 95.9903335571289, - "test_comprehensive_masked_amin_cpu_float32 (__main__.TestInductorOpInfoCPU)": 95.3953348795573, - "test_comprehensive_masked_amin_cpu_float64 (__main__.TestInductorOpInfoCPU)": 93.07833607991536, - "test_comprehensive_masked_amin_cpu_int32 (__main__.TestInductorOpInfoCPU)": 89.55566660563152, - "test_comprehensive_masked_amin_cpu_int64 (__main__.TestInductorOpInfoCPU)": 86.22466786702473, - "test_comprehensive_masked_mean_cpu_float16 (__main__.TestInductorOpInfoCPU)": 94.80033111572266, - "test_comprehensive_masked_mean_cpu_float32 (__main__.TestInductorOpInfoCPU)": 93.42666625976562, - "test_comprehensive_masked_mean_cpu_float64 (__main__.TestInductorOpInfoCPU)": 93.45800018310547, - "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 466.69366455078125, - "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 464.84532674153644, - "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 468.4709981282552, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 125.94750086466472, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 120.40383402506511, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 133.67750295003256, - "test_comprehensive_masked_prod_cpu_bool (__main__.TestInductorOpInfoCPU)": 90.84866841634114, - "test_comprehensive_masked_prod_cpu_float16 (__main__.TestInductorOpInfoCPU)": 96.20899963378906, - "test_comprehensive_masked_prod_cpu_float32 (__main__.TestInductorOpInfoCPU)": 90.58700052897136, - "test_comprehensive_masked_prod_cpu_float64 (__main__.TestInductorOpInfoCPU)": 99.9510014851888, - "test_comprehensive_masked_prod_cpu_int32 (__main__.TestInductorOpInfoCPU)": 94.47566731770833, - "test_comprehensive_masked_prod_cpu_int64 (__main__.TestInductorOpInfoCPU)": 89.86966705322266, - "test_comprehensive_masked_sum_cpu_bool (__main__.TestInductorOpInfoCPU)": 89.43766530354817, - "test_comprehensive_masked_sum_cpu_float16 (__main__.TestInductorOpInfoCPU)": 97.86233266194661, - "test_comprehensive_masked_sum_cpu_float32 (__main__.TestInductorOpInfoCPU)": 87.95466613769531, - "test_comprehensive_masked_sum_cpu_float64 (__main__.TestInductorOpInfoCPU)": 90.6480000813802, - "test_comprehensive_masked_sum_cpu_int32 (__main__.TestInductorOpInfoCPU)": 91.357666015625, - "test_comprehensive_masked_sum_cpu_int64 (__main__.TestInductorOpInfoCPU)": 94.107666015625, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 92.32383346557617, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.00616836547852, - "test_comprehensive_nn_functional_glu_cpu_float16 (__main__.TestInductorOpInfoCPU)": 71.70499928792317, - "test_comprehensive_nn_functional_glu_cpu_float32 (__main__.TestInductorOpInfoCPU)": 72.04166666666667, - "test_comprehensive_nn_functional_glu_cpu_float64 (__main__.TestInductorOpInfoCPU)": 74.28933461507161, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 87.73799896240234, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 81.04799906412761, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 242.09933217366537, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 256.25333404541016, - "test_comprehensive_nn_functional_interpolate_bicubic_cpu_uint8 (__main__.TestInductorOpInfoCPU)": 60.534000396728516, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 75.61316553751628, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 76.84416834513347, - "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 188.01399739583334, - "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 186.28333536783853, - "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 185.177001953125, - "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 974.4946695963541, - "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 874.4259847005209, - "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 882.1919962565104, - "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 836.5886433919271, - "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 833.1363525390625, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 848.4001770019531, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 855.3283386230469, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 863.9473368326823, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 198.97533671061197, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 199.50466918945312, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 204.54600524902344, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 72.47933260599773, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 126.71599833170573, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 128.18866729736328, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 125.28499857584636, - "test_comprehensive_nn_functional_pad_constant_cpu_float16 (__main__.TestInductorOpInfoCPU)": 68.7433344523112, - "test_comprehensive_nn_functional_pad_constant_cpu_float32 (__main__.TestInductorOpInfoCPU)": 69.4153315226237, - "test_comprehensive_nn_functional_pad_constant_cpu_float64 (__main__.TestInductorOpInfoCPU)": 69.83100128173828, - "test_comprehensive_nn_functional_pad_constant_cpu_int32 (__main__.TestInductorOpInfoCPU)": 67.97833251953125, - "test_comprehensive_nn_functional_pad_constant_cpu_int64 (__main__.TestInductorOpInfoCPU)": 68.58200073242188, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float16 (__main__.TestInductorOpInfoCPU)": 123.47900136311848, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestInductorOpInfoCPU)": 114.12900034586589, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float64 (__main__.TestInductorOpInfoCPU)": 118.65166473388672, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int32 (__main__.TestInductorOpInfoCPU)": 115.42100016276042, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int64 (__main__.TestInductorOpInfoCPU)": 111.11299896240234, - "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 131.9026641845703, - "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 229.06666564941406, - "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 230.85599772135416, - "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 229.9073282877604, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 115.78150049845378, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 109.48800150553386, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 62.58650016784668, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.28583272298177, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 68.01150004069011, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 68.31016667683919, - "test_cond_autograd_nested (__main__.TestControlFlow)": 108.9411112467448, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 100.8696657816569, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 91.36616770426433, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 80.7226676940918, - "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 70.30566660563152, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 224.8618867662218, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 521.492443508572, - "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 73.4326680501302, - "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.06488927205403, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 95.46499888102214, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 83.36849975585938, - "test_count_nonzero_all (__main__.TestBool)": 621.7835659450955, - "test_cusparse_multiple_threads_same_device (__main__.TestCuda)": 88.83855459425185, - "test_custom_module_lstm (__main__.TestQuantizedOps)": 782.4322068956163, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 83.058167775472, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDTensorOpsCPU)": 95.14833323160808, - "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 243.97850879033408, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 66.31777699788411, - "test_fail_creation_ops.py (__main__.TestTyping)": 73.5800605542732, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 80.3489990234375, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 84.08483378092448, - "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 76.93700154622395, - "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 492.0260009765625, - "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 327.5421651204427, - "test_forward_ad_svd_lowrank_cpu_float32 (__main__.TestCompositeComplianceCPU)": 68.51100158691406, - "test_fuse_large_params_cpu (__main__.CpuTests)": 78.46166653103299, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 160.96700032552084, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 166.3767784966363, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 85.631165822347, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 99.18250020345052, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 91.73133341471355, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 108.87999979654948, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 118.38499959309895, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 203.54966990152994, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 123.6168327331543, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 140.19833119710287, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 576.1204986572266, - "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 194.1616668701172, - "test_group_norm (__main__.TestQuantizedOps)": 240.9851115544637, - "test_indexing (__main__.TestAutogradWithCompiledAutograd)": 88.37566757202148, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 261.59466552734375, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.35699971516927, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 98.20366668701172, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 133.5326656765408, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.19766489664714, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 597.1708386739095, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 77.99583435058594, - "test_linalg_solve_triangular_large_cuda_float64 (__main__.TestLinalgCUDA)": 95.48333422342937, - "test_linear (__main__.TestStaticQuantizedModule)": 201.3015539381239, - "test_linear_relu (__main__.TestStaticQuantizedModule)": 198.11822424994574, - "test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 111.03733523686726, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 70.34333419799805, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 114.91833411322699, - "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 577.1563313802084, - "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 135.72266642252603, - "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 452.1196695963542, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.14066653781467, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 63.6434457567003, - "test_memory_format_operators_cpu (__main__.TestTorchDeviceTypeCPU)": 74.79505585485862, - "test_nccl_non_blocking_wait_with_barrier (__main__.NcclErrorHandlingTest)": 69.80233256022136, - "test_proper_exit (__main__.TestDataLoader)": 229.3759969075521, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 254.74083709716797, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 65.72250080108643, - "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 153.51377783881293, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 60.370178349812825, - "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 93.32955551147461, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 126.44500223795573, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 87.6626688639323, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 93.46333312988281, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 125.66800181070964, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 86.86966705322266, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 94.73033396402995, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 123.07366689046223, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 85.68800099690755, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.31833394368489, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 135.38099670410156, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 93.70433298746745, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.45233154296875, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 132.78799947102866, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 89.87099965413411, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.96466827392578, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 128.1060002644857, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 92.87266794840495, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 95.6653340657552, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 137.1143341064453, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.06833394368489, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.31199900309245, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 125.58800252278645, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 92.6046651204427, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.97366587320964, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 328.73166910807294, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 796.4181518554688, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 544.1849975585938, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1095.7953186035156, - "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 61.016167958577476, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 72.16200129191081, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 203.41483052571616, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 87.03033192952473, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 174.36033376057944, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 61.354000091552734, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 111.04199981689453, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 75.12533315022786, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 132.99366505940756, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 91.84250005086263, - "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 119.75666681925456, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.09033457438152, - "test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.16733169555664, - "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 67.3231650988261, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 109.39099884033203, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 141.49566650390625, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 146.1365534464518, - "test_sort_stable_cpu (__main__.CpuTritonTests)": 76.30600229899089, - "test_sparse_gradients (__main__.DistributedDataParallelTest)": 104.54216623306274, - "test_split_cumsum_cpu (__main__.CpuTritonTests)": 89.89100138346355, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 182.42533469200134, - "test_terminate_handler_on_crash (__main__.TestTorch)": 100.94433457321591, - "test_terminate_signal (__main__.ForkTest)": 137.2934450134635, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 137.37077751341792, - "test_terminate_signal (__main__.SpawnTest)": 139.78100167380438, - "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 95.95577812194824, - "test_transformer_backend_inductor_fullgraph_True (__main__.TestFullyShardCompile)": 95.92808405558269, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 70.52516492207845, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 67.16916783650716, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 77.5228328704834, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 118.96799850463867, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 117.68416659037273, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 100.82866477966309, - "test_unary_ops (__main__.TestTEFuserDynamic)": 174.203000386556, - "test_unary_ops (__main__.TestTEFuserStatic)": 162.15266492631702, - "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 96.66299947102864, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 79.73133341471355, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 75.25083287556966, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 86.45800018310547, - "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 63.62757146926153, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 75.84433301289876, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 71.28416697184245, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 71.99233182271321, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.53566614786784, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 91.30016708374023, - "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 64.31961922418503, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 72.40933481852214, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 80.20533307393391, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 76.08066749572754, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 64.3009055001395, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 80.77966817220052, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 88.6466687520345, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 85.9961675008138, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 139.46716435750326 + "EndToEndLSTM (__main__.RNNTest)": 181.61566162109375, + "MultiheadAttention (__main__.ModulesTest)": 136.4750010172526, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 151.13477834065756, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 62.89133326212565, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.0672378540039, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 83.74566650390625, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 151.28533426920572, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 76.96799977620442, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 86.31200154622395, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 574.2255004882812, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 112.03270034790039, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.27077229817706, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 485.51055908203125, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 120.31833012898763, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 132.48300425211588, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 94.59216562906902, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 301.2558898925781, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 133.76050313313803, + "test_builtin_equivalent_funcs (__main__.TorchFunctionModeTests)": 82.59212158665513, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 328.1419949001736, + "test_comprehensive_constant_pad_nd_cpu_float16 (__main__.TestInductorOpInfoCPU)": 70.02100118001302, + "test_comprehensive_constant_pad_nd_cpu_float32 (__main__.TestInductorOpInfoCPU)": 74.87266540527344, + "test_comprehensive_constant_pad_nd_cpu_float64 (__main__.TestInductorOpInfoCPU)": 69.08433278401692, + "test_comprehensive_constant_pad_nd_cpu_int32 (__main__.TestInductorOpInfoCPU)": 72.38800303141277, + "test_comprehensive_constant_pad_nd_cpu_int64 (__main__.TestInductorOpInfoCPU)": 68.1750005086263, + "test_comprehensive_diff_cpu_bool (__main__.TestInductorOpInfoCPU)": 103.65033467610677, + "test_comprehensive_diff_cpu_float32 (__main__.TestInductorOpInfoCPU)": 113.07499694824219, + "test_comprehensive_diff_cpu_float64 (__main__.TestInductorOpInfoCPU)": 102.12066650390625, + "test_comprehensive_diff_cpu_int32 (__main__.TestInductorOpInfoCPU)": 102.65233357747395, + "test_comprehensive_diff_cpu_int64 (__main__.TestInductorOpInfoCPU)": 110.2530008951823, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 82.65933227539062, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 79.33149973551433, + "test_comprehensive_dist_cpu_float16 (__main__.TestInductorOpInfoCPU)": 85.96466573079427, + "test_comprehensive_dist_cpu_float32 (__main__.TestInductorOpInfoCPU)": 82.62900034586589, + "test_comprehensive_dist_cpu_float64 (__main__.TestInductorOpInfoCPU)": 84.08733367919922, + "test_comprehensive_eye_cpu_bool (__main__.TestInductorOpInfoCPU)": 126.77433268229167, + "test_comprehensive_eye_cpu_float16 (__main__.TestInductorOpInfoCPU)": 129.90166727701822, + "test_comprehensive_eye_cpu_float32 (__main__.TestInductorOpInfoCPU)": 130.88333129882812, + "test_comprehensive_eye_cpu_float64 (__main__.TestInductorOpInfoCPU)": 128.96799977620444, + "test_comprehensive_eye_cpu_int32 (__main__.TestInductorOpInfoCPU)": 124.88400014241536, + "test_comprehensive_eye_cpu_int64 (__main__.TestInductorOpInfoCPU)": 125.25633239746094, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 85.94333394368489, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 76.60233306884766, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 82.14366658528645, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 333.54833984375, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 83.21299997965495, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 348.1693420410156, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 86.17266591389973, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 233.37083180745444, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 240.9846674601237, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 922.0073445638021, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.75899823506673, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 957.1233317057291, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 65.89716720581055, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 60.13633410135905, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.52150026957194, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.30249913533528, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.91258398691813, + "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 211.29166666666666, + "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 201.2199961344401, + "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 201.05166625976562, + "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 448.63133748372394, + "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 435.1319986979167, + "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 414.0263366699219, + "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 428.4053446451823, + "test_comprehensive_masked_amax_cpu_float16 (__main__.TestInductorOpInfoCPU)": 97.50900014241536, + "test_comprehensive_masked_amax_cpu_float32 (__main__.TestInductorOpInfoCPU)": 96.23233286539714, + "test_comprehensive_masked_amax_cpu_float64 (__main__.TestInductorOpInfoCPU)": 98.6259994506836, + "test_comprehensive_masked_amax_cpu_int32 (__main__.TestInductorOpInfoCPU)": 99.11599985758464, + "test_comprehensive_masked_amax_cpu_int64 (__main__.TestInductorOpInfoCPU)": 89.52233632405598, + "test_comprehensive_masked_amin_cpu_float16 (__main__.TestInductorOpInfoCPU)": 100.05933125813802, + "test_comprehensive_masked_amin_cpu_float32 (__main__.TestInductorOpInfoCPU)": 92.08133188883464, + "test_comprehensive_masked_amin_cpu_float64 (__main__.TestInductorOpInfoCPU)": 102.49733479817708, + "test_comprehensive_masked_amin_cpu_int32 (__main__.TestInductorOpInfoCPU)": 93.6953353881836, + "test_comprehensive_masked_amin_cpu_int64 (__main__.TestInductorOpInfoCPU)": 94.12633260091145, + "test_comprehensive_masked_mean_cpu_float16 (__main__.TestInductorOpInfoCPU)": 90.63199869791667, + "test_comprehensive_masked_mean_cpu_float32 (__main__.TestInductorOpInfoCPU)": 93.61466471354167, + "test_comprehensive_masked_mean_cpu_float64 (__main__.TestInductorOpInfoCPU)": 95.45333353678386, + "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 471.6109924316406, + "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 478.7690022786458, + "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 483.9660135904948, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 104.6216672261556, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 125.5418332417806, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 102.33516438802083, + "test_comprehensive_masked_prod_cpu_bool (__main__.TestInductorOpInfoCPU)": 91.59299977620442, + "test_comprehensive_masked_prod_cpu_float16 (__main__.TestInductorOpInfoCPU)": 96.90999857584636, + "test_comprehensive_masked_prod_cpu_float32 (__main__.TestInductorOpInfoCPU)": 95.03333282470703, + "test_comprehensive_masked_prod_cpu_float64 (__main__.TestInductorOpInfoCPU)": 96.96366628011067, + "test_comprehensive_masked_prod_cpu_int32 (__main__.TestInductorOpInfoCPU)": 93.97466532389323, + "test_comprehensive_masked_prod_cpu_int64 (__main__.TestInductorOpInfoCPU)": 91.50166829427083, + "test_comprehensive_masked_sum_cpu_bool (__main__.TestInductorOpInfoCPU)": 96.39866892496745, + "test_comprehensive_masked_sum_cpu_float16 (__main__.TestInductorOpInfoCPU)": 93.34033457438152, + "test_comprehensive_masked_sum_cpu_float32 (__main__.TestInductorOpInfoCPU)": 97.53666687011719, + "test_comprehensive_masked_sum_cpu_float64 (__main__.TestInductorOpInfoCPU)": 87.80099995930989, + "test_comprehensive_masked_sum_cpu_int32 (__main__.TestInductorOpInfoCPU)": 98.83033243815105, + "test_comprehensive_masked_sum_cpu_int64 (__main__.TestInductorOpInfoCPU)": 98.4626693725586, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 83.55299886067708, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 82.62733205159505, + "test_comprehensive_nn_functional_glu_cpu_float16 (__main__.TestInductorOpInfoCPU)": 74.2403335571289, + "test_comprehensive_nn_functional_glu_cpu_float32 (__main__.TestInductorOpInfoCPU)": 73.23299916585286, + "test_comprehensive_nn_functional_glu_cpu_float64 (__main__.TestInductorOpInfoCPU)": 74.39199829101562, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 88.33433532714844, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 88.76199849446614, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 212.46066538492838, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 215.0308354695638, + "test_comprehensive_nn_functional_interpolate_bicubic_cpu_uint8 (__main__.TestInductorOpInfoCPU)": 63.83266576131185, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 72.45750109354655, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 84.0174986521403, + "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 184.9403330485026, + "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 186.5510050455729, + "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 182.49533081054688, + "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1921.0713297526042, + "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1740.4580078125, + "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1776.2012939453125, + "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 1599.7586263020833, + "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 1617.5953369140625, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1075.434326171875, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1099.4353332519531, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1125.7143249511719, + "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 950.7925381130642, + "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 812.9512176513672, + "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 829.8953365749783, + "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 889.6016608344185, + "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 899.8731655544705, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 464.61108271280926, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 459.0932896931966, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 197.40933227539062, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 200.65933227539062, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 195.45833333333334, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 131.92733510335287, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 133.2663319905599, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 124.51133473714192, + "test_comprehensive_nn_functional_pad_constant_cpu_float16 (__main__.TestInductorOpInfoCPU)": 70.27266693115234, + "test_comprehensive_nn_functional_pad_constant_cpu_float32 (__main__.TestInductorOpInfoCPU)": 68.51133219401042, + "test_comprehensive_nn_functional_pad_constant_cpu_float64 (__main__.TestInductorOpInfoCPU)": 72.49566650390625, + "test_comprehensive_nn_functional_pad_constant_cpu_int32 (__main__.TestInductorOpInfoCPU)": 70.40933481852214, + "test_comprehensive_nn_functional_pad_constant_cpu_int64 (__main__.TestInductorOpInfoCPU)": 69.66466522216797, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float16 (__main__.TestInductorOpInfoCPU)": 119.83233388264973, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestInductorOpInfoCPU)": 114.48733266194661, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float64 (__main__.TestInductorOpInfoCPU)": 120.08599853515625, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int32 (__main__.TestInductorOpInfoCPU)": 127.59833017985027, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int64 (__main__.TestInductorOpInfoCPU)": 120.84366353352864, + "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 123.23733266194661, + "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 231.90233357747397, + "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 225.70599873860678, + "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 237.0050048828125, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 102.30183537801106, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 105.99450047810872, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 64.52433395385742, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 67.21816571553548, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 63.552083015441895, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 68.47633298238118, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 63.37950070699056, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 86.54149881998698, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 82.44583511352539, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 70.82466634114583, + "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 61.23749987284342, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 174.3193367852105, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 325.76544019911023, + "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 79.62433369954427, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 75.77216720581055, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 71.49933369954427, + "test_count_nonzero_all (__main__.TestBool)": 607.5451117621528, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 795.3888888888889, + "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 185.69566524028778, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 81.4071667989095, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 211.3948280016581, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 67.51188871595595, + "test_fail_random.py (__main__.TestTyping)": 72.58257559574011, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 83.84800211588542, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 103.40933100382487, + "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 84.23000081380208, + "test_fn_gradgrad_map_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 61.798926176848234, + "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 506.09766642252606, + "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 345.16650390625, + "test_forward_ad_svd_lowrank_cpu_float32 (__main__.TestCompositeComplianceCPU)": 87.62466684977214, + "test_fuse_large_params_cpu (__main__.CpuTests)": 74.28099937438965, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 167.2507781982422, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 170.01244269476996, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 85.7643330891927, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 82.86033376057942, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 92.42216618855794, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 93.47633107503255, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 95.30599975585938, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 204.81483459472656, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.41116460164388, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 143.45366795857748, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 545.9098256429037, + "test_group_norm (__main__.TestQuantizedOps)": 106.1533326043023, + "test_indexing (__main__.TestAutogradWithCompiledAutograd)": 73.70611148410373, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 260.4870096842448, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 67.80922275119357, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 104.3433354695638, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 131.30877770317926, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.27200317382812, + "test_large_bmm_bfloat16 (__main__.TestMPS)": 1425.0152994791667, + "test_large_bmm_float16 (__main__.TestMPS)": 1253.7086181640625, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 626.6733360290527, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 83.12500190734863, + "test_linalg_solve_triangular_large_cuda_float64 (__main__.TestLinalgCUDA)": 95.26766586303711, + "test_linear (__main__.TestStaticQuantizedModule)": 111.07222196790907, + "test_linear_relu (__main__.TestStaticQuantizedModule)": 186.7755593193902, + "test_low_memory_max_pool_dilation_1_dim_2_cpu_halide (__main__.HalideCpuTests)": 60.36066691080729, + "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 654.1646728515625, + "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 516.5246785481771, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 78.14400100708008, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 111.50900014241536, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 64.09444597032335, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 61.463999854193794, + "test_proper_exit (__main__.TestDataLoader)": 224.20867156982422, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 223.44366709391275, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 63.79383373260498, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 135.53544277615018, + "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 126.51955371432834, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 124.93733469645183, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.31800079345703, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.14400227864583, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 131.42633819580078, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 87.89099884033203, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.00499979654948, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 130.41100311279297, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 80.20366668701172, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.31099955240886, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 132.61800130208334, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 89.43633270263672, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 87.20300038655598, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 141.2316640218099, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 85.5403340657552, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.23933410644531, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 134.7986628214518, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 84.30833180745442, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 98.22066752115886, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 126.87933603922527, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 80.81599934895833, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.3933334350586, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 136.1326649983724, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 86.34866587320964, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.5760014851888, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 339.4856669108073, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 688.2038370768229, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 563.0550130208334, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1057.2596740722656, + "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 60.16333325703939, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 65.99533589680989, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 225.16200002034506, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 84.37333424886067, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 161.50900268554688, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 72.27966562906902, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 110.45833333333333, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 69.66266632080078, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 133.47400283813477, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 83.216002146403, + "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 113.03650029500325, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 66.53833262125652, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 118.47366658846538, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 142.11158307393393, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 106.017333984375, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 130.9388910929362, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 147.9407755533854, + "test_sum_all_cpu_float64 (__main__.TestReductionsCPU)": 252.02999792054848, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 232.58900745709738, + "test_terminate_handler_on_crash (__main__.TestTorch)": 100.51200015015073, + "test_terminate_signal (__main__.ForkTest)": 137.1949984199471, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 137.9567770593696, + "test_terminate_signal (__main__.SpawnTest)": 140.8028925259908, + "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.12111282348633, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 124.51886669516874, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 79.41683387756348, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 78.52750078837077, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 83.28249867757161, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 117.90133094787598, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 118.63733228047688, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 100.5381685892741, + "test_unary_ops (__main__.TestTEFuserDynamic)": 173.42911020914713, + "test_unary_ops (__main__.TestTEFuserStatic)": 158.1659984588623, + "test_unwaited (__main__.CommTest)": 60.680667877197266, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 94.42666371663411, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 74.38800048828125, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.0030008951823, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 89.27025032043457, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 67.01800028483073, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 68.22083346048991, + "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 63.439666112264, + "test_vmapjvpvjp_linalg_solve_triangular_cuda_float32 (__main__.TestOperatorsCUDA)": 63.80483341217041, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 65.49283345540364, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 81.31166585286458, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 67.66283289591472, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.59249941507976, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 69.74266688028972, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 70.75883356730144, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 74.43816630045573, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 72.2706667582194, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 129.9483324686686 } \ No newline at end of file diff --git a/test/test_cuda.py b/test/test_cuda.py index a3cc62c5e1d4..192b41fed324 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -595,6 +595,65 @@ def test_serialization_array_with_storage(self): q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Does not work in fbcode yet") + @setBlasBackendsToDefaultFinally + def test_preferred_blas_library_settings(self): + def _check_default(): + default = torch.backends.cuda.preferred_blas_library() + if torch.version.cuda: + # CUDA logic is easy, it's always cublas + self.assertTrue(default == torch._C._BlasBackend.Cublas) + else: + # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else + gcn_arch = str( + torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] + ) + if gcn_arch in ["gfx90a", "gfx942", "gfx950"]: + self.assertTrue(default == torch._C._BlasBackend.Cublaslt) + else: + self.assertTrue(default == torch._C._BlasBackend.Cublas) + + _check_default() + # "Default" can be set but is immediately reset internally to the actual default value. + self.assertTrue( + torch.backends.cuda.preferred_blas_library("default") + != torch._C._BlasBackend.Default + ) + _check_default() + self.assertTrue( + torch.backends.cuda.preferred_blas_library("cublas") + == torch._C._BlasBackend.Cublas + ) + self.assertTrue( + torch.backends.cuda.preferred_blas_library("hipblas") + == torch._C._BlasBackend.Cublas + ) + # check bad strings + with self.assertRaisesRegex( + RuntimeError, + "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.", + ): + torch.backends.cuda.preferred_blas_library("unknown") + # check bad input type + with self.assertRaisesRegex(RuntimeError, "Unknown input value type."): + torch.backends.cuda.preferred_blas_library(1.0) + # check env var override + custom_envs = [ + {"TORCH_BLAS_PREFER_CUBLASLT": "1"}, + {"TORCH_BLAS_PREFER_HIPBLASLT": "1"}, + ] + test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())" + for env_config in custom_envs: + env = os.environ.copy() + for key, value in env_config.items(): + env[key] = value + r = ( + subprocess.check_output([sys.executable, "-c", test_script], env=env) + .decode("ascii") + .strip() + ) + self.assertEqual("_BlasBackend.Cublaslt", r) + @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") @setBlasBackendsToDefaultFinally def test_cublas_workspace_explicit_allocation(self): diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index c92edc279f55..a2691d5e1cbf 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3910,6 +3910,127 @@ def fvmap2(info, in_dims, x, y): self.assertTrue(called) self.assertEqual(result, x + y) + @skipIfTorchDynamo("Skip due to sys.refcount") + def test_any_requires_grad(self): + test_fn = torch._C._any_requires_grad + # Regression test on not leaking kwargs + t = torch.randn(2, 2) + t_refcount = sys.getrefcount(t) + test_fn(t, a=t) + self.assertEqual(sys.getrefcount(t), t_refcount) + + self.assertTrue( + test_fn( + torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True) + ) + ) + self.assertFalse(test_fn(torch.ones(1), torch.zeros(1))) + self.assertTrue( + test_fn( + [torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True)] + ) + ) + # _C_any_requires_grad supports only List[Tensor] in args, not List[List[Tensor]] + self.assertFalse(test_fn([[torch.zeros(1, requires_grad=True)]], torch.ones(1))) + self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)])) + self.assertTrue(test_fn(torch.zeros(1), a=torch.ones(1, requires_grad=True))) + self.assertFalse(test_fn(torch.zeros(1), a=torch.ones(1))) + self.assertTrue( + test_fn([torch.zeros(1, requires_grad=True), torch.ones(1)], torch.zeros(1)) + ) + self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)], torch.zeros(1))) + + @skipIfTorchDynamo("Skip due to sys.refcount") + def test_any_output_is_alias_to_input_or_output(self): + test_fn = torch._C._any_output_is_alias_to_input_or_output + # Regression test on not leaking kwargs + t = torch.randn(2, 2) + t_refcount = sys.getrefcount(t) + test_fn((t,), {"a": t}, ()) + assert sys.getrefcount(t) == t_refcount + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + self.assertTrue( + test_fn( + (x,), + {}, + (x.t(),), + ) + ) + self.assertFalse(test_fn((x,), None, (2 * x,))) + self.assertTrue( + test_fn( + (), + {"a": x.view(-1)}, + (x,), + ) + ) + self.assertTrue( + test_fn( + (), + {"a": x.view(-1)}, + (x.t(),), + ) + ) + self.assertTrue(test_fn((y,), {}, (y[1:],))) + self.assertFalse( + test_fn( + (x,), + {"a": x}, + (), + ) + ) + self.assertFalse( + test_fn( + (torch.tensor([]),), + {}, + (torch.tensor([]),), + ) + ) + self.assertTrue( + test_fn( + ([x], x + 1), + {}, + (x.t(),), + ) + ) + self.assertTrue( + test_fn( + ([x], x + 1), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([x], x), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([x, 1], x), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([[x]], x), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([[1, x], 2], 3), + {}, + ([x.t()], x + 1), + ) + ) + class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index a3458efbe65b..224846681500 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -1329,6 +1329,22 @@ def test_tensor_factory_with_symint(self): res = Tensor(sym_args) self.assertEqual(res, expected, exact_dtype=False) + def test_backed_size_oblivious_01_spec(self): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + @torch.compile(dynamic=True, fullgraph=True) + def f(a, b): + if guard_size_oblivious(a.size(0) == 1): + return b * 10 + else: + return b * 20 + + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + # always go to the >= 2 branch. + self.assertEqual( + f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) + ) + @skipIfTorchDynamo( "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" @@ -2815,6 +2831,136 @@ def test_guards_float_print(self): guards = shape_env.produce_guards_expression([s0]) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_guard_or_true(self): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + def func(a, b): + x = a.item() + if guard_or_true(x == 1): + return b * 10 + else: + return b * 20 + + # call with guarding. + self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])) + self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])) + + unbacked_func = torch.compile(func, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func(a, b) + + # always return b*10 + self.assertEqual( + unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]) + ) + self.assertEqual( + unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) + ) + + # Test that statically known true works. + def func2(a, b): + x = a.item() + if guard_or_true(x != x): + return b * 10 + else: + return b * 20 + + unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func2(a, b) + # always return b*20 + self.assertEqual( + unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) + ) + self.assertEqual( + unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) + ) + + # Test backed_size_oblivious + with torch.fx.experimental._config.patch("backed_size_oblivious", True): + + def func3(a, b): + if guard_or_true(a.size()[0] != 9): + return b * 10 + else: + return b * 20 + + compiled = torch.compile(func3, dynamic=True, fullgraph=True) + a = torch.rand(9, 2) + b = torch.rand(3, 4) + + self.assertEqual(func3(a, b), b * 20) + self.assertEqual(compiled(a, b), b * 10) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_guard_or_false(self): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + def func(a, b): + x = a.item() + if guard_or_false(x == 1): + return b * 10 + else: + return b * 20 + + # call with guarding. + self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])) + self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])) + + unbacked_func = torch.compile(func, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func(a, b) + + # always return b*20 + self.assertEqual( + unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) + ) + self.assertEqual( + unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) + ) + + # Test that statically known true works. + def func2(a, b): + x = a.item() + if guard_or_false(x == x): + return b * 10 + else: + return b * 20 + + unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func2(a, b) + # always return b*10 + self.assertEqual( + unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]) + ) + self.assertEqual( + unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) + ) + + # Test backed_size_oblivious + with torch.fx.experimental._config.patch("backed_size_oblivious", True): + + def func3(a, b): + if guard_or_false(a.size()[0] == 9): + return b * 10 + else: + return b * 20 + + compiled = torch.compile(func3, dynamic=True, fullgraph=True) + a = torch.rand(9, 2) + b = torch.rand(3, 4) + + self.assertEqual(func3(a, b), b * 10) + self.assertEqual(compiled(a, b), b * 20) + def test_guards_float_div(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 8) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 1b99bd94061e..7dad38355e20 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -972,6 +972,14 @@ def add(x, y): self.assertIsInstance(r[0], FakeTensor) self.assertIsInstance(r[1], FakeTensor) + def test_fast_div(self): + mode = FakeTensorMode() + with mode: + x = torch.empty(2, 2, device="cpu", dtype=torch.int32) + from torch._subclasses.fake_impls import get_fast_op_impls + fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor] + y = fast_div(mode, x, 2) + self.assertEqual(y.dtype, torch.float32) instantiate_parametrized_tests(FakeTensorTest) diff --git a/test/test_fx.py b/test/test_fx.py index 5b54025d8d32..58e925f8633e 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1271,6 +1271,23 @@ def forward(self, x: torch.Tensor, y: int = 2): "call_module" ).check("clamp").check("call_method").run(all_formatted) + def test_print_graph(self): + op: torch._ops.OpOverload = torch.ops.aten.relu.default + type_name: str = torch.typename(op) + + graph: torch.fx.Graph = torch.fx.Graph() + a: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node("call_function", op, (a,), type_expr=type_name) + c: torch.fx.Node = graph.create_node("call_function", op, (b,), type_expr=type_name) + graph.output((b, c)) + + gm: torch.fx.GraphModule = torch.fx.GraphModule( + torch.nn.Module(), graph + ) + gm.graph.lint() + text = gm.print_readable(False) + assert 2 == text.count("_torch__ops_aten_aten_relu_") + def test_script_tensor_constant(self): # TorchScript seems to ignore attributes that start with `__`. # We used to call anonymous Tensor values `__tensor_constant*`, but diff --git a/test/test_linalg.py b/test/test_linalg.py index b49bed2a2e93..649c46b5404c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -65,22 +65,7 @@ def blaslt_supported_device(): return True return False -def set_tunableop_defaults(): - if not torch.cuda.is_available(): - # TunableOp not supported on CPU at this time. - return - - # disable TunableOp and restore to default values - torch.cuda.tunable.enable(False) - torch.cuda.tunable.record_untuned_enable(False) - torch.cuda.tunable.tuning_enable(True) - torch.cuda.tunable.set_max_tuning_duration(30) - torch.cuda.tunable.set_max_tuning_iterations(100) - torch.cuda.tunable.set_rotating_buffer_size(-1) - ordinal = torch.cuda.current_device() - torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv") - -def tunableop_matmul(device, dtype, offline=False): +def tunableop_matmul(device, dtype, result_filename=None, offline=False): # Helper function to test TunableOp in a subprocess # requires helper function since lambda function # not supported by multiprocessing module @@ -90,6 +75,9 @@ def tunableop_matmul(device, dtype, offline=False): if offline: torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) + else: + if result_filename is not None: + torch.cuda.tunable.set_filename(result_filename) torch.cuda.tunable.set_max_tuning_duration(1) A = torch.randn((17, 17), device=device, dtype=dtype) @@ -109,31 +97,13 @@ def find_tunableop_result(results, OpSig, ParamSig): return inner_tuple return None -def compare_untuned_tuned_entries(untuned_filename, tuned_filename): - # Compare the entries of untuned and tuned Tunableop results - # file. Verify that for each Op+Param Signature in the untuned file - # there is a matching one in the tuned results file. - import csv - ok = False - with open(untuned_filename) as file1: - with open(tuned_filename) as file2: - untuned_reader = csv.reader(file1) - untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader} - - tuned_reader = csv.reader(file2) - for _ in range(5): # Skip the first 5 lines for the validator - next(tuned_reader, None) - - result_csv_entries = {(row[0], row[1]) for row in tuned_reader} - - missing = untuned_csv_entries - result_csv_entries - - if missing: - ok = False - else: - ok = True - - return ok +def get_tunableop_untuned_filename(): + import os + ordinal = torch.cuda.current_device() + untuned_filename_env = os.getenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME") + untuned_filename_base, _, _ = untuned_filename_env.rpartition('.') + untuned_filename = f"{untuned_filename_base}{ordinal}.csv" + return untuned_filename class TestLinalg(TestCase): @contextlib.contextmanager @@ -165,7 +135,7 @@ def _tunableop_ctx(self): # Inialize and then tear down TunableOp import glob import os - set_tunableop_defaults() + self._set_tunableop_defaults() torch.cuda.tunable.enable(True) try: @@ -175,7 +145,13 @@ def _tunableop_ctx(self): torch.cuda.tunable.enable(False) # clean up, remove any files that were generated - for file in glob.glob("tunableop*.csv"): + results_filename = torch.cuda.tunable.get_filename() + results_filename_pattern, _, _ = results_filename.rpartition('.') + untuned_filename = get_tunableop_untuned_filename() + untuned_filename_pattern, _, _ = untuned_filename.rpartition('.') + patterns = [f"{results_filename_pattern[:-1]}*.csv", f"{untuned_filename_pattern[:-1]}*.csv"] + files = [f for pattern in patterns for f in glob.glob(pattern)] + for file in files: try: os.remove(file) # NB: The file is locked on Windows @@ -194,6 +170,59 @@ def _tunableop_ctx(self): except KeyError: pass + def _set_tunableop_defaults(self): + if not torch.cuda.is_available(): + # TunableOp not supported on CPU at this time. + return + + # disable TunableOp and restore to default values + torch.cuda.tunable.enable(False) + torch.cuda.tunable.record_untuned_enable(False) + torch.cuda.tunable.tuning_enable(True) + torch.cuda.tunable.set_max_tuning_duration(30) + torch.cuda.tunable.set_max_tuning_iterations(100) + torch.cuda.tunable.set_rotating_buffer_size(-1) + ordinal = torch.cuda.current_device() + + # Set filenames to be unique on a per test basis + import os + unique_id = self.id().split(".")[-1] + torch.cuda.tunable.set_filename(f"tunableop_results_{unique_id}_{ordinal}.csv") + # ordinal gets automatically appended + os.environ["PYTORCH_TUNABLEOP_UNTUNED_FILENAME"] = f"tunableop_untuned_{unique_id}_.csv" + + def _compare_untuned_tuned_entries(self, untuned_filename=None, tuned_filename=None): + # Compare the entries of untuned and tuned Tunableop results + # file. Verify that for each Op+Param Signature in the untuned file + # there is a matching one in the tuned results file. + import csv + ok = False + ordinal = torch.cuda.current_device() + if untuned_filename is None: + untuned_filename = get_tunableop_untuned_filename() + if tuned_filename is None: + tuned_filename = torch.cuda.tunable.get_filename() + + with open(untuned_filename) as file1: + with open(tuned_filename) as file2: + untuned_reader = csv.reader(file1) + untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader} + + tuned_reader = csv.reader(file2) + for _ in range(5): # Skip the first 5 lines for the validator + next(tuned_reader, None) + + result_csv_entries = {(row[0], row[1]) for row in tuned_reader} + + missing = untuned_csv_entries - result_csv_entries + + if missing: + ok = False + else: + ok = True + + return ok + exact_dtype = True @dtypes(torch.float, torch.cfloat) @@ -4693,16 +4722,18 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): make_arg = partial(make_tensor, device=device, dtype=dtype) # Using gen_sizes_matmul(2) to ensure we cover # 'NN', 'TN', 'TT', and 'NN' cases. - for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)): + for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2, y_dim=3), + (True, False), (True, False)): x = make_arg(size_x, noncontiguous=nctg_x) y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) filename1 = torch.cuda.tunable.get_filename() - filename2 = "tunableop_results_tmp1.csv" - filename3 = "tunableop_results_tmp2.csv" + unique_id = self.id().split(".")[-1] + filename2 = f"{filename1}_tmp1.csv" + filename3 = f"{filename1}_tmp2.csv" ordinal = torch.cuda.current_device() - assert filename1 == f"tunableop_results{ordinal}.csv" + assert filename1 == f"tunableop_results_{unique_id}_{ordinal}.csv" assert len(torch.cuda.tunable.get_results()) > 0 assert torch.cuda.tunable.write_file() # use default filename @@ -4720,6 +4751,10 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): assert file1_contents == file2_contents assert file1_contents == file3_contents + # We need to reset the filename to the default value so we can properly + # clean up intermediate files + self._set_tunableop_defaults() + @onlyCUDA @skipCUDAIfNotRocm @dtypes(torch.half) @@ -4728,7 +4763,6 @@ def test_matmul_offline_tunableop(self, device, dtype): # NOTE: The offline tuning does not support certain tensor # shapes as noted below. Submatrics / matrix slices are # not supported at all. - import os def has_any_dim_size_one(tensor: torch.Tensor): """Check if any dimension of a PyTorch tensor has size 1.""" @@ -4750,7 +4784,6 @@ def is_bmm_compatible(A, B): torch.cuda.tunable.set_rotating_buffer_size(0) ordinal = torch.cuda.current_device() - result_filename = f"tunableop_results{ordinal}.csv" # record GEMM torch.cuda.tunable.tuning_enable(False) @@ -4821,8 +4854,7 @@ def is_bmm_compatible(A, B): self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False) - untuned_filename = f"tunableop_untuned{ordinal}.csv" - self.assertTrue(os.path.exists(untuned_filename)) + untuned_filename = get_tunableop_untuned_filename() # tuning the untuned GEMMs in file torch.cuda.tunable.tuning_enable(True) @@ -4839,12 +4871,8 @@ def is_bmm_compatible(A, B): self.assertGreater(new_results - ref_results, 0) self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) @onlyCUDA @@ -4853,14 +4881,11 @@ def is_bmm_compatible(A, B): @dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) def test_scaled_gemm_offline_tunableop(self, device, dtype): # This test is the offline version of test_scaled_gemm_tunableop - import os with self._tunableop_ctx(): ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) - result_filename = f"tunableop_results{ordinal}.csv" - # record GEMM torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) @@ -4910,8 +4935,7 @@ def test_scaled_gemm_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False) - untuned_filename = f"tunableop_untuned{ordinal}.csv" - self.assertTrue(os.path.exists(untuned_filename)) + untuned_filename = get_tunableop_untuned_filename() # tuning the untuned GEMMs in file torch.cuda.tunable.tuning_enable(True) @@ -4937,12 +4961,8 @@ def test_scaled_gemm_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") @@ -4960,7 +4980,11 @@ def test_matmul_offline_mgpu_tunableop(self, device, dtype): total_gpus = torch.cuda.device_count() ordinal = torch.cuda.current_device() - untuned_filename = f"tunableop_untuned{ordinal}.csv" + + # Untuned filename has unique id, but results file + # does not because it is executed in a subprocess + untuned_filename = get_tunableop_untuned_filename() + torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv") # turn on untuned GEMM recording and turn off tuning torch.cuda.tunable.tuning_enable(False) @@ -4985,19 +5009,14 @@ def test_matmul_offline_mgpu_tunableop(self, device, dtype): torch.cuda.tunable.mgpu_tune_gemm_in_file(untuned_filename, total_gpus) # check the results files where written, one per gpu - # get the size of the first result and make sure it - # greater than 100. Since the validator text should - # be at least that much. - # The other results file will have - # at least the size of the first results file - 80 + # Check that the results file is not empty and store + # that in a local variable for the next loop. for i in range(total_gpus): result_filename = f"tunableop_results{i}.csv" self.assertTrue(os.path.exists(result_filename)) + self.assertGreater(os.path.getsize(result_filename), 0) if i == 0: # Store for next loop result_size = os.path.getsize(result_filename) - self.assertGreater(os.path.getsize(result_filename), 0) - self.assertGreater(os.path.getsize(result_filename), result_size - 80) - # Check the full results files was written, one per gpu # check that the size of the full results file for @@ -5018,6 +5037,7 @@ def test_matmul_offline_mgpu_tunableop(self, device, dtype): def test_rotating_buffer_tunableop(self, device, dtype): # Test the TunableOp rotating buffer API # Test the default value, will return the l2_cache_size + self._set_tunableop_defaults() l2_cache_size = torch.cuda.tunable.get_rotating_buffer_size() self.assertGreater(l2_cache_size, 0) # Test zero @@ -5038,6 +5058,9 @@ def test_bmm_tunableop_rocm(self, device, dtype): # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault with self._tunableop_ctx(): torch.cuda.tunable.set_max_tuning_iterations(10) + # Make sure the rotating buffer is not zero, otherwise this test does nothing useful. + rotating_buffer = torch.cuda.tunable.get_rotating_buffer_size() + self.assertGreater(rotating_buffer, 0) # the following 3 cases cover all previous failure cases and are here to catch regressions B = 16 N = M = K = 256 @@ -5082,21 +5105,21 @@ def test_bmm_tunableop_rocm(self, device, dtype): @onlyCUDA @skipCUDAIfNotRocm - @dtypes(torch.float) + @dtypes(torch.bfloat16) def test_numeric_check_leak_tunableop_rocm(self, device, dtype): import os from torch.testing._internal.common_utils import CudaMemoryLeakCheck # run operator first without tuning to ensure all rocm libs are loaded, # otherwise false positive mem leak - B = 16 - N = M = K = 256 - dtype = torch.bfloat16 + B = 5 + N = M = K = 29 device = torch.device("cuda:0") i1 = torch.randn((B, N, M), device=device, dtype=dtype) i2 = torch.randn((B, M, K), device=device, dtype=dtype) out = torch.bmm(i1, i2) with self._tunableop_ctx(): + torch.cuda.tunable.set_rotating_buffer_size(0) # enable tunableop numeric check via env variable. os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1" @@ -5213,9 +5236,9 @@ def test_disable_tuning_tunableop(self, device, dtype): ref_num_results = len(torch.cuda.tunable.get_results()) # Tune one GEMMs to make sure TunableOp is enabled - M = 3 - N = 3 - K = 3 + M = 11 + N = 13 + K = 17 A = torch.randn(N, K, device=device, dtype=dtype) B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) @@ -5234,9 +5257,9 @@ def test_disable_tuning_tunableop(self, device, dtype): torch.cuda.tunable.tuning_enable(False) # Try to tune one more GEMM - M = 3 - N = 3 - K = 4 + M = 11 + N = 13 + K = 18 A = torch.randn(N, K, device=device, dtype=dtype) B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) @@ -5257,8 +5280,7 @@ def test_dump_results_on_exit_tunableop(self, device, dtype): import multiprocessing as mp with self._tunableop_ctx(): - ordinal = torch.cuda.current_device() - filename = f"tunableop_results{ordinal}.csv" + filename = torch.cuda.tunable.get_filename() # force=True needed according to: # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method @@ -5266,7 +5288,7 @@ def test_dump_results_on_exit_tunableop(self, device, dtype): # already set the start method mp.set_start_method("spawn", force=True) - p = mp.Process(target=tunableop_matmul, args=(device, dtype)) + p = mp.Process(target=tunableop_matmul, args=(device, dtype, filename, False)) p.start() p.join() @@ -5305,14 +5327,11 @@ def test_gemm_bias_tunableop(self, device, dtype): @dtypes(torch.bfloat16) def test_gemm_bias_offline_tunableop(self, device, dtype): # This test is the offline version of test_gemm_bias_tunableop - import os ordinal = torch.cuda.current_device() with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) - result_filename = f"tunableop_results{ordinal}.csv" - # record GEMM torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) @@ -5330,8 +5349,7 @@ def test_gemm_bias_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False) - untuned_filename = f"tunableop_untuned{ordinal}.csv" - self.assertTrue(os.path.exists(untuned_filename)) + untuned_filename = get_tunableop_untuned_filename() # tuning the untuned GEMMs in file torch.cuda.tunable.tuning_enable(True) @@ -5353,12 +5371,8 @@ def test_gemm_bias_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) @onlyCUDA @@ -5378,6 +5392,7 @@ def test_scaled_gemm_tunableop(self, device, dtype): # tested by PyTorch with self._tunableop_ctx(): # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_rotating_buffer_size(0) torch.cuda.tunable.set_max_tuning_iterations(1) # Reference number of results @@ -5386,9 +5401,9 @@ def test_scaled_gemm_tunableop(self, device, dtype): # Scaled GEMM parameters fillA = 0.25 fillB = 0.75 - n = 32 - m = 64 - k = 128 + n = 64 + m = 16 + k = 32 scaleA = torch.tensor(0.8, device=device) scaleB = torch.tensor(0.9, device=device) @@ -5519,8 +5534,6 @@ def test_tf32_offline_tunableop(self, device, dtype): ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) - result_filename = f"tunableop_results{ordinal}.csv" - # record GEMM torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) @@ -5535,7 +5548,7 @@ def test_tf32_offline_tunableop(self, device, dtype): torch.backends.cuda.matmul.allow_tf32 = False C = torch.matmul(A, B) - untuned_filename = f"tunableop_untuned{ordinal}.csv" + untuned_filename = get_tunableop_untuned_filename() self.assertTrue(os.path.exists(untuned_filename)) # tuning the untuned GEMMs in file @@ -5569,12 +5582,8 @@ def test_tf32_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) finally: @@ -5606,10 +5615,11 @@ def test_blaslog_tunableop(self, device, dtype): with self._tunableop_ctx(): os.putenv("PYTORCH_TUNABLEOP_BLAS_LOG", "1") - ordinal = torch.cuda.current_device() - - result_filename = f"tunableop_results{ordinal}.csv" - untuned_filename = f"tunableop_untuned{ordinal}.csv" + # TunableOp is running in a subprocess + # online tuning needs filename set through API + # offline tuning needs filename set through environment variableq + result_filename = torch.cuda.tunable.get_filename() + untuned_filename = get_tunableop_untuned_filename() # Offline Tuning case in a subprocess @@ -5619,7 +5629,7 @@ def test_blaslog_tunableop(self, device, dtype): # already set the start method mp.set_start_method("spawn", force=True) - p = mp.Process(target=tunableop_matmul, args=(device, dtype, True)) + p = mp.Process(target=tunableop_matmul, args=(device, dtype, None, True)) p.start() p.join() @@ -5646,7 +5656,7 @@ def test_blaslog_tunableop(self, device, dtype): # already set the start method mp.set_start_method("spawn", force=True) - p = mp.Process(target=tunableop_matmul, args=(device, dtype, False)) + p = mp.Process(target=tunableop_matmul, args=(device, dtype, result_filename, False)) p.start() p.join() @@ -5835,20 +5845,21 @@ def run_test(batch, m, n, fortran_contiguous): @dtypes(*floating_and_complex_types()) def test_ormqr_errors_and_warnings(self, device, dtype): test_cases = [ - # input1 size, input2 size, input3 size, error regex - ((10,), (2,), (2,), r"input must have at least 2 dimensions"), - ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"), - ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"), - ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"), - ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"), - ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"), + # input1 size, input2 size, input3 size, left, error regex + ((10,), (2,), (2,), True, r"input must have at least 2 dimensions"), + ((2, 2), (2,), (2,), True, r"other must have at least 2 dimensions"), + ((6, 6), (5,), (5, 5), True, r"other.shape\[-2\] must be equal to input.shape\[-2\]"), + ((1, 2, 2), (2, 2), (1, 2, 2), True, r"batch dimensions of tau to be equal to input.shape\[:-2\]"), + ((1, 2, 2), (1, 2), (2, 2, 2), True, r"batch dimensions of other to be equal to input.shape\[:-2\]"), + ((2, 4, 3), (2, 2), (2, 3, 10), True, r"torch.ormqr: other.shape\[-2\] must be equal to input.shape\[-2\]"), + ((2, 4, 3), (2, 2), (2, 3, 10), False, r"torch.ormqr: other.shape\[-1\] must be equal to input.shape\[-2\]") ] - for a_size, tau_size, c_size, error_regex in test_cases: + for a_size, tau_size, c_size, left, error_regex in test_cases: a = make_tensor(a_size, dtype=dtype, device=device) tau = make_tensor(tau_size, dtype=dtype, device=device) c = make_tensor(c_size, dtype=dtype, device=device) with self.assertRaisesRegex(RuntimeError, error_regex): - torch.ormqr(a, tau, c) + torch.ormqr(a, tau, c, left) def test_blas_empty(self, device): def fn(torchfn, *args, test_out=False, **kwargs): @@ -6868,7 +6879,8 @@ def test_addmm_relu(self, device, dtype): @bf32_on_and_off(0.05) def test_addmm_relu_tunableop_rocm(self, device, dtype): with self._tunableop_ctx(): - torch.cuda.tunable.set_max_tuning_iterations(10) + torch.cuda.tunable.set_rotating_buffer_size(0) + torch.cuda.tunable.set_max_tuning_iterations(1) self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 611a2f943f67..7de6f3d725ed 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -24,7 +24,7 @@ SM90OrLater, _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, - PLATFORM_SUPPORTS_MX_GEMM + PLATFORM_SUPPORTS_MX_GEMM, ) from torch.testing._internal.common_device_type import ( dtypes, @@ -32,6 +32,10 @@ onlyCUDA, tol as xtol, toleranceOverride, + e4m3_type, + e5m2_type, + E4M3_MAX_POS, + E5M2_MAX_POS, ) from torch.testing._internal.common_utils import ( @@ -254,21 +258,179 @@ def _expand_to_batch(t: torch.Tensor): # cross comparison self.assertEqual(out1_gpu, out2_gpu[0]) + def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist): + for a, b, gO, agrad, bgrad, out in zip(alist, blist, gOlist, agradlist, bgradlist, outlist): + a = a.clone().detach().requires_grad_() + b = b.clone().detach().requires_grad_() + out_ref = torch.mm(a, b.t()) + out_ref.backward(gO) + self.assertEqual(out, out_ref) + self.assertEqual(agrad, a.grad) + self.assertEqual(bgrad, b.grad) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16 + if a_row_major: + a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] + else: + a = torch.randn(k * n_groups + k * int(strided), m, device=device, dtype=dtype).t()[:, :k * n_groups] + + if b_row_major: + b = torch.randn(n, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] + else: + b = torch.randn(k * n_groups + k * int(strided), n, device=device, dtype=dtype).t()[:, :k * n_groups] + + a.requires_grad_(True) + b.requires_grad_(True) + offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) + out = torch._grouped_mm(a, b.t(), offs=offs, + out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + alist, blist, agradlist, bgradlist = [], [], [], [] + start = 0 + for i in range(n_groups): + alist.append(a[:, start:offs_cpu[i]]) + blist.append(b[:, start:offs_cpu[i]]) + agradlist.append(a.grad[:, start:offs_cpu[i]]) + bgradlist.append(b.grad[:, start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + s_int = int(strided) + m, n, k, n_groups = 16, 32, 16, 4 + if a_row_major: + a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] + else: + a = torch.randn(k, (m + 2 * s_int) * n_groups, device=device, dtype=dtype).t()[:m * n_groups, :] + + if b_row_major: + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.t() + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs, + out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + alist, agradlist, gOlist, outlist = [], [], [], [] + start = 0 + for i in range(n_groups): + alist.append(a[start:offs_cpu[i]]) + agradlist.append(a.grad[start:offs_cpu[i]]) + outlist.append(out[start:offs_cpu[i]]) + gOlist.append(gO[start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, b, gOlist, agradlist, b.grad, outlist) + + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + s_int = int(strided) + m, n, k, n_groups = 16, 32, 16, 4 + if a_row_major: + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + + out = torch._grouped_mm(a, b.transpose(-2, -1), out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + s_int = int(strided) + m, n, k, n_groups = 16, 32, 16, 4 + if a_row_major: + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] + else: + b = torch.randn(k, n * (n_groups + s_int), device=device, dtype=dtype).transpose(-2, -1)[:n * n_groups, :] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) + out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs, + out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + blist, outlist, bgradlist, gOlist = [], [], [], [] + start = 0 + for i in range(n_groups): + blist.append(b[start:offs_cpu[i]]) + bgradlist.append(b.grad[start:offs_cpu[i]]) + outlist.append(out[:, start:offs_cpu[i]]) + gOlist.append(gO[:, start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(a, blist, gOlist, a.grad, bgradlist, outlist) + f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" -if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName: - e4m3_type = torch.float8_e4m3fnuz - e5m2_type = torch.float8_e5m2fnuz - E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max - E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max -else: - e4m3_type = torch.float8_e4m3fn - e5m2_type = torch.float8_e5m2 - E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max - E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max - # avoid division by zero when calculating scale EPS = 1e-12 @@ -1265,7 +1427,7 @@ def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: out_dtype=torch.bfloat16, ) - def grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum): + def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum): for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist): out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1), out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum) @@ -1275,7 +1437,7 @@ def grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_f @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_2d_2d(self, fast_accum, strided): + def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): device = "cuda" m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16 a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] @@ -1294,14 +1456,14 @@ def test_grouped_gemm_2d_2d(self, fast_accum, strided): ascalelist.append(scale_a[i * m : (i + 1) * m]) bscalelist.append(scale_b[i * n : (i + 1) * n]) start = offs_cpu[i] - self.grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum) + self.scaled_grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_2d_3d(self, fast_accum, strided): + def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 16, 4 @@ -1324,14 +1486,14 @@ def test_grouped_gemm_2d_3d(self, fast_accum, strided): ascalelist.append(scale_a[start:offs_cpu[i]]) outlist.append(out[start:offs_cpu[i]]) start = offs_cpu[i] - self.grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) + self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_3d_3d(self, fast_accum, strided): + def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 16, 4 @@ -1345,14 +1507,14 @@ def test_grouped_gemm_3d_3d(self, fast_accum, strided): out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) - self.grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum) + self.scaled_grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_3d_2d(self, fast_accum, strided): + def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 16, 4 @@ -1374,7 +1536,8 @@ def test_grouped_gemm_3d_2d(self, fast_accum, strided): bscalelist.append(scale_b[start:offs_cpu[i]]) outlist.append(out[:, start:offs_cpu[i]]) start = offs_cpu[i] - self.grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum) + self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum) + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_mxfp8_compile(self) -> None: @@ -1404,6 +1567,35 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) + def test_blockwise_nvfp4_compile(self) -> None: + + device = "cuda" + M, K, N = 128, 128, 128 + BLOCK_SIZE = 16 + + A_ref = torch.eye(M, device=device, dtype=torch.bfloat16) + B_ref = torch.eye(M, device=device, dtype=torch.bfloat16) + + A = _bfloat16_to_float4_e2m1fn_x2(A_ref) + B = _bfloat16_to_float4_e2m1fn_x2(B_ref) + + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + C_ref = A_ref @ B_ref.t() + + compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor") + # C = torch._scaled_mm( + C = compiled_scaled_mm( + A, + B.t(), + A_scale, + B_scale, + out_dtype=torch.bfloat16, + use_fast_accum=False, + ) + torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 93858f10b5c9..19772c4adaa5 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -4,6 +4,7 @@ import itertools import functools import unittest +import warnings from contextlib import nullcontext try: @@ -1612,6 +1613,16 @@ def common(self, shape1, shape2, op, dtype): ]: common(self, shape1, shape2, op, dtype) + def test_mkldnn_setflags_nowarn(self, device): + # Regression test for https://github.com/pytorch/pytorch/issues/149829 + with warnings.catch_warnings(record=True) as w: + rc = torch.backends.mkldnn.set_flags() + # torch.backends.mkldnn. returns previously set flags + # That one should be able to set back without cauinsg a warning + torch.backends.mkldnn.set_flags(*rc) + # Above should trigger no warnings regardless of configuration + self.assertEqual(len(w), 0) + instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) diff --git a/test/test_model_exports_to_core_aten.py b/test/test_model_exports_to_core_aten.py index aae14c28b8d6..3d1c25939ec4 100644 --- a/test/test_model_exports_to_core_aten.py +++ b/test/test_model_exports_to_core_aten.py @@ -27,7 +27,9 @@ def test_vit_aten_export(self): m = m.eval() input_shape = (1, 3, 224, 224) example_inputs = (torch.randn(input_shape),) - m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() + m = torch.export.export_for_training( + m, copy.deepcopy(example_inputs), strict=True + ).module() m(*example_inputs) m = export.export(m, copy.deepcopy(example_inputs)) ops = _get_ops_list(m.graph_module) diff --git a/test/test_mps.py b/test/test_mps.py index f5c9befc492a..0f1790a1808d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -657,7 +657,6 @@ def mps_ops_modifier(ops): 'sparse.mmreduce': None, 'special.airy_ai': None, 'special.erfcx': None, - 'special.hermite_polynomial_h': None, 'special.hermite_polynomial_he': None, 'special.laguerre_polynomial_l': None, 'special.log_ndtr': None, @@ -714,6 +713,7 @@ def mps_ops_modifier(ops): 'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 'special.chebyshev_polynomial_t': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 'special.chebyshev_polynomial_u': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'special.hermite_polynomial_h': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # entr does not support boolean types 'special.entr': [torch.bool], @@ -1089,6 +1089,125 @@ def test_scaled_dot_product_attention_autocast(self): y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) self.assertEqual(y.to(y_autocast.dtype), y_autocast) + def test_gradscaler_mps(self): + # big model to force chunking/depth in the gradscaler dispatch + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 2048) + self.fc2 = nn.Linear(2048, 2048) + self.fc3 = nn.Linear(2048, 2048) + self.fc4 = nn.Linear(2048, 2048) + self.fc5 = nn.Linear(2048, 5) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.relu(self.fc3(x)) + x = self.relu(self.fc4(x)) + return self.fc5(x) + torch.manual_seed(42) + + def helper(model_cpu, model_mps, dtype, iterations, batch_size, atol=3e-4, rtol=1e-5): + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat16 needs MacOS14+") + optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01) + optimizer_mps = torch.optim.SGD(model_mps.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + input_cpu = torch.randn(batch_size, 10) + target_cpu = torch.randn(batch_size, 5) + input_mps = input_cpu.to('mps') + target_mps = target_cpu.to('mps') + + scaler_cpu = torch.amp.GradScaler(device="cpu") + scaler_mps = torch.amp.GradScaler(device="mps") + for _ in range(iterations): + optimizer_cpu.zero_grad() + optimizer_mps.zero_grad() + + with torch.amp.autocast(device_type="cpu", dtype=dtype): + output_cpu = model_cpu(input_cpu) + loss_cpu = loss_fn(output_cpu, target_cpu) + scaler_cpu.scale(loss_cpu).backward() + scaler_cpu.step(optimizer_cpu) + scaler_cpu.update() + + with torch.autocast(device_type="mps", dtype=dtype): + output_mps = model_mps(input_mps) + loss_mps = loss_fn(output_mps, target_mps) + scaler_mps.scale(loss_mps).backward() + scaler_mps.step(optimizer_mps) + scaler_mps.update() + + for p_cpu, p_mps in zip(model_cpu.parameters(), model_mps.parameters()): + self.assertEqual(p_mps.cpu(), p_cpu, rtol=rtol, atol=atol) + + model_cpu = Model().to('cpu') + model_mps = Model().to('mps') + model_mps.load_state_dict(model_cpu.state_dict()) + + helper(model_cpu, model_mps, torch.float16, iterations=5, batch_size=4) + helper(model_cpu, model_mps, torch.bfloat16, iterations=5, batch_size=4) + + def test_non_fast_path_amp_unscale(self): + torch.manual_seed(42) + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x): + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + x = x.mean(dim=1) + return x + + cpu_model = Model().to("cpu") + mps_model = copy.deepcopy(cpu_model).to("mps") + + cpu_optimizer = torch.optim.SGD(cpu_model.parameters(), lr=0.01) + mps_optimizer = torch.optim.SGD(mps_model.parameters(), lr=0.01) + cpu_scaler = torch.amp.GradScaler(device="cpu") + mps_scaler = torch.amp.GradScaler(device="mps") + + def helper(model, optimizer, scaler, device, input, target, apply_grad_transform=False): + optimizer.zero_grad() + with torch.autocast(device_type=device, dtype=torch.bfloat16): + output = model(input) + loss = nn.MSELoss()(output, target) + scaler.scale(loss).backward() + + if apply_grad_transform: + for p in model.parameters(): + if p.grad is not None and p.grad.dim() >= 2: + p.grad = p.grad.as_strided(p.grad.size(), (1,) * p.grad.dim()) + + scaler.unscale_(optimizer) + scaler.step(optimizer) + scaler.update() + + # CPU forward/backward pass + input_cpu = torch.randn(32, 10, device="cpu") + target_cpu = torch.randn(32, device="cpu") + helper(cpu_model, cpu_optimizer, cpu_scaler, "cpu", input_cpu, target_cpu) + + # MPS forward/backward pass + input_mps = input_cpu.to("mps") + target_mps = target_cpu.to("mps") + helper(mps_model, mps_optimizer, mps_scaler, "mps", input_mps, target_mps, apply_grad_transform=True) + + updated_linear1_weight_cpu = cpu_model.linear1.weight.detach() + updated_linear2_weight_cpu = cpu_model.linear2.weight.detach() + updated_linear1_weight_mps = mps_model.linear1.weight.detach().cpu() + updated_linear2_weight_mps = mps_model.linear2.weight.detach().cpu() + + self.assertEqual(updated_linear1_weight_cpu, updated_linear1_weight_mps, atol=6e-4, rtol=1e-6) + self.assertEqual(updated_linear2_weight_cpu, updated_linear2_weight_mps, atol=6e-4, rtol=1e-6) # Expand TestCase class with Memory Leak Detection on MPS device class TestCaseMPS(TestCase): @@ -7476,6 +7595,7 @@ def compare_mm(m, n, k, dtype=torch.float): @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test") @unittest.skipIf(MACOS_VERSION < 14.0, "Can't allocate 4Gb tensor on MacOS 13") + @unittest.skipIf(IS_CI, "May be fixes https://github.com/pytorch/pytorch/issues/149999") def test_copy_large(self): """ Test that copy of 4Gb+ tensors works """ x = torch.ones((2**30 + 11,), dtype=torch.float32) @@ -7814,18 +7934,19 @@ def helper(shape, diag=0): # Test inverse def test_inverse(self): - def helper(n): + def helper(n, atol=1e-5, rtol=1e-6): cpu_input = torch.randn(n, n, device='cpu') mps_input = cpu_input.to('mps') cpu_result = torch.linalg.inv(cpu_input) mps_result = torch.linalg.inv(mps_input) - self.assertEqual(cpu_result, mps_result) + self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol) helper(2) helper(6) helper(3) helper(8) + helper(1025, atol=1e-4) # Test tril def test_tril(self): @@ -12936,6 +13057,50 @@ def test_metal_include(self): lib = torch.mps.compile_shader("#include ") self.assertIsNotNone(lib) + @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_reduction_utils(self, dtype): + if dtype == torch.int64 and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Using simd_shuffle_down_and_fill results in ICE on MacOS-13") + from torch._inductor.codegen.mps import DTYPE_TO_METAL + lib = torch.mps.compile_shader(f""" + #include + kernel void do_sum(device {DTYPE_TO_METAL[dtype]}* out, + constant {DTYPE_TO_METAL[dtype]}* inp, + uint idx [[thread_position_in_grid]]) {{ + out[idx] = c10::metal::simd_sum(inp[idx]); + }} + """) + x = torch.testing.make_tensor(28, device="mps", dtype=dtype) + y = torch.empty_like(x) + lib.do_sum(y, x) + x_sum = x.sum() + max_err = (y - x_sum).abs().max().item() + self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5, + f"results are {y}, but all elements should have been {x_sum.item()}") + + def test_argument_buffers(self): + lib = torch.mps.compile_shader(""" + constant constexpr auto nbuffers = 64; + struct Inputs { + metal::array args; + }; + + kernel void sum_all(device float* output, constant Inputs& inputs, uint idx [[thread_position_in_grid]]) { + auto rc = inputs.args[0][idx]; + for(auto i = 1; i < nbuffers; ++i) { + rc += inputs.args[i][idx]; + } + output[idx] = rc; + } + """) + inputs = torch.rand(64, 32, device="mps").unbind(0) + output = torch.empty_like(inputs[0]) + lib.sum_all(output, inputs) + correct = torch.zeros_like(inputs[0]) + for inp in inputs: + correct += inp + self.assertEqual(correct, output) + @unittest.skipIf(not torch.mps.profiler.is_metal_capture_enabled(), "Set MTL_CAPTURE_ENABLED and try again") def test_metal_capture(self): lib = torch.mps.compile_shader("kernel void full(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] = 1.0; }") @@ -12967,6 +13132,7 @@ def test_metal_capture(self): instantiate_parametrized_tests(TestMPS) instantiate_parametrized_tests(TestSDPA) instantiate_parametrized_tests(TestSmoothL1Loss) +instantiate_parametrized_tests(TestMetalLibrary) if __name__ == "__main__": run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index 72c440ca5ec5..32b0efd40aff 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1812,7 +1812,7 @@ def check_weight_norm(l, name, num_params): def test_weight_norm(self): - for dtype in [torch.float, torch.bfloat16]: + for dtype in [torch.float, torch.bfloat16, torch.float16]: input = torch.randn(3, 4, dtype=dtype) m = nn.Linear(4, 5).to(dtype=dtype) expected_output = m(input) diff --git a/test/test_ops.py b/test/test_ops.py index 871b643568eb..c8079ea71255 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -118,8 +118,6 @@ def reduction_dtype_filter(op): aten = torch.ops.aten meta_consistency_out_dtype_mismatch_xfails = { - xfail("addbmm"), - xfail("addmv"), xfail("alias_copy"), xfail("all"), xfail("amax"), @@ -127,7 +125,6 @@ def reduction_dtype_filter(op): xfail("aminmax"), xfail("any"), xfail("as_strided_copy"), - xfail("baddbmm"), xfail("bucketize"), xfail("conj_physical"), xfail("cross"), @@ -135,7 +132,6 @@ def reduction_dtype_filter(op): xfail("cummin"), xfail("diag"), xfail("diagonal_copy"), - xfail("dot"), xfail("expand_copy"), xfail("fft.ihfft2"), xfail("fft.ihfftn"), @@ -159,7 +155,6 @@ def reduction_dtype_filter(op): xfail("linalg.lu_factor"), xfail("linalg.lu_factor_ex"), xfail("linalg.lu_solve"), - xfail("linalg.matrix_power"), xfail("linalg.qr"), xfail("linalg.slogdet"), xfail("linalg.solve"), @@ -168,12 +163,9 @@ def reduction_dtype_filter(op): xfail("logcumsumexp"), xfail("lu_solve"), xfail("lu_unpack"), - xfail("matmul"), - xfail("mm"), xfail("mode"), xfail("msort"), xfail("multinomial"), - xfail("mv"), xfail("nan_to_num"), xfail("nanmean"), xfail("narrow_copy"), @@ -182,7 +174,6 @@ def reduction_dtype_filter(op): xfail("nn.functional.avg_pool3d"), xfail("nn.functional.gelu"), xfail("nn.functional.hardshrink"), - xfail("nn.functional.linear"), xfail("nn.functional.logsigmoid"), xfail("nn.functional.softplus"), xfail("nn.functional.softshrink"), @@ -210,7 +201,6 @@ def reduction_dtype_filter(op): xfail("triu"), xfail("unfold_copy"), xfail("unsqueeze_copy"), - xfail("vdot"), xfail("view_copy"), xfail("where"), # Output has dynamic shape. @@ -1825,6 +1815,7 @@ def check_ignore_materialize(idx_or_kw, allow_list): def check_cow_input( arg, arg_copy, + arg_raw, idx_or_kw, backward_or_forward="forward", supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward, @@ -1837,6 +1828,13 @@ def check_cow_input( ) + f" during {backward_or_forward} call" if is_strided_tensor(arg): + self.assertTrue( + torch._C._is_cow_tensor(arg_raw), + msg=( + f"{arg_name} raw input should remain COW, but it " + "unexpectedly materialized." + ), + ) is_cow = torch._C._is_cow_tensor(arg) if supports_cow_input_no_materialize and not check_ignore_materialize( @@ -1861,6 +1859,17 @@ def check_cow_input( "but the operation mutated its data." ), ) + else: + self.assertTrue( + torch.allclose( + arg_raw, arg_copy, rtol=0, atol=0, equal_nan=True + ), + msg=( + f"{arg_name} materialized, which is allowed in this " + "case, but the COW input data was mutated, which is " + "not allowed." + ), + ) for sample in samples: args_raw = [sample.input] + list(sample.args) @@ -1901,10 +1910,10 @@ def check_cow_input( # Check that COW inputs remain COW after the forward op is executed for idx, arg in enumerate(args): - check_cow_input(arg, args_copy[idx], idx) + check_cow_input(arg, args_copy[idx], args_raw[idx], idx) for kw, arg in kwargs.items(): - check_cow_input(arg, kwargs_copy[kw], kw) + check_cow_input(arg, kwargs_copy[kw], kwargs_raw[kw], kw) # Call backward op if it is supported. This part of the test is # based on `composite_compliance.check_backward_formula` @@ -1954,6 +1963,7 @@ def check_cow_input( check_cow_input( arg, args_copy[idx], + args_raw[idx], idx, backward_or_forward="backward", supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, @@ -1965,6 +1975,7 @@ def check_cow_input( check_cow_input( output_grad, output_grads_copy[idx], + output_grads_raw[idx], f"output grad {idx}", backward_or_forward="backward", supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, diff --git a/test/test_pytree.py b/test/test_pytree.py index 4560ac6e69ed..82665854c2b1 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -6,11 +6,12 @@ import re import subprocess import sys +import time import unittest from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import auto -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional import torch import torch.utils._pytree as py_pytree @@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl): with self.assertRaises(TypeError): pytree_impl.treespec_dumps("random_blurb") + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_is_namedtuple(self, pytree): + DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"]) + + class DirectNamedTuple2(NamedTuple): + x: int + y: int + + class IndirectNamedTuple1(DirectNamedTuple1): + pass + + class IndirectNamedTuple2(DirectNamedTuple2): + pass + + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1))) + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1))) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1))) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1))) + self.assertFalse(pytree.is_namedtuple(time.gmtime())) + self.assertFalse(pytree.is_namedtuple((0, 1))) + self.assertFalse(pytree.is_namedtuple([0, 1])) + self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2})) + self.assertFalse(pytree.is_namedtuple({0, 1})) + self.assertFalse(pytree.is_namedtuple(1)) + + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2)) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2)) + self.assertFalse(pytree.is_namedtuple(time.struct_time)) + self.assertFalse(pytree.is_namedtuple(tuple)) + self.assertFalse(pytree.is_namedtuple(list)) + + self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2)) + self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2)) + self.assertFalse(pytree.is_namedtuple_class(time.struct_time)) + self.assertFalse(pytree.is_namedtuple_class(tuple)) + self.assertFalse(pytree.is_namedtuple_class(list)) + + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_is_structseq(self, pytree): + class FakeStructSeq(tuple): + n_fields = 2 + n_sequence_fields = 2 + n_unnamed_fields = 0 + + __slots__ = () + __match_args__ = ("x", "y") + + def __new__(cls, sequence): + return super().__new__(cls, sequence) + + @property + def x(self): + return self[0] + + @property + def y(self): + return self[1] + + DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"]) + + class DirectNamedTuple2(NamedTuple): + x: int + y: int + + self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1)))) + self.assertTrue(pytree.is_structseq(time.gmtime())) + self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1))) + self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1))) + self.assertFalse(pytree.is_structseq((0, 1))) + self.assertFalse(pytree.is_structseq([0, 1])) + self.assertFalse(pytree.is_structseq({0: 1, 1: 2})) + self.assertFalse(pytree.is_structseq({0, 1})) + self.assertFalse(pytree.is_structseq(1)) + + self.assertFalse(pytree.is_structseq(FakeStructSeq)) + self.assertTrue(pytree.is_structseq(time.struct_time)) + self.assertFalse(pytree.is_structseq(DirectNamedTuple1)) + self.assertFalse(pytree.is_structseq(DirectNamedTuple2)) + self.assertFalse(pytree.is_structseq(tuple)) + self.assertFalse(pytree.is_structseq(list)) + + self.assertFalse(pytree.is_structseq_class(FakeStructSeq)) + self.assertTrue( + pytree.is_structseq_class(time.struct_time), + ) + self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1)) + self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2)) + self.assertFalse(pytree.is_structseq_class(tuple)) + self.assertFalse(pytree.is_structseq_class(list)) + + # torch.return_types.* are all PyStructSequence types + for cls in vars(torch.return_types).values(): + if isinstance(cls, type) and issubclass(cls, tuple): + self.assertTrue(pytree.is_structseq(cls)) + self.assertTrue(pytree.is_structseq_class(cls)) + self.assertFalse(pytree.is_namedtuple(cls)) + self.assertFalse(pytree.is_namedtuple_class(cls)) + + inst = cls(range(cls.n_sequence_fields)) + self.assertTrue(pytree.is_structseq(inst)) + self.assertTrue(pytree.is_structseq(type(inst))) + self.assertFalse(pytree.is_structseq_class(inst)) + self.assertTrue(pytree.is_structseq_class(type(inst))) + self.assertFalse(pytree.is_namedtuple(inst)) + self.assertFalse(pytree.is_namedtuple_class(inst)) + else: + self.assertFalse(pytree.is_structseq(cls)) + self.assertFalse(pytree.is_structseq_class(cls)) + self.assertFalse(pytree.is_namedtuple(cls)) + self.assertFalse(pytree.is_namedtuple_class(cls)) + class TestPythonPytree(TestCase): def test_deprecated_register_pytree_node(self): @@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self): serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1", ) - spec = py_pytree.TreeSpec( - namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(Point1(1, 2)) + self.assertIs(spec.type, namedtuple) roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) self.assertEqual(spec, roundtrip_spec) @@ -990,18 +1117,28 @@ class Point2(NamedTuple): serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2", ) - spec = py_pytree.TreeSpec( - namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] + spec = py_pytree.tree_structure(Point2(1, 2)) + self.assertIs(spec.type, namedtuple) + roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) + self.assertEqual(spec, roundtrip_spec) + + class Point3(Point2): + pass + + py_pytree._register_namedtuple( + Point3, + serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3", ) + + spec = py_pytree.tree_structure(Point3(1, 2)) + self.assertIs(spec.type, namedtuple) roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) self.assertEqual(spec, roundtrip_spec) def test_pytree_serialize_namedtuple_bad(self): DummyType = namedtuple("DummyType", ["x", "y"]) - spec = py_pytree.TreeSpec( - namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "Please register using `_register_namedtuple`" @@ -1020,9 +1157,7 @@ def __init__(self, x, y): lambda xs, _: DummyType(*xs), ) - spec = py_pytree.TreeSpec( - DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "No registered serialization name" ): @@ -1042,9 +1177,7 @@ def __init__(self, x, y): to_dumpable_context=lambda context: "moo", from_dumpable_context=lambda dumpable_context: None, ) - spec = py_pytree.TreeSpec( - DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) serialized_spec = py_pytree.treespec_dumps(spec, 1) self.assertIn("moo", serialized_spec) roundtrip_spec = py_pytree.treespec_loads(serialized_spec) @@ -1082,9 +1215,7 @@ def __init__(self, x, y): from_dumpable_context=lambda dumpable_context: None, ) - spec = py_pytree.TreeSpec( - DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( TypeError, "Object of type type is not JSON serializable" @@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self): import json Point = namedtuple("Point", ["x", "y"]) - spec = py_pytree.TreeSpec( - namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(Point(1, 2)) py_pytree._register_namedtuple( Point, serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point", @@ -1168,16 +1297,55 @@ def test_tree_map_with_path(self): def test_dataclass(self): @dataclass - class Point: - x: torch.Tensor - y: torch.Tensor + class Data: + a: torch.Tensor + b: str = "moo" + c: Optional[str] = None + d: str = field(init=False, default="") + + py_pytree.register_dataclass(Data) + old_data = Data(torch.tensor(3), "b", "c") + old_data.d = "d" + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "b") + self.assertEqual(new_data.c, "c") + self.assertEqual(new_data.d, "") + py_pytree._deregister_pytree_node(Data) + + with self.assertRaisesRegex(ValueError, "Missing fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "e"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"]) + + py_pytree.register_dataclass( + Data, field_names=["a"], drop_field_names=["b", "c"] + ) + old_data = Data(torch.tensor(3), "b", "c") + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "moo") + self.assertEqual(new_data.c, None) + py_pytree._deregister_pytree_node(Data) + + def test_register_dataclass_class(self): + class CustomClass: + def __init__(self, x, y): + self.x = x + self.y = y - py_pytree.register_dataclass(Point) + with self.assertRaisesRegex(ValueError, "field_names must be specified"): + py_pytree.register_dataclass(CustomClass) - point = Point(torch.tensor(0), torch.tensor(1)) - point = py_pytree.tree_map(lambda x: x + 1, point) - self.assertEqual(point.x, torch.tensor(1)) - self.assertEqual(point.y, torch.tensor(2)) + py_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) + c = CustomClass(torch.tensor(0), torch.tensor(1)) + mapped = py_pytree.tree_map(lambda x: x + 1, c) + self.assertEqual(mapped.x, torch.tensor(1)) + self.assertEqual(mapped.y, torch.tensor(2)) def test_constant(self): # Either use `frozen=True` or `unsafe_hash=True` so we have a diff --git a/test/test_transformers.py b/test/test_transformers.py index 3a22d382d3c5..42950a84f154 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2479,7 +2479,8 @@ def test_cudnn_attention_gqa(self, device): # Sample call to SDPA - GQ query = torch.rand(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16) key = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) - value = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) + # cuDNN supports h_k != h_v + value = torch.rand(batch, 4, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) with sdpa_kernel([SDPBackend.MATH]): output_math = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) diff --git a/test/test_xpu.py b/test/test_xpu.py index 4208bf6daa5e..1647ad24a75a 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -136,6 +136,7 @@ def test_get_device_properties(self): device_capability["architecture"], ) + @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)") def test_wrong_xpu_fork(self): stderr = TestCase.runWithPytorchAPIUsageStderr( """\ @@ -192,9 +193,11 @@ def test_multi_process(model, input): torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), ) -test_multi_process(model, input) -test_multi_process(model, input) -print(torch.xpu.device_count()) + +if __name__ == "__main__": + test_multi_process(model, input) + test_multi_process(model, input) + print(torch.xpu.device_count()) """ rc = check_output(test_script) self.assertEqual(rc, str(torch.xpu.device_count())) diff --git a/test/torch_np/numpy_tests/core/test_dtype.py b/test/torch_np/numpy_tests/core/test_dtype.py index aeb9710832f9..d548f49b4cc4 100644 --- a/test/torch_np/numpy_tests/core/test_dtype.py +++ b/test/torch_np/numpy_tests/core/test_dtype.py @@ -3,7 +3,6 @@ import functools import operator import pickle -import sys import types from itertools import permutations from typing import Any @@ -325,11 +324,6 @@ def test_keyword_argument(self): # test for https://github.com/numpy/numpy/pull/16574#issuecomment-642660971 assert np.dtype(dtype=np.float64) == np.dtype(np.float64) - @skipif(sys.version_info >= (3, 9), reason="Requires python 3.9") - def test_class_getitem_38(self) -> None: - with pytest.raises(TypeError): - np.dtype[Any] - class TestFromDTypeAttribute(TestCase): def test_simple(self): diff --git a/test/torch_np/numpy_tests/core/test_scalar_methods.py b/test/torch_np/numpy_tests/core/test_scalar_methods.py index e1e92de7d6c6..36ac89a02c29 100644 --- a/test/torch_np/numpy_tests/core/test_scalar_methods.py +++ b/test/torch_np/numpy_tests/core/test_scalar_methods.py @@ -5,7 +5,6 @@ """ import fractions import functools -import sys import types from typing import Any from unittest import skipIf as skipif, SkipTest @@ -222,15 +221,6 @@ def test_subscript_scalar(self) -> None: assert np.number[Any] -@instantiate_parametrized_tests -class TestClassGetitemMisc(TestCase): - @skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") - @parametrize("cls", [np.number, np.complexfloating, np.int64]) - def test_class_getitem_38(self, cls: type[np.number]) -> None: - with pytest.raises(TypeError): - cls[Any] - - @skip(reason="scalartype(...).bit_count() not implemented") @instantiate_parametrized_tests class TestBitCount(TestCase): diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index cf3d68add29e..138729261652 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -15,7 +15,12 @@ instantiate_device_type_tests, precisionOverride, ) -from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase +from torch.testing._internal.common_utils import ( + iter_indices, + parametrize, + run_tests, + TestCase, +) class TestBasicGEMM(TestCase): @@ -1119,6 +1124,84 @@ def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): with torch.no_grad(): torch.matmul(a, b, out=c) + def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16): + # w [k, n] = [32, 48] + assert w.dim() == 2 + # w [n, k] = [48, 32] + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + # to_quant: [n * k / group_size, group_size] + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_int - min_val.div(scales).round() + zeros = torch.clamp(zeros, min_int, max_int) + zeros = zeros.to(torch.int8) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + # [n, k] + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device != torch.device("cpu"): + out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous() + zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous() + + return out, scales, zeros + + @parametrize("m", [128]) + @parametrize("k", [512, 1024]) + @parametrize("n", [512, 1024]) + def test__int4_mm(self, device, m, k, n): + q_group = 32 + inner_k_tiles = 2 + + torch.manual_seed(1) + a_bf16 = torch.rand((m, k), dtype=torch.float32, device=device) + b_bf16 = torch.rand((k, n), dtype=torch.float32, device=device) + + def convert_weight_to_int4pack(b): + # b_uint8 [n, k //2] + b_uint8, scales, zeros = self._group_quantize_tensor( + b, n_bit=4, q_group_size=q_group + ) + # b_int4pack [k//8, n] + b_int4pack = torch._convert_weight_to_int4pack(b_uint8, inner_k_tiles) + + return b_int4pack, scales, zeros + + def weight_int4pack_mm(a, b_int4pack, qscale, qzeros): + return torch._weight_int4pack_mm_with_scales_and_zeros( + a, b_int4pack, q_group, qscale, qzeros + ) + + b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b_bf16) + + for dtype in [torch.bfloat16, torch.float16]: + a = a_bf16.to(dtype=dtype) + b = b_bf16.to(dtype=dtype) + b_scales = b_scales.to(dtype=dtype) + ref = torch.mm(a, b) + + res = weight_int4pack_mm(a, b_int4pack, b_scales, zeros_int8) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) diff --git a/third_party/gloo b/third_party/gloo index 95ca2af4e4c7..c61070427610 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit 95ca2af4e4c76433fac8911525d8a0142b7a5289 +Subproject commit c61070427610ccd923efe3e7f8b3eca12bbcc31a diff --git a/third_party/kineto b/third_party/kineto index 2859721fd9e7..d6796921fdde 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 2859721fd9e73d3ca1c56f827dbc64e6d68f78a2 +Subproject commit d6796921fdde135cb94d2dd04fe2071a5424a321 diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index b20a7be4ed1a..231384bd859a 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -2249,6 +2249,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_deps = [ ":subgraph", ], + compiler_flags = select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS + [ "-DXNN_NO_Q8_OPERATORS", diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 239a4b8aeb93..5bdc7353dfe7 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -026b2c8c7c92a7b2cec5d26334006e3423251cc6 +98c808dea6de7330c415aa777d6921944cf79887 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 25749372a1f3..6a42e26d7618 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1206,6 +1206,10 @@ mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) +- name: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + self: _grouped_mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), offs, 1) + mat2: _grouped_mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), offs, 1) + - name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index bb61ac3e8216..d239ab1d43a0 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -16,8 +16,7 @@ Database, EntryState, Group, - MatchInfo, - MatchState, + MatchStateRecord, Membership, NCCLCall, Op, @@ -25,15 +24,14 @@ ) from tools.flight_recorder.components.utils import ( align_trace_from_beginning, + check_current_entry_match, check_no_missing_dump_files, - check_size_alltoall, check_version, + error_analysis, find_coalesced_group, - format_frames, get_version_detail, just_print_entries, match_coalesced_groups, - match_one_event, ) @@ -161,7 +159,6 @@ def build_collectives( ] } """ - major_v, minor_v = get_version_detail(version) tracebacks: list[Traceback] = [] collectives: list[Collective] = [] @@ -194,17 +191,23 @@ def build_collectives( # lets match the first collective! we need to know which ranks are involved, and ensure that this same # collective is also the first one on those ranks within that group entries = all_entries[first_rank] - desc = entries[0]["process_group"][1] + current_entry = entries[0] + desc = current_entry["process_group"][1] # For db build and logs printing, we want to use the original pg_name, not the hash one. - original_pg_name = entries[0]["process_group"][0] + original_pg_name = current_entry["process_group"][0] pg_name = _pg_guids[(original_pg_name, first_rank)] expected_ranks = set(_memberships[pg_name]) - entry_state = EntryState(entries[0], expected_ranks) - candidate_ranks = {first_rank} - candidate_idx = {} - found_ranks = set() - found_idx = {} - errors = set() + entry_state = EntryState(current_entry, expected_ranks) + match_record = MatchStateRecord( + expected_ranks=expected_ranks, + other_ranks=other_ranks, + entry_state=entry_state, + candidate_ranks={first_rank}, + candidate_idx={}, + found_ranks=set(), + found_idx={}, + errors=set(), + ) if find_coalesced_group(pg_name, entries, _pg_guids, first_rank): expected_ranks.add(first_rank) @@ -256,137 +259,42 @@ def build_collectives( ) ) else: - has_undecided_case = False - for o in expected_ranks.intersection(set(other_ranks)): - for i, e in enumerate(all_entries[o]): # type: ignore[index] - # step over ops from other PGs - # only check match state when seq_id matches - if ( - _pg_guids[(e["process_group"][0], o)] == pg_name - and e["process_group"][1] == desc - and e["collective_seq_id"] == entry_state.collective_seq_id - ): - match_info = match_one_event( - entries[0], e, _memberships, pg_name - ) - if ( - match_info.state - in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] - and mismatch[pg_name] == 0 - ): - found_ranks.add(o) - found_idx[o] = i - has_undecided_case = ( - match_info.state == MatchState.UNDECIDED - ) - else: - candidate_ranks.add(o) - candidate_idx[o] = i - if match_info.state not in [ - MatchState.FULLY_MATCHED, - MatchState.UNDECIDED, - ]: - # Here we assume the current rank is not the source of the error. - # But it's possible that the current rank is the culprit, then users will - # see lots of normal ranks reported as culprit. - # TODO: we need to figure out a better way to handle the case mentioned above. - errors.add((o, match_info)) - break - - # case one: not every rank join the collective or in the flight recorder. - if (candidate_ranks | found_ranks) != expected_ranks and expected_ranks - ( - candidate_ranks | found_ranks - ) <= dumps_ranks: - mismatch[pg_name] += 1 - logger_msg = "Not all ranks joining collective, sequence number: %s" - missing_ranks = expected_ranks - (candidate_ranks | found_ranks) - entry_state.log( - logger, logger_msg, format_frames, missing_ranks=missing_ranks - ) - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - elif len(candidate_ranks) == 1 and dumps_ranks == expected_ranks: - # case two: alltoall or alltoall_base case. - if has_undecided_case: - alltoall_cases = [entries[0]] + [ - all_entries[o][found_idx[o]] for o in found_ranks - ] - fail_check, total_input_numel, total_output_numel = ( - check_size_alltoall(alltoall_cases) - ) - if major_v <= 2 and minor_v <= 3: - # We don't log the input/output sizes for alltoall before v2.4, - # so we don't consider the size mismatch as an error for now. - fail_check = False - if fail_check: - # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. - mismatch[pg_name] += 1 - logger_msg = "Input/output mismatch in the collective sequence number: %s" - entry_state.log( - logger, - logger_msg, - format_frames, - total_numel=(total_input_numel, total_output_numel), - ) - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - errors.add( - (first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)) - ) - else: - found_ranks.update(candidate_ranks) - found_idx.update(candidate_idx) - candidate_idx.clear() - candidate_ranks.clear() - # case three: all joined and everything matches on all ranks. - else: - found_ranks.update(candidate_ranks) - found_idx.update(candidate_idx) - candidate_idx.clear() - candidate_ranks.clear() - # case four: mismatch cases due to not same type, size mismatch or state mismatch. - elif len(errors) > 0: - mismatch[pg_name] += 1 - logger_msg = "Collective sequence number: %s has errors" - entry_state.log(logger, logger_msg, format_frames, errors=errors) - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - # partial analysis case when we cannot decide what's wrong with this collective entry. - else: - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - if expected_ranks - dumps_ranks: - mismatch[pg_name] += 1 - logger.info( - "We cannot decide what's wrong with this collective entry " - "because we missed FR dumps from ranks (%s) so we don't have enough " - "information. If you want to debug further use -j to dump all raw trace", - str(expected_ranks - dumps_ranks), - ) - else: - logger.info( - "No errors found for this collective entry, There could be some " - "other reasons why we see collective timeout." - ) + # Iterate through all the ranks and check if there is a mis-match for the current entry. + check_current_entry_match( + all_entries, + _pg_guids, + (pg_name, desc), + current_entry, + _memberships, + mismatch, + match_record, + ) + + # Use heuristics to decide what type of errors and error messages we should print. + error_analysis( + all_entries, + match_record, + dumps_ranks, + first_rank, + current_entry, + mismatch, + get_version_detail(version), + pg_name, + ) # at this point there are 3 possibilities # 1. we found a match on all the ranks that are members of the group # -> we create a Collective and remove the individual entries from their original lists - if found_ranks == expected_ranks and mismatch[pg_name] == 0: - collectives.append(entry_state.to_collective(len(collectives))) + if match_record.found_ranks == expected_ranks and mismatch[pg_name] == 0: + collectives.append( + match_record.entry_state.to_collective(len(collectives)) + ) idx_map = { - r: found_idx[r] if r != first_rank else 0 for r in found_ranks + r: match_record.found_idx[r] if r != first_rank else 0 + for r in match_record.found_ranks } nccl_calls.extend( - entry_state.to_nccl_call( + match_record.entry_state.to_nccl_call( all_entries, idx_map, len(nccl_calls), collectives[-1].id ) ) @@ -398,19 +306,19 @@ def build_collectives( else: logger.debug("appending a non-matching collective") idx_map = { - r: candidate_idx[r] if r != first_rank else 0 - for r in candidate_ranks + r: match_record.candidate_idx[r] if r != first_rank else 0 + for r in match_record.candidate_ranks } collectives.append( - entry_state.to_collective( + match_record.entry_state.to_collective( len(collectives), - errors=errors, + errors=match_record.errors, idx_map=idx_map, all_entries=all_entries, ) ) nccl_calls.extend( - entry_state.to_nccl_call( + match_record.entry_state.to_nccl_call( all_entries, idx_map, len(nccl_calls), None ) ) diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index d396551f7cdf..5587e7179c77 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -224,7 +224,7 @@ def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None: self.input_sizes = entry["input_sizes"] self.output_sizes = entry["output_sizes"] self.collective_state = entry["state"] - self.collective_frames = entry["frames"] + self.collective_frames = entry.get("frames", []) self.expected_ranks = expected_ranks self.missing_ranks: set[int] self.input_numel: int @@ -316,7 +316,7 @@ def to_collective( output_sizes=entry["output_sizes"], expected_ranks=self.expected_ranks, collective_state=entry["state"], - collective_frames=entry["frames"], + collective_frames=entry.get("frames", []), type_of_mismatch=error, ) return Collective( @@ -560,3 +560,26 @@ def match(self, other: "Op") -> MatchInfo: else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) ) return MatchInfo(MatchState.FULLY_MATCHED) + + +class MatchStateRecord: + def __init__( + self, + expected_ranks: set[int], + other_ranks: list[int], + entry_state: EntryState, + candidate_ranks: set[int], + candidate_idx: dict[int, int], + found_ranks: set[int], + found_idx: dict[int, int], + errors: set[tuple[int, MatchInfo]], + ) -> None: + self.expected_ranks = expected_ranks + self.other_ranks = other_ranks + self.entry_state = entry_state + self.candidate_ranks = candidate_ranks + self.candidate_idx = candidate_idx + self.found_ranks = found_ranks + self.found_idx = found_idx + self.errors = errors + self.has_undecided_case = False diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 02787d3e43c6..0973ec1c17bb 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -13,6 +13,7 @@ Group, MatchInfo, MatchState, + MatchStateRecord, Membership, Op, P2P, @@ -184,6 +185,158 @@ def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int return input_numel != output_numel, input_numel, output_numel +def check_current_entry_match( + all_entries: dict[int, list[dict[str, Any]]], + _pg_guids: dict[tuple[str, int], str], + pg_info: tuple[str, str], + current_entry: dict[str, Any], + _memberships: dict[str, set[Any]], + mismatch: dict[str, int], + match_record: MatchStateRecord, +) -> None: + pg_name, desc = pg_info[0], pg_info[1] + for o in match_record.expected_ranks.intersection(set(match_record.other_ranks)): + for i, e in enumerate(all_entries[o]): # type: ignore[index] + # step over ops from other PGs + # only check match state when seq_id matches + if ( + _pg_guids[(e["process_group"][0], o)] == pg_name + and e["process_group"][1] == desc + and e["collective_seq_id"] == match_record.entry_state.collective_seq_id + ): + match_info = match_one_event(current_entry, e, _memberships, pg_name) + if ( + match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] + and mismatch[pg_name] == 0 + ): + match_record.found_ranks.add(o) + match_record.found_idx[o] = i + match_record.has_undecided_case = ( + match_info.state == MatchState.UNDECIDED + ) + else: + match_record.candidate_ranks.add(o) + match_record.candidate_idx[o] = i + if match_info.state not in [ + MatchState.FULLY_MATCHED, + MatchState.UNDECIDED, + ]: + # Here we assume the current rank is not the source of the error. + # But it's possible that the current rank is the culprit, then users will + # see lots of normal ranks reported as culprit. + # TODO: we need to figure out a better way to handle the case mentioned above. + match_record.errors.add((o, match_info)) + break + + +def error_analysis( + all_entries: dict[int, list[dict[str, Any]]], + match_record: MatchStateRecord, + dumps_ranks: set[int], + first_rank: int, + current_entry: dict[str, Any], + mismatch: dict[str, int], + version: tuple[int, int], + pg_name: str, +) -> None: + major_v, minor_v = version[0], version[1] + # case one: not every rank join the collective or in the flight recorder. + if ( + match_record.candidate_ranks | match_record.found_ranks + ) != match_record.expected_ranks and match_record.expected_ranks - ( + match_record.candidate_ranks | match_record.found_ranks + ) <= dumps_ranks: + mismatch[pg_name] += 1 + logger_msg = "Not all ranks joining collective, sequence number: %s" + missing_ranks = match_record.expected_ranks - ( + match_record.candidate_ranks | match_record.found_ranks + ) + match_record.entry_state.log( + logger, logger_msg, format_frames, missing_ranks=missing_ranks + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + elif ( + len(match_record.candidate_ranks) == 1 + and dumps_ranks == match_record.expected_ranks + ): + # case two: alltoall or alltoall_base case. + if match_record.has_undecided_case: + alltoall_cases = [current_entry] + [ + all_entries[o][match_record.found_idx[o]] + for o in match_record.found_ranks + ] + fail_check, total_input_numel, total_output_numel = check_size_alltoall( + alltoall_cases + ) + if major_v <= 2 and minor_v <= 3: + # We don't log the input/output sizes for alltoall before v2.4, + # so we don't consider the size mismatch as an error for now. + fail_check = False + if fail_check: + # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. + mismatch[pg_name] += 1 + logger_msg = ( + "Input/output mismatch in the collective sequence number: %s" + ) + match_record.entry_state.log( + logger, + logger_msg, + format_frames, + total_numel=(total_input_numel, total_output_numel), + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + match_record.errors.add( + (first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)) + ) + else: + match_record.found_ranks.update(match_record.candidate_ranks) + match_record.found_idx.update(match_record.candidate_idx) + match_record.candidate_idx.clear() + match_record.candidate_ranks.clear() + # case three: all joined and everything matches on all ranks. + else: + match_record.found_ranks.update(match_record.candidate_ranks) + match_record.found_idx.update(match_record.candidate_idx) + match_record.candidate_idx.clear() + match_record.candidate_ranks.clear() + # case four: mismatch cases due to not same type, size mismatch or state mismatch. + elif len(match_record.errors) > 0: + mismatch[pg_name] += 1 + logger_msg = "Collective sequence number: %s has errors" + match_record.entry_state.log( + logger, logger_msg, format_frames, errors=match_record.errors + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + # partial analysis case when we cannot decide what's wrong with this collective entry. + else: + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + if match_record.expected_ranks - dumps_ranks: + mismatch[pg_name] += 1 + logger.info( + "We cannot decide what's wrong with this collective entry " + "because we missed FR dumps from ranks (%s) so we don't have enough " + "information. If you want to debug further use -j to dump all raw trace", + str(match_record.expected_ranks - dumps_ranks), + ) + else: + logger.info( + "No errors found for this collective entry, There could be some " + "other reasons why we see collective timeout." + ) + + def find_coalesced_group( pg_name: str, entries: list[dict[str, Any]], diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index a33ea171edbb..a10d87faf938 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -97,7 +97,9 @@ def get_torch_version(sha: str | None = None) -> str: with open(version_path, "w") as f: f.write("from typing import Optional\n\n") - f.write("__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip']\n") + f.write( + "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n" + ) f.write(f"__version__ = '{version}'\n") # NB: This is not 100% accurate, because you could have built the # library code with DEBUG, but csrc without DEBUG (in which case diff --git a/tools/linter/adapters/docstring_linter.py b/tools/linter/adapters/docstring_linter.py index cb9b4ebe9881..cd67243d1ac9 100644 --- a/tools/linter/adapters/docstring_linter.py +++ b/tools/linter/adapters/docstring_linter.py @@ -1,16 +1,21 @@ from __future__ import annotations +import dataclasses as dc +import itertools +import json import sys import token -from functools import cached_property +from enum import Enum +from functools import cached_property, total_ordering from pathlib import Path -from typing import TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import Self -_PARENT = Path(__file__).parent.absolute() +_FILE = Path(__file__).absolute() _PATH = [Path(p).absolute() for p in sys.path] -if TYPE_CHECKING or _PARENT not in _PATH: +if TYPE_CHECKING or _FILE.parent not in _PATH: from . import _linter else: import _linter @@ -20,149 +25,485 @@ from tokenize import TokenInfo +GRANDFATHER_LIST = Path(str(_FILE).replace(".py", "-grandfather.json")) + +# We tolerate a 10% increase in block size before demanding a docstring +TOLERANCE_PERCENT = 10 + MAX_LINES = {"class": 100, "def": 80} -MIN_DOCSTRING = 16 # docstrings shorter than this are ignored -IGNORE_PROTECTED = True # If True, ignore classes and files whose names start with _. +MIN_DOCSTRING = 50 # docstrings shorter than this are too short ERROR_FMT = "Every {type} with more than {length} lines needs a docstring" DESCRIPTION = """`docstring_linter` reports on long functions, methods or classes without docstrings""" -# How many top violations to report? -REPORT_TOP_RESULTS = 3 +@total_ordering +@dc.dataclass +class Block: + """A block of Python code starting with either `def` or `class`""" + + class Category(str, Enum): + CLASS = "class" + DEF = "def" + + category: Category + + # The sequence of tokens that contains this Block. + # Tokens are represented in `Block` as indexes into `self.tokens` + tokens: Sequence[TokenInfo] = dc.field(repr=False) + + # The name of the function or class being defined + name: str + + # The index of the very first token in the block (the "class" or "def" keyword) + begin: int + + # The index of the first INDENT token for this block + indent: int + + # The index of the DEDENT token for this end of this block + dedent: int + + # The docstring for the block + docstring: str + + # These next members only get filled in after all blocks have been constructed + # and figure out family ties + + # The full qualified name of the block within the file. + # This is the name of this block and all its parents, joined with `.`. + full_name: str = "" + + # The index of this block within the full list of blocks in the file + index: int = 0 + + # Is this block contained within a function definition? + is_local: bool = dc.field(default=False, repr=False) + + # Is this block a function definition in a class definition? + is_method: bool = dc.field(default=False, repr=False) + + # A block index to the parent of this block, or None for a top-level block. + parent: int | None = None + + # A list of block indexes for the children + children: list[int] = dc.field(default_factory=list) + + @property + def start_line(self) -> int: + return self.tokens[max(self.indent, self.index)].start[0] + + @property + def end_line(self) -> int: + return self.tokens[max(self.dedent, self.index)].start[0] + + @property + def line_count(self) -> int: + return self.end_line - self.start_line + + @property + def is_class(self) -> bool: + return self.category == Block.Category.CLASS + + @property + def display_name(self) -> str: + """A user-friendly name like 'class One' or 'def One.method()'""" + ending = "" if self.is_class else "()" + return f"{self.category.value} {self.full_name}{ending}" + + DATA_FIELDS = ( + "category", + "children", + "display_name", + "docstring", + "full_name", + "index", + "is_local", + "is_method", + "line_count", + "parent", + "start_line", + ) + + def as_data(self) -> dict[str, Any]: + d = {i: getattr(self, i) for i in self.DATA_FIELDS} + d["category"] = d["category"].value + return d + + @property + def is_init(self) -> bool: + return not self.is_class and self.name == "__init__" + + def contains(self, b: Block) -> bool: + return self.start_line < b.start_line and self.end_line >= b.end_line + + def __eq__(self, o: object) -> bool: + assert isinstance(o, Block) + return o.tokens is self.tokens and o.index == self.index + + def __hash__(self) -> int: + return super().__hash__() + + def __lt__(self, o: Self) -> bool: + assert isinstance(o, Block) and o.tokens is self.tokens + return o.index < self.index + + +class DocstringFile(_linter.PythonFile): + def __getitem__(self, i: int | slice) -> TokenInfo | Sequence[TokenInfo]: + return self.tokens[i] + + def next_token(self, start: int, token_type: int, error: str) -> int: + for i in range(start, len(self.tokens)): + if self.tokens[i].type == token_type: + return i + raise _linter.ParseError(self.tokens[-1], error) + + def docstring(self, start: int) -> str: + for i in range(start + 1, len(self.tokens)): + tk = self.tokens[i] + if tk.type == token.STRING: + return tk.string + if tk.type not in _linter.EMPTY_TOKENS: + return "" + return "" -def _is_def(t: TokenInfo) -> bool: - return t.type == token.NAME and t.string in ("class", "def") + @cached_property + def indent_to_dedent(self) -> dict[int, int]: + dedents = dict[int, int]() + stack = list[int]() + + for i, t in enumerate(self.tokens): + if t.type == token.INDENT: + stack.append(i) + elif t.type == token.DEDENT: + dedents[stack.pop()] = i + + return dedents + + @cached_property + def errors(self) -> dict[str, str]: + return {} + @cached_property + def blocks(self) -> list[Block]: + blocks: list[Block] = [] + + for i in range(len(self.tokens)): + try: + if (b := self.block(i)) is not None: + blocks.append(b) + except _linter.ParseError as e: + self.errors[e.token.line] = " ".join(e.args) + + for i, parent in enumerate(blocks): + for j in range(i + 1, len(blocks)): + if parent.contains(child := blocks[j]): + child.parent = i + parent.children.append(j) + else: + break -class DocstringLinter(_linter.FileLinter[_linter.PythonFile]): + for i, b in enumerate(blocks): + b.index = i + + parents = [b] + while (p := parents[-1].parent) is not None: + parents.append(blocks[p]) + parents = parents[1:] + + b.is_local = not all(p.is_class for p in parents) + b.is_method = not b.is_class and bool(parents) and parents[0].is_class + + def add_full_names(children: Sequence[Block], prefix: str = "") -> None: + dupes: dict[str, list[Block]] = {} + for b in children: + dupes.setdefault(b.name, []).append(b) + + for dl in dupes.values(): + for i, b in enumerate(dl): + suffix = f"[{i + 1}]" if len(dl) > 1 else "" + b.full_name = prefix + b.name + suffix + + for b in children: + if kids := [blocks[i] for i in b.children]: + add_full_names(kids, b.full_name + ".") + + add_full_names([b for b in blocks if b.parent is None]) + return blocks + + def block(self, begin: int) -> Block | None: + t = self.tokens[begin] + if not (t.type == token.NAME and t.string in ("class", "def")): + return None + + category = Block.Category[t.string.upper()] + try: + ni = self.next_token(begin + 1, token.NAME, "Definition but no name") + name = self.tokens[ni].string + indent = self.next_token(ni + 1, token.INDENT, "Definition but no indent") + dedent = self.indent_to_dedent[indent] + docstring = self.docstring(indent) + except _linter.ParseError: + name = "(ParseError)" + indent = -1 + dedent = -1 + docstring = "" + + return Block( + begin=begin, + category=category, + dedent=dedent, + docstring=docstring, + indent=indent, + name=name, + tokens=self.tokens, + ) + + +class DocstringLinter(_linter.FileLinter[DocstringFile]): linter_name = "docstring_linter" description = DESCRIPTION is_fixer = False - results: dict[str, list[tuple[int, Path, str]]] - def __init__(self, argv: list[str] | None = None) -> None: + path_to_blocks: dict[str, list[dict[str, Any]]] + path_to_errors: dict[str, list[dict[str, Any]]] + + def __init__(self, argv: Sequence[str] | None = None) -> None: super().__init__(argv) - self.results = {} + add_arguments(self.parser.add_argument) + self.path_to_blocks = {} + self.path_to_errors = {} - help = "Maximum number of lines for an undocumented class" - self.parser.add_argument( - "--max-class", "-c", default=MAX_LINES["class"], type=int, help=help - ) + def lint_all(self) -> bool: + success = super().lint_all() + self._report() + self._write_grandfather() + return success - help = "Maximum number of lines for an undocumented function" - self.parser.add_argument( - "--max-def", "-d", default=MAX_LINES["def"], type=int, help=help - ) + def _lint(self, df: DocstringFile) -> Iterator[_linter.LintResult]: + if (p := str(df.path)) in self.path_to_blocks: + print("Repeated file", p, file=sys.stderr) + return - help = "Minimum number of characters for a docstring" - self.parser.add_argument( - "--min-docstring", "-m", default=MIN_DOCSTRING, type=int, help=help - ) + blocks = df.blocks + bad = {b for b in blocks if self._is_bad_block(b, df)} + bad = self._dont_require_constructor_and_class_docs(blocks, bad) + gf = self._grandfathered(df.path, bad) - help = "Lint functions, methods and classes that start with _" - self.parser.add_argument( - "--lint-protected", "-p", action="store_true", help=help - ) + yield from (self._block_result(b, df) for b in sorted(bad - gf)) + + def as_data(b: Block) -> dict[str, Any]: + status = "grandfather" if b in gf else "bad" if b in bad else "good" + return {"status": status, **b.as_data()} + + self.path_to_blocks[p] = [as_data(b) for b in blocks] + + def _error(self, df: DocstringFile, result: _linter.LintResult) -> None: + self.path_to_errors[str(df.path)] = [{str(result.line): result.name}] @cached_property - def max_lines(self) -> dict[str, int]: - return {"class": self.args.max_class, "def": self.args.max_def} + def _grandfather(self) -> dict[str, dict[str, Any]]: + try: + with open(self.args.grandfather) as fp: + return json.load(fp) # type: ignore[no-any-return] + except FileNotFoundError: + return {} + except Exception as e: + print("ERROR:", e, "in", GRANDFATHER_LIST, file=sys.stderr) + raise - def lint_all(self) -> bool: - success = super().lint_all() - if not self.args.lintrunner and self.results: - self._report_results() - return success + @cached_property + def _max_lines(self) -> dict[str, int]: + return {"class": self.args.max_class, "def": self.args.max_def} - def _lint(self, pf: _linter.PythonFile) -> Iterator[_linter.LintResult]: - tokens = pf.tokens - indents = indent_to_dedent(tokens) - defs = [i for i, t in enumerate(tokens) if _is_def(t)] - - def next_token(start: int, token_type: int, error: str) -> int: # type: ignore[return] - for i in range(start, len(tokens)): - if tokens[i].type == token_type: - return i - raise _linter.ParseError(tokens[-1], error) - - for i in defs: - name = next_token(i + 1, token.NAME, "Definition with no name") - if not self.args.lint_protected and tokens[name].string.startswith("_"): - continue - - indent = next_token(name + 1, token.INDENT, "Definition with no indent") - dedent = indents[indent] - - lines = tokens[dedent].start[0] - tokens[indent].start[0] - max_lines = self.max_lines[tokens[i].string] - if lines <= max_lines: - continue - - # Now search for a docstring - docstring_len = -1 - for k in range(indent + 1, len(tokens)): - tk = tokens[k] - if tk.type == token.STRING: - docstring_len = len(tk.string) - break - if tk.type not in _linter.EMPTY_TOKENS: - break + def _grandfathered(self, path: Path | None, bad: set[Block]) -> set[Block]: + if path is None or self.args.no_grandfather or self.args.write_grandfather: + return set() + + grand: dict[str, int] = self._grandfather.get(str(path), {}) + tolerance_ratio = 1 + self.args.grandfather_tolerance / 100.0 + + def grandfathered(b: Block) -> bool: + lines = int(grand.get(b.display_name, 0) * tolerance_ratio) + return b.line_count <= lines + + return {b for b in bad if grandfathered(b)} + + def _block_result(self, b: Block, df: DocstringFile) -> _linter.LintResult: + def_name = "function" if b.category == "def" else "class" + msg = f"docstring found for {def_name} '{b.name}' ({b.line_count} lines)" + if len(b.docstring): + msg = msg + f" was too short ({len(b.docstring)} characters)" + else: + msg = "No " + msg + return _linter.LintResult(msg, *df.tokens[b.begin].start) + + def _display( + self, df: DocstringFile, results: list[_linter.LintResult] + ) -> Iterator[str]: + if not self.args.report: + yield from super()._display(df, results) + + def _dont_require_constructor_and_class_docs( + self, blocks: Sequence[Block], bad: set[Block] + ) -> set[Block]: + if self.args.lint_init: + return bad + + good = {b for b in blocks if len(b.docstring) >= self.args.min_docstring} + + def has_class_init_doc(b: Block) -> bool: + if b.is_class: + # Is it a class whose constructor is documented? + children = (blocks[i] for i in b.children) + return any(b.is_init and b in good for b in children) + + # Is it a constructor whose class is documented? + return b.is_init and b.parent is not None and blocks[b.parent] in good + + return {b for b in bad if not has_class_init_doc(b)} + + def _is_bad_block(self, b: Block, df: DocstringFile) -> bool: + max_lines = self._max_lines[b.category] + return ( + not df.omitted(df.tokens, b.begin, b.dedent) + and b.line_count > max_lines + and len(b.docstring) < self.args.min_docstring + and (self.args.lint_local or not b.is_local) + and (self.args.lint_protected or not b.name.startswith("_")) + ) - if docstring_len >= self.args.min_docstring: - continue - - # Now check if it's omitted - if pf.omitted(pf.tokens[i:indent]): - continue - - t = tokens[i] - def_name = "function" if t.string == "def" else t.string - tname = tokens[name].string - msg = f"docstring found for {def_name} '{tname}' ({lines} lines)" - if docstring_len < 0: - msg = "No " + msg - else: - msg = msg + f" was too short ({docstring_len} characters)" - yield _linter.LintResult(msg, *t.start) - if pf.path is not None: - self.results.setdefault(def_name, []).append((lines, pf.path, tname)) - - def _report_results(self) -> None: - print() - for i, (k, v) in enumerate(sorted(self.results.items())): - if i: - print() - top = sorted(v, reverse=True)[:REPORT_TOP_RESULTS] - if len(top) == 1: - s = "" - t = f"{len(top)} " - else: - s = "es" if k.endswith("s") else "s" - t = "" - print(f"Top {t}undocumented {k}{s}:") - for lines, path, tname in top: - print(f" {lines} lines: {path}:{tname}") - - -def indent_to_dedent(tokens: Sequence[TokenInfo]) -> dict[int, int]: - indent_to_dedent: dict[int, int] = {} - stack: list[int] = [] - - for i, t in enumerate(tokens): - if t.type == token.INDENT: - stack.append(i) - elif t.type == token.DEDENT: - assert stack - indent_to_dedent[stack.pop()] = i - - assert not stack - # Can't happen: the tokenization process would already have failed on a bad indent - - return indent_to_dedent + def _report(self) -> None: + if not self.args.lintrunner and self.path_to_blocks and self.args.report: + report = { + k: s for k, v in self.path_to_blocks.items() if (s := file_summary(v)) + } | self.path_to_errors + print(json.dumps(report, sort_keys=True, indent=2)) + + def _write_grandfather(self) -> None: + if self.args.write_grandfather: + results: dict[str, dict[str, int]] = {} + + for path, blocks in self.path_to_blocks.items(): + for block in blocks: + if block["status"] == "bad": + d = results.setdefault(path, {}) + d[block["display_name"]] = block["line_count"] + + with open(self.args.grandfather, "w") as fp: + json.dump(results, fp, sort_keys=True, indent=2) + + +def make_recursive(blocks: list[dict[str, Any]]) -> list[dict[str, Any]]: + def rec(i: int) -> dict[str, Any]: + d = dict(blocks[i]) + d["children"] = [rec(c) for c in d["children"]] + return d + + return [rec(i) for i, b in enumerate(blocks) if b["parent"] is None] + + +def make_terse( + blocks: Sequence[dict[str, Any]], + index_by_line: bool = True, +) -> dict[str, dict[str, Any]]: + result: dict[str, dict[str, Any]] = {} + + max_line = max(b["start_line"] for b in blocks) if blocks else 0 + line_field_width = len(str(max_line)) + + for b in blocks: + root = f"{b['category']} {b['full_name']}" + for i in itertools.count(): + name = root + bool(i) * f"[{i + 1}]" + if name not in result: + break + + d = { + "docstring_len": len(b["docstring"]), + "lines": b["line_count"], + "status": b.get("status", "good"), + } + + start_line = b["start_line"] + if index_by_line: + d["name"] = name + result[f"{start_line:>{line_field_width}}"] = d + else: + d["line"] = start_line + result[name] = d + + if kids := b["children"]: + if not all(isinstance(k, int) for k in kids): + assert all(isinstance(k, dict) for k in kids) + d["children"] = make_terse(kids) + + return result + + +def file_summary( + blocks: Sequence[dict[str, Any]], report_all: bool = False +) -> dict[str, str]: + def to_line(v: dict[str, Any]) -> str | None: + if (status := v["status"]) == "good": + if not report_all: + return None + fail = "" + elif status == "grandfather": + fail = ": (grandfathered)" + else: + assert status == "bad" + fail = ": FAIL" + name = v["name"] + lines = v["lines"] + docs = v["docstring_len"] + parens = "()" if name.startswith("def ") else "" + return f"{name}{parens}: {lines=}, {docs=}{fail}" + + t = make_terse(blocks) + r = {k: line for k, v in t.items() if (line := to_line(v))} + while r and all(k.startswith(" ") for k in r): + r = {k[1:]: v for k, v in r.items()} + return r + + +def add_arguments(add: Callable[..., Any]) -> None: + h = "Set the grandfather list" + add("--grandfather", "-g", default=str(GRANDFATHER_LIST), type=str, help=h) + + h = "Tolerance for grandfather sizes, in percent" + add("--grandfather-tolerance", "-t", default=TOLERANCE_PERCENT, type=float, help=h) + + h = "Lint __init__ and class separately" + add("--lint-init", "-i", action="store_true", help=h) + + h = "Lint definitions inside other functions" + add("--lint-local", "-o", action="store_true", help=h) + + h = "Lint functions, methods and classes that start with _" + add("--lint-protected", "-p", action="store_true", help=h) + + h = "Maximum number of lines for an undocumented class" + add("--max-class", "-c", default=MAX_LINES["class"], type=int, help=h) + + h = "Maximum number of lines for an undocumented function" + add("--max-def", "-d", default=MAX_LINES["def"], type=int, help=h) + + h = "Minimum number of characters for a docstring" + add("--min-docstring", "-s", default=MIN_DOCSTRING, type=int, help=h) + + h = "Disable the grandfather list" + add("--no-grandfather", "-n", action="store_true", help=h) + + h = "Print a report on all classes and defs" + add("--report", "-r", action="store_true", help=h) + + h = "Rewrite the grandfather list" + add("--write-grandfather", "-w", action="store_true", help=h) if __name__ == "__main__": diff --git a/tools/nightly.py b/tools/nightly.py index 45ca897cbe55..9fa1dcba9f51 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -50,6 +50,7 @@ import subprocess import sys import tempfile +import textwrap import time import uuid from ast import literal_eval @@ -340,6 +341,44 @@ def create(self, *, remove_if_exists: bool = False) -> Path: self.base_python("-m", "venv", str(self.prefix)) assert self.is_venv(), "Failed to create virtual environment." (self.prefix / ".gitignore").write_text("*\n", encoding="utf-8") + + if LINUX: + activate_script = self.activate_script + st_mode = activate_script.stat().st_mode + # The activate script may be read-only and we need to add write permissions + activate_script.chmod(st_mode | 0o200) + with activate_script.open(mode="a", encoding="utf-8") as f: + f.write( + "\n" + + textwrap.dedent( + f""" + # Add NVIDIA PyPI packages to LD_LIBRARY_PATH + export LD_LIBRARY_PATH="$( + {self.executable.name} - < Path: diff --git a/tools/test/docstring_linter_testdata/block_names.py.txt b/tools/test/docstring_linter_testdata/block_names.py.txt new file mode 100644 index 000000000000..a3a41ec9cb46 --- /dev/null +++ b/tools/test/docstring_linter_testdata/block_names.py.txt @@ -0,0 +1,44 @@ +def top(number): + if number == 0: + + def fun(): + if number == 10: + def sab(): + return 1 + else: + def sub(): + return 2 + return sub + + elif number == 1: + + def fun(): + if number == 11: + def sub(): + return 3 + else: + def sub(): + return 4 + return sub + + elif number == 2: + + def fun(): + if number == 12: + def sub(): + return 5 + else: + def sab(): + return 6 + return sub + + elif number == 3: + + def run(): + if number == 12: + def sub(): + return 5 + else: + def sub(): + return 6 + return sub diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.json b/tools/test/docstring_linter_testdata/python_code.py.txt.json index eebee3718730..5efc13550f3d 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.json @@ -32,39 +32,6 @@ "replacement": null, "severity": "error" }, - { - "char": 8, - "code": "DOCSTRING_LINTER", - "description": null, - "line": 72, - "name": "No docstring found for function 'not_short' (11 lines)", - "original": null, - "path": "tools/test/docstring_linter_testdata/python_code.py.txt", - "replacement": null, - "severity": "error" - }, - { - "char": 12, - "code": "DOCSTRING_LINTER", - "description": null, - "line": 73, - "name": "No docstring found for class 'Long' (6 lines)", - "original": null, - "path": "tools/test/docstring_linter_testdata/python_code.py.txt", - "replacement": null, - "severity": "error" - }, - { - "char": 0, - "code": "DOCSTRING_LINTER", - "description": null, - "line": 84, - "name": "No docstring found for class 'NotDocstring' (12 lines)", - "original": null, - "path": "tools/test/docstring_linter_testdata/python_code.py.txt", - "replacement": null, - "severity": "error" - }, { "char": null, "code": "DOCSTRING_LINTER", diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner b/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner index 07adffee6d84..a787cb1ecb32 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner @@ -21,36 +21,3 @@ tools/test/docstring_linter_testdata/python_code.py.txt:71: No docstring found f ^ 72 | def not_short(): 73 | class Long: - -tools/test/docstring_linter_testdata/python_code.py.txt:72: No docstring found for function 'not_short' (11 lines) - 70 | - 71 | def needs_docs(self): - 72 | def not_short(): - ^ - 73 | class Long: - 74 | a = 1 - -tools/test/docstring_linter_testdata/python_code.py.txt:73: No docstring found for class 'Long' (6 lines) - 71 | def needs_docs(self): - 72 | def not_short(): - 73 | class Long: - ^ - 74 | a = 1 - 75 | b = 1 - -tools/test/docstring_linter_testdata/python_code.py.txt:84: No docstring found for class 'NotDocstring' (12 lines) - 82 | - 83 | - 84 | class NotDocstring: - ^ - 85 | def short1(self): - 86 | pass - -Top undocumented classes: - 12 lines: tools/test/docstring_linter_testdata/python_code.py.txt:NotDocstring - 6 lines: tools/test/docstring_linter_testdata/python_code.py.txt:LongWithShortDocstring - 6 lines: tools/test/docstring_linter_testdata/python_code.py.txt:Long - -Top undocumented functions: - 12 lines: tools/test/docstring_linter_testdata/python_code.py.txt:needs_docs - 11 lines: tools/test/docstring_linter_testdata/python_code.py.txt:not_short diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.report.json b/tools/test/docstring_linter_testdata/python_code.py.txt.report.json new file mode 100644 index 000000000000..2ccc6f05703d --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.report.json @@ -0,0 +1,325 @@ +[ + { + "category": "class", + "children": [], + "display_name": "class ShortWithDocstring", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "ShortWithDocstring", + "index": 0, + "is_local": false, + "is_method": false, + "line_count": 4, + "parent": null, + "start_line": 2 + }, + { + "category": "class", + "children": [], + "display_name": "class Short", + "docstring": "", + "full_name": "Short", + "index": 1, + "is_local": false, + "is_method": false, + "line_count": 3, + "parent": null, + "start_line": 7 + }, + { + "category": "class", + "children": [ + 3 + ], + "display_name": "class LongWithDocstring", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "LongWithDocstring", + "index": 2, + "is_local": false, + "is_method": false, + "line_count": 6, + "parent": null, + "start_line": 11 + }, + { + "category": "def", + "children": [], + "display_name": "def LongWithDocstring.short1()", + "docstring": "", + "full_name": "LongWithDocstring.short1", + "index": 3, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 2, + "start_line": 14 + }, + { + "category": "class", + "children": [ + 5 + ], + "display_name": "class LongWithoutDocstring", + "docstring": "", + "full_name": "LongWithoutDocstring", + "index": 4, + "is_local": false, + "is_method": false, + "line_count": 4, + "parent": null, + "start_line": 20 + }, + { + "category": "def", + "children": [], + "display_name": "def LongWithoutDocstring.short1()", + "docstring": "", + "full_name": "LongWithoutDocstring.short1", + "index": 5, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 4, + "start_line": 21 + }, + { + "category": "class", + "children": [ + 7 + ], + "display_name": "class LongWithShortDocstring", + "docstring": "\"\"\"TODO\"\"\"", + "full_name": "LongWithShortDocstring", + "index": 6, + "is_local": false, + "is_method": false, + "line_count": 6, + "parent": null, + "start_line": 25 + }, + { + "category": "def", + "children": [], + "display_name": "def LongWithShortDocstring.short1()", + "docstring": "", + "full_name": "LongWithShortDocstring.short1", + "index": 7, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 6, + "start_line": 28 + }, + { + "category": "class", + "children": [ + 9 + ], + "display_name": "class _Protected", + "docstring": "\"\"\"TODO\"\"\"", + "full_name": "_Protected", + "index": 8, + "is_local": false, + "is_method": false, + "line_count": 6, + "parent": null, + "start_line": 32 + }, + { + "category": "def", + "children": [], + "display_name": "def _Protected.short1()", + "docstring": "", + "full_name": "_Protected.short1", + "index": 9, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 8, + "start_line": 35 + }, + { + "category": "def", + "children": [], + "display_name": "def short()", + "docstring": "", + "full_name": "short", + "index": 10, + "is_local": false, + "is_method": false, + "line_count": 3, + "parent": null, + "start_line": 42 + }, + { + "category": "def", + "children": [], + "display_name": "def long()", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "long", + "index": 11, + "is_local": false, + "is_method": false, + "line_count": 8, + "parent": null, + "start_line": 46 + }, + { + "category": "def", + "children": [], + "display_name": "def long_without_docstring()", + "docstring": "", + "full_name": "long_without_docstring", + "index": 12, + "is_local": false, + "is_method": false, + "line_count": 3, + "parent": null, + "start_line": 59 + }, + { + "category": "class", + "children": [ + 14, + 15, + 16, + 17 + ], + "display_name": "class ImpossibleCombo", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "ImpossibleCombo", + "index": 13, + "is_local": false, + "is_method": false, + "line_count": 15, + "parent": null, + "start_line": 69 + }, + { + "category": "def", + "children": [ + 15, + 16, + 17 + ], + "display_name": "def ImpossibleCombo.needs_docs()", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs", + "index": 14, + "is_local": false, + "is_method": true, + "line_count": 12, + "parent": 13, + "start_line": 72 + }, + { + "category": "def", + "children": [ + 16, + 17 + ], + "display_name": "def ImpossibleCombo.needs_docs.not_short()", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs.not_short", + "index": 15, + "is_local": true, + "is_method": false, + "line_count": 11, + "parent": 14, + "start_line": 73 + }, + { + "category": "class", + "children": [], + "display_name": "class ImpossibleCombo.needs_docs.not_short.Long", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs.not_short.Long", + "index": 16, + "is_local": true, + "is_method": false, + "line_count": 6, + "parent": 15, + "start_line": 74 + }, + { + "category": "class", + "children": [], + "display_name": "class ImpossibleCombo.needs_docs.not_short.Short", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs.not_short.Short", + "index": 17, + "is_local": true, + "is_method": false, + "line_count": 3, + "parent": 15, + "start_line": 81 + }, + { + "category": "class", + "children": [ + 19, + 20, + 21 + ], + "display_name": "class NotDocstring", + "docstring": "", + "full_name": "NotDocstring", + "index": 18, + "is_local": false, + "is_method": false, + "line_count": 12, + "parent": null, + "start_line": 85 + }, + { + "category": "def", + "children": [], + "display_name": "def NotDocstring.short1()", + "docstring": "", + "full_name": "NotDocstring.short1", + "index": 19, + "is_local": false, + "is_method": true, + "line_count": 2, + "parent": 18, + "start_line": 86 + }, + { + "category": "def", + "children": [], + "display_name": "def NotDocstring.short2()", + "docstring": "", + "full_name": "NotDocstring.short2", + "index": 20, + "is_local": false, + "is_method": true, + "line_count": 2, + "parent": 18, + "start_line": 91 + }, + { + "category": "def", + "children": [], + "display_name": "def NotDocstring.short3()", + "docstring": "", + "full_name": "NotDocstring.short3", + "index": 21, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 18, + "start_line": 94 + }, + { + "category": "def", + "children": [], + "display_name": "def long_with_omit()", + "docstring": "", + "full_name": "long_with_omit", + "index": 22, + "is_local": false, + "is_method": false, + "line_count": 1, + "parent": null, + "start_line": 102 + } +] diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json new file mode 100644 index 000000000000..bbf71643c76a --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json @@ -0,0 +1,25 @@ +{ + " 2": "class ShortWithDocstring: lines=4, docs=44", + " 7": "class Short: lines=3, docs=0", + " 11": "class LongWithDocstring: lines=6, docs=44", + " 14": "def LongWithDocstring.short1(): lines=3, docs=0", + " 20": "class LongWithoutDocstring: lines=4, docs=0", + " 21": "def LongWithoutDocstring.short1(): lines=3, docs=0", + " 25": "class LongWithShortDocstring: lines=6, docs=10", + " 28": "def LongWithShortDocstring.short1(): lines=3, docs=0", + " 32": "class _Protected: lines=6, docs=10", + " 35": "def _Protected.short1(): lines=3, docs=0", + " 42": "def short(): lines=3, docs=0", + " 46": "def long(): lines=8, docs=44", + " 59": "def long_without_docstring(): lines=3, docs=0", + " 69": "class ImpossibleCombo: lines=15, docs=44", + " 72": "def ImpossibleCombo.needs_docs(): lines=12, docs=0", + " 73": "def ImpossibleCombo.needs_docs.not_short(): lines=11, docs=0", + " 74": "class ImpossibleCombo.needs_docs.not_short.Long: lines=6, docs=0", + " 81": "class ImpossibleCombo.needs_docs.not_short.Short: lines=3, docs=0", + " 85": "class NotDocstring: lines=12, docs=0", + " 86": "def NotDocstring.short1(): lines=2, docs=0", + " 91": "def NotDocstring.short2(): lines=2, docs=0", + " 94": "def NotDocstring.short3(): lines=3, docs=0", + "102": "def long_with_omit(): lines=1, docs=0" +} diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json new file mode 100644 index 000000000000..0b86e9e6ba1e --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json @@ -0,0 +1,140 @@ +{ + "class ImpossibleCombo": { + "docstring_len": 44, + "line": 69, + "lines": 15, + "status": "good" + }, + "class ImpossibleCombo.needs_docs.not_short.Long": { + "docstring_len": 0, + "line": 74, + "lines": 6, + "status": "good" + }, + "class ImpossibleCombo.needs_docs.not_short.Short": { + "docstring_len": 0, + "line": 81, + "lines": 3, + "status": "good" + }, + "class LongWithDocstring": { + "docstring_len": 44, + "line": 11, + "lines": 6, + "status": "good" + }, + "class LongWithShortDocstring": { + "docstring_len": 10, + "line": 25, + "lines": 6, + "status": "good" + }, + "class LongWithoutDocstring": { + "docstring_len": 0, + "line": 20, + "lines": 4, + "status": "good" + }, + "class NotDocstring": { + "docstring_len": 0, + "line": 85, + "lines": 12, + "status": "good" + }, + "class Short": { + "docstring_len": 0, + "line": 7, + "lines": 3, + "status": "good" + }, + "class ShortWithDocstring": { + "docstring_len": 44, + "line": 2, + "lines": 4, + "status": "good" + }, + "class _Protected": { + "docstring_len": 10, + "line": 32, + "lines": 6, + "status": "good" + }, + "def ImpossibleCombo.needs_docs": { + "docstring_len": 0, + "line": 72, + "lines": 12, + "status": "good" + }, + "def ImpossibleCombo.needs_docs.not_short": { + "docstring_len": 0, + "line": 73, + "lines": 11, + "status": "good" + }, + "def LongWithDocstring.short1": { + "docstring_len": 0, + "line": 14, + "lines": 3, + "status": "good" + }, + "def LongWithShortDocstring.short1": { + "docstring_len": 0, + "line": 28, + "lines": 3, + "status": "good" + }, + "def LongWithoutDocstring.short1": { + "docstring_len": 0, + "line": 21, + "lines": 3, + "status": "good" + }, + "def NotDocstring.short1": { + "docstring_len": 0, + "line": 86, + "lines": 2, + "status": "good" + }, + "def NotDocstring.short2": { + "docstring_len": 0, + "line": 91, + "lines": 2, + "status": "good" + }, + "def NotDocstring.short3": { + "docstring_len": 0, + "line": 94, + "lines": 3, + "status": "good" + }, + "def _Protected.short1": { + "docstring_len": 0, + "line": 35, + "lines": 3, + "status": "good" + }, + "def long": { + "docstring_len": 44, + "line": 46, + "lines": 8, + "status": "good" + }, + "def long_with_omit": { + "docstring_len": 0, + "line": 102, + "lines": 1, + "status": "good" + }, + "def long_without_docstring": { + "docstring_len": 0, + "line": 59, + "lines": 3, + "status": "good" + }, + "def short": { + "docstring_len": 0, + "line": 42, + "lines": 3, + "status": "good" + } +} diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json new file mode 100644 index 000000000000..ee2facfc6b5d --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json @@ -0,0 +1,140 @@ +{ + " 2": { + "docstring_len": 44, + "lines": 4, + "name": "class ShortWithDocstring", + "status": "good" + }, + " 7": { + "docstring_len": 0, + "lines": 3, + "name": "class Short", + "status": "good" + }, + " 11": { + "docstring_len": 44, + "lines": 6, + "name": "class LongWithDocstring", + "status": "good" + }, + " 14": { + "docstring_len": 0, + "lines": 3, + "name": "def LongWithDocstring.short1", + "status": "good" + }, + " 20": { + "docstring_len": 0, + "lines": 4, + "name": "class LongWithoutDocstring", + "status": "good" + }, + " 21": { + "docstring_len": 0, + "lines": 3, + "name": "def LongWithoutDocstring.short1", + "status": "good" + }, + " 25": { + "docstring_len": 10, + "lines": 6, + "name": "class LongWithShortDocstring", + "status": "good" + }, + " 28": { + "docstring_len": 0, + "lines": 3, + "name": "def LongWithShortDocstring.short1", + "status": "good" + }, + " 32": { + "docstring_len": 10, + "lines": 6, + "name": "class _Protected", + "status": "good" + }, + " 35": { + "docstring_len": 0, + "lines": 3, + "name": "def _Protected.short1", + "status": "good" + }, + " 42": { + "docstring_len": 0, + "lines": 3, + "name": "def short", + "status": "good" + }, + " 46": { + "docstring_len": 44, + "lines": 8, + "name": "def long", + "status": "good" + }, + " 59": { + "docstring_len": 0, + "lines": 3, + "name": "def long_without_docstring", + "status": "good" + }, + " 69": { + "docstring_len": 44, + "lines": 15, + "name": "class ImpossibleCombo", + "status": "good" + }, + " 72": { + "docstring_len": 0, + "lines": 12, + "name": "def ImpossibleCombo.needs_docs", + "status": "good" + }, + " 73": { + "docstring_len": 0, + "lines": 11, + "name": "def ImpossibleCombo.needs_docs.not_short", + "status": "good" + }, + " 74": { + "docstring_len": 0, + "lines": 6, + "name": "class ImpossibleCombo.needs_docs.not_short.Long", + "status": "good" + }, + " 81": { + "docstring_len": 0, + "lines": 3, + "name": "class ImpossibleCombo.needs_docs.not_short.Short", + "status": "good" + }, + " 85": { + "docstring_len": 0, + "lines": 12, + "name": "class NotDocstring", + "status": "good" + }, + " 86": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.short1", + "status": "good" + }, + " 91": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.short2", + "status": "good" + }, + " 94": { + "docstring_len": 0, + "lines": 3, + "name": "def NotDocstring.short3", + "status": "good" + }, + "102": { + "docstring_len": 0, + "lines": 1, + "name": "def long_with_omit", + "status": "good" + } +} diff --git a/tools/test/test_docstring_linter.py b/tools/test/test_docstring_linter.py index 85ea26de4e77..d09c84de131a 100644 --- a/tools/test/test_docstring_linter.py +++ b/tools/test/test_docstring_linter.py @@ -1,10 +1,14 @@ # mypy: ignore-errors -from __future__ import annotations +import json import sys from pathlib import Path -from tools.linter.adapters.docstring_linter import DocstringLinter +from tools.linter.adapters.docstring_linter import ( + DocstringLinter, + file_summary, + make_terse, +) _PARENT = Path(__file__).parent.absolute() @@ -16,11 +20,69 @@ from .linter_test_case import LinterTestCase TEST_FILE = Path("tools/test/docstring_linter_testdata/python_code.py.txt") +TEST_FILE2 = Path("tools/test/docstring_linter_testdata/more_python_code.py.txt") +TEST_BLOCK_NAMES = Path("tools/test/docstring_linter_testdata/block_names.py.txt") +ARGS = "--max-class=3", "--max-def=4", "--min-docstring=16" class TestDocstringLinter(LinterTestCase): LinterClass = DocstringLinter + maxDiff = 10_240 def test_python_code(self): - args = "--max-class=3 --max-def=4".split() - self.lint_test(TEST_FILE, args) + self.lint_test(TEST_FILE, ARGS) + + def test_report(self): + actual = _dumps(_data()) + self.assertExpected(TEST_FILE, actual, "report.json") + + def test_terse(self): + terse = make_terse(_data(), index_by_line=False) + actual = _dumps(terse) + self.assertExpected(TEST_FILE, actual, "terse.json") + + def test_terse_line(self): + terse = make_terse(_data(), index_by_line=True) + actual = _dumps(terse) + self.assertExpected(TEST_FILE, actual, "terse.line.json") + + def test_file_summary(self): + actual = _dumps(file_summary(_data(), report_all=True)) + self.assertExpected(TEST_FILE, actual, "single.line.json") + + def test_file_names(self): + f = DocstringLinter.make_file(TEST_BLOCK_NAMES) + actual = [b.full_name for b in f.blocks] + expected = [ + "top", + "top.fun[1]", + "top.fun[1].sab", + "top.fun[1].sub", + "top.fun[2]", + "top.fun[2].sub[1]", + "top.fun[2].sub[2]", + "top.fun[3]", + "top.fun[3].sub", + "top.fun[3].sab", + "top.run", + "top.run.sub[1]", + "top.run.sub[2]", + ] + self.assertEqual(actual, expected) + + +def _dumps(d: dict) -> str: + return json.dumps(d, sort_keys=True, indent=2) + "\n" + + +def _data(): + docstring_file = DocstringLinter.make_file(TEST_FILE) + return [b.as_data() for b in docstring_file.blocks] + + +def _next_stdout(mock_stdout): + length = 0 + while True: + s = mock_stdout.getvalue() + yield s[length:] + length = len(s) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 8b8ebdc6e976..67fe1df8ca87 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -251,6 +251,8 @@ add_custom_command( "${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml" "${TORCH_ROOT}/aten/src/ATen/native/tags.yaml" "${TORCH_ROOT}/tools/autograd/deprecated.yaml" + "${TORCH_ROOT}/torch/_torch_docs.py" + "${TORCH_ROOT}/torch/_tensor_docs.py" ${pyi_python} ${autograd_python} ${torchgen_python} diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 09744f2b043d..3c487e321c8c 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -913,6 +913,12 @@ class AliasInfo: is_write: _bool before_set: Set[str] after_set: Set[str] + def __init__( + self, + is_write: _bool, + before_set: Set[str], + after_set: Set[str] + ) -> None: ... # Defined in torch/aten/src/ATen/core/function_schema.h class Argument: @@ -925,6 +931,15 @@ class Argument: alias_info: Optional[AliasInfo] is_write: _bool real_type: JitType + def __init__( + self, + name: str, + type: JitType, + N: Optional[_int], + defualt_value: Optional[Any], + kwarg_only: _bool, + alias_info: Optional[AliasInfo] + ) -> None: ... class FunctionSchema: arguments: List[Argument] @@ -932,6 +947,15 @@ class FunctionSchema: name: str overload_name: str is_mutable: _bool + def __init__( + self, + name: str, + overload_name: str, + arguments: List[Argument], + returns: List[Argument], + is_vararg: _bool, + is_varret: _bool + ) -> None: ... class _UpgraderEntry: bumped_at_version: _int @@ -1358,6 +1382,8 @@ def _set_grad_enabled(enabled: _bool) -> None: ... def is_grad_enabled() -> _bool: ... def _set_fwd_grad_enabled(enabled: _bool) -> None: ... def _is_fwd_grad_enabled() -> _bool: ... +def _any_requires_grad(*args, **kwargs) -> _bool: ... +def _any_output_is_alias_to_input_or_output(*args, **kwargs) -> _bool: ... def is_inference_mode_enabled() -> _bool: ... @overload def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ... @@ -2537,12 +2563,6 @@ class _NodeBase: return_type: Any, ) -> None: ... def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ... - def _prepend(self, n: FxNode) -> None: ... - def _remove_from_list(self) -> None: ... - def __lt__(self, n: Self) -> _bool: ... - def __gt__(self, n: Self) -> _bool: ... - def __le__(self, n: Self) -> _bool: ... - def __ge__(self, n: Self) -> _bool: ... class _NodeIter(Iterator): def __init__(self, root: FxNode, reversed: _bool) -> None: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 77a8f9c33e04..0487eb7c924a 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -2,7 +2,7 @@ # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum -from typing import Any, overload +from typing import Any, Optional, overload import torch from torch import Tensor @@ -139,6 +139,8 @@ class BroadcastOptions: class AllreduceOptions: reduceOp: ReduceOp timeout: timedelta + asyncOp: bool + sparseIndices: Optional[Tensor] class AllreduceCoalescedOptions(AllreduceOptions): ... @@ -147,6 +149,7 @@ class ReduceOptions: rootRank: int rootTensor: int timeout: timedelta + asyncOp: bool class AllgatherOptions: timeout: timedelta @@ -155,6 +158,7 @@ class AllgatherOptions: class GatherOptions: rootRank: int timeout: timedelta + asyncOp: bool class ScatterOptions: rootRank: int @@ -170,9 +174,11 @@ class BarrierOptions: device_ids: list[int] device: torch.device timeout: timedelta + asyncOp: bool class AllToAllOptions: timeout: timedelta + asyncOp: bool class Store: def set(self, key: str, value: str): ... @@ -564,9 +570,9 @@ class ProcessGroupGloo(Backend): timeout: timedelta, ) -> None: ... @staticmethod - def create_device(hostname="", interface="") -> Device: ... + def create_device(hostname="", interface="", lazy_init=None) -> Device: ... @staticmethod - def create_default_device() -> Device: ... + def create_default_device(lazy_init=None) -> Device: ... def _set_default_timeout(self, timeout) -> None: ... class _ProcessGroupWrapper(Backend): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 6be4f1d276ef..94cf3aeeb1d2 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1482,7 +1482,7 @@ def _addmm_activation( @register_decomposition(aten.addmv) -@out_wrapper() +@out_wrapper(exact_dtype=True) @pw_cast_for_opmath def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): if not self.is_floating_point() and not self.is_complex(): @@ -4338,7 +4338,7 @@ def grid_sampler_2d( @register_decomposition(aten.mv) -@out_wrapper() +@out_wrapper(exact_dtype=True) @pw_cast_for_opmath def mv(self, vec): torch._check( @@ -5031,7 +5031,7 @@ def inplace_op(*args, **kwargs): @register_decomposition([aten.baddbmm]) -@out_wrapper() +@out_wrapper(exact_dtype=True) @pw_cast_for_opmath def baddbmm(self, batch1, batch2, beta=1, alpha=1): if not self.is_floating_point() and not self.is_complex(): diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index f92d16bf2b30..246596bcbcab 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -69,7 +69,12 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs): def wrap_bw_compiler(bw_compiler_fn): def _wrapped_bw_compiler(*args, **kwargs): # stop TorchDynamo from trying to compile our generated backwards pass - return disable(disable(bw_compiler_fn)(*args, **kwargs)) + return disable( + disable( + bw_compiler_fn, reason="do not trace backward compiler function" + )(*args, **kwargs), + reason="do not trace generated backwards pass", + ) return _wrapped_bw_compiler @@ -100,7 +105,7 @@ def _wrapped_bw_compiler(*args, **kwargs): with enable_aot_logging(), patch_config: cg = aot_module_simplified(gm, example_inputs, **self.kwargs) counters["aot_autograd"]["ok"] += 1 - return disable(cg) + return disable(cg, reason="do not trace AOT-compiled graph") except TensorifyScalarRestartAnalysis: raise except Exception: diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 05dd42866e81..1a1f44609112 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -18,7 +18,7 @@ import sys import types from collections import Counter -from typing import Optional, Union +from typing import Optional, TYPE_CHECKING, Union import torch.nn from torch.utils._ordered_set import OrderedSet @@ -54,6 +54,10 @@ from .variables.torch_function import TensorWithTFOverrideVariable +if TYPE_CHECKING: + from .symbolic_convert import InstructionTranslatorBase + + @dataclasses.dataclass class GraphOutputEntry: index: int @@ -67,7 +71,7 @@ class PyCodegen: def __init__( self, - tx=None, + tx: "InstructionTranslatorBase", root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, tempvars=None, @@ -345,10 +349,10 @@ def gen_fn(): context=str(value), explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", hints=[ - "If Dynamo attempting to trace a return statement and your code is attempting to return a variable " + "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable " "that Dynamo cannot reconstruct, then remove it from the return statement.", *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK, - "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have" + "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have " "reconstruction rules may be fundamentally unreconstructable.", ], ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5d58efdeed09..870291a43785 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -152,26 +152,16 @@ # Non-Inductor backends can use this list for graph freezing. prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1" - -# This feature doesn't really work. We offer this flag for experimental -# purposes / if you want to help us build out support. -# -# torchdynamo has limited support for tensor subclasses that implement -# __torch_function__ see [Note: __torch_function__] in torch_function.py. -# Our current support is limited to tensor subclasses -# that DO NOT store metadata on the tensor (in general, dynamo does not -# support Python code that stores extra attributes on tensors at present). -# If your tensor subclass purely changes function call behavior via -# __torch_function__, you can allow torchdynamo to trace into it by -# adding it to traceable_tensor_subclasses. We don't do any safety checks, -# so it is up to you to ensure that your subclass is well behaved. See also -# https://github.com/pytorch/torchdynamo/issues/1948 -# -# We do NOT currently support __torch_dispatch__. The implementation is -# currently buggy, the main show stopper for nontrivial use is -# https://github.com/pytorch/torchdynamo/issues/1952 +# NOTE this has been deprecated, it does nothing now. traceable_tensor_subclasses: set[type[Any]] = set() +# If a tensor subclass is put into this set, Dynamo will model its instasnces in +# a very conservative and limited way (most likely causing lots of graph breaks +# if one apply tensor ops on these instances). This is useful if you encounter +# internal compiler errors from Dynamo which are caused by tensor subclasses, +# and you are willing to tolerate potential graph breaks rather than hard error. +nontraceable_tensor_subclasses: set[type[Any]] = set() + # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing @@ -411,6 +401,9 @@ # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True +# Enable tracing through unittest +enable_trace_unittest = False + # Enable tracing generator functions lazily. If False, Dynamo will exhaust # generators upon first execution. And if True, the generator will be accessed lazily enable_faithful_generator_behavior = True diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index a31b1f7e59c3..44d19986707d 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -774,9 +774,6 @@ def compile_inner( dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ) ) - stack.enter_context( - _WaitCounter("pytorch.wait_counter.dynamo_compile").guard() - ) stack.enter_context(torch._dynamo.callback_handler.install_callbacks()) stack.enter_context(CompileTimeInstructionCounter.record()) return _compile_inner(code, one_graph, hooks, transform) @@ -957,7 +954,9 @@ def count_args(code: CodeType) -> int: chromium_event_timed( "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True ), + _WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(), metrics_context, + _WaitCounter("pytorch.wait_counter.dynamo_compile").guard(), ): restart_reasons: set[str] = set() # This is shared across restarts diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index def6c5fd2919..5d966c5d1f64 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -65,7 +65,7 @@ def run(fn=None): return RunOnlyContext() -def disable(fn=None, recursive=True): +def disable(fn=None, recursive=True, *, reason=None): """ Decorator to disable TorchDynamo @@ -74,13 +74,15 @@ def disable(fn=None, recursive=True): If recursive=False, Dynamo skips frames associated with the function code, but still process recursively invoked frames. + + If reason is provided, it will be printed when Dynamo attempts to trace the disabled function. """ if recursive: if fn is not None: fn = innermost_fn(fn) assert callable(fn) - return DisableContext()(fn) - return DisableContext() + return DisableContext(msg=reason)(fn) + return DisableContext(msg=reason) else: def wrap(fn): @@ -89,6 +91,7 @@ def wrap(fn): nonrecursive_disable_wrapper = get_nonrecursive_disable_wrapper(fn) nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined] + nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] return nonrecursive_disable_wrapper diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index d8610915ec3a..b24a94ea7cd5 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -376,7 +376,7 @@ def is_bf16_supported(including_emulation: bool = False) -> bool: def is_dtype_supported( cls, dtype: torch.dtype, including_emulation: bool = False ) -> bool: - if dtype == torch.float64: + if dtype in [torch.float64, torch.complex128]: return False return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 18450464197b..1bcad8ef5b5f 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -597,7 +597,7 @@ def get_compiler_config(): filename = inspect.getsourcefile(fn) except TypeError: filename = None - if ( + if config.wrap_top_frame or ( (filename is None or trace_rules.check(fn)) and ( getattr(fn, "__name__", "") @@ -805,8 +805,9 @@ def __reduce__(self): class DisableContext(_TorchDynamoContext): - def __init__(self) -> None: + def __init__(self, msg: Optional[str] = None) -> None: super().__init__(callback=None) + self.msg = msg def __call__(self, fn): # Earlier this code was in the base class _TorchDynamoContext. But we @@ -854,6 +855,7 @@ def _fn(*args, **kwargs): _maybe_set_eval_frame(prior) _fn._torchdynamo_disable = True # type: ignore[attr-defined] + _fn._torchdynamo_disable_msg = self.msg # type: ignore[attr-defined] # Save the function pointer to find the original callable while nesting # of decorators. @@ -1899,11 +1901,20 @@ def patch(): # with torch.deploy internally. from .decorators import disable - torch.jit.trace = disable(torch.jit.trace) - torch.jit.trace_module = disable(torch.jit.trace_module) - torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) + torch.jit.trace = disable( + torch.jit.trace, reason="tracing into TorchScript not fully supported" + ) + torch.jit.trace_module = disable( + torch.jit.trace_module, + reason="tracing into TorchScript not fully supported", + ) + torch.jit._get_trace_graph = disable( + torch.jit._get_trace_graph, + reason="tracing into TorchScript not fully supported", + ) torch.fx._symbolic_trace.Tracer.trace = disable( - torch.fx._symbolic_trace.Tracer.trace + torch.fx._symbolic_trace.Tracer.trace, + reason="tracing into FX not fully supported", ) torch.distributions.Distribution.set_default_validate_args(False) @@ -1945,7 +1956,12 @@ def patch(): if hasattr(opt_mod, fused_fn_name): setattr( - opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) + opt_mod, + fused_fn_name, + disable( + getattr(opt_mod, fused_fn_name), + reason="don't trace into fused optimizer", + ), ) optimizer_classes = [ @@ -1962,10 +1978,14 @@ def patch(): for opt in optimizer_classes: if opt in excluded_optimizer_classes: - opt.step = disable(opt.step) + opt.step = disable( + opt.step, reason=f"optimizer {opt} step not supported" + ) if hasattr(opt, "_init_group"): - opt._init_group = disable(opt._init_group) + opt._init_group = disable( + opt._init_group, reason=f"optimizer {opt} _init_group not supported" + ) @staticmethod def suppress_torch_distributed_warnings(fn): diff --git a/torch/_dynamo/graph_deduplication.py b/torch/_dynamo/graph_deduplication.py index c9ee689e3da5..3a3f7e65491a 100644 --- a/torch/_dynamo/graph_deduplication.py +++ b/torch/_dynamo/graph_deduplication.py @@ -12,18 +12,19 @@ from collections.abc import Iterable from typing import Any +import torch import torch.fx from torch._dynamo import config from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation -from torch.utils._pytree import tree_flatten from .graph_region_tracker import Node, Region +from .graph_utils import _detect_cycles, _flatten_args_kwargs log = logging.getLogger(__name__) -def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore[no-untyped-def] +def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def] """ This is the main entry point for applying the graph deduplication pass. \ Deduplication occurs in two phases: @@ -50,15 +51,14 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore Returns a mapping of nodes to their subgraph output replacement node to remap outputs when they are created in output_graph. """ + from torch._inductor.pattern_matcher import stable_topological_sort + duplicated_region_groups = output_graph.region_tracker.get_identical_regions( output_graph.graph ) - # Used to track which nodes were replaced with subgraph outputs - # today, we have to register the new subgraph submodules before the - # graph outputs have been created, so we pass the replacement mapping - # back to output graph to do the replacements at the site of output creation - output_replacements: dict[Node, Node] = {} + sub_gms: dict[str, torch.fx.GraphModule] = {} + for region_group in duplicated_region_groups: inds_with_external_users = _get_all_output_indices(region_group) region = region_group[0] @@ -66,8 +66,14 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore subgraph, node_ind_arg_inds, ) = _create_subgraph(region, inds_with_external_users) + + # Ignore regions with no args for now, could they possibly be evaluated at compile time? + if not list(node_ind_arg_inds): + continue + sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph) subgraph_name = output_graph.install_subgraph("subgraph", sub_gm) + sub_gms[subgraph_name] = sub_gm with output_graph.graph.inserting_before(): get_subgraph_node = output_graph.graph.create_node( "get_attr", subgraph_name, (), {} @@ -81,34 +87,10 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore inds_with_external_users, sub_gm, subgraph_name, - output_replacements, ) - return output_replacements - - -# flattens with support for slices -# Note: a better way to do this would -# be register/unregister slices as pytree nodes -# but there is no unregister API in the pytorch -# pytree impl -def _flatten_args_kwargs(args: Any) -> list[Node]: - fully_flattened = [] - - def flatten(args: Any) -> None: - flattened, _ = tree_flatten(args) - for arg in flattened: - if isinstance(arg, slice): - start = arg.start - stop = arg.stop - step = arg.step - flatten((start, stop, step)) - else: - fully_flattened.append(arg) - - flatten(args) - - return fully_flattened + stable_topological_sort(output_graph.graph) + return sub_gms def _replace_region_with_subgraph( @@ -119,7 +101,6 @@ def _replace_region_with_subgraph( inds_with_external_users: list[int], sub_gm: torch.fx.GraphModule, subgraph_name: str, - output_replacements: dict[Node, Node], ) -> None: sub_args = [] for node_ind, arg_ind in node_ind_arg_ind: @@ -137,23 +118,26 @@ def _replace_region_with_subgraph( ) return - latest_region_node = region[-1] - with graph.inserting_after(latest_region_node): - invoke_subgraph_node = graph.create_node( - "call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {} + from torch._inductor.pattern_matcher import stable_topological_sort + + invoke_subgraph_node = graph.create_node( + "call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {} + ) + for ind, external_user_ind in enumerate(inds_with_external_users): + node = region[external_user_ind] + subgraph_output = graph.create_node( + "call_function", operator.getitem, (invoke_subgraph_node, ind), {} ) - with graph.inserting_after(invoke_subgraph_node): - for ind, external_user_ind in enumerate(inds_with_external_users): - node = region[external_user_ind] - subgraph_output = graph.create_node( - "call_function", operator.getitem, (invoke_subgraph_node, ind), {} - ) - output_replacements[node] = subgraph_output - node.replace_all_uses_with(subgraph_output, propagate_meta=True) - - # Erase in reverse topological order - for node in reversed(region): - graph.erase_node(node) + node.replace_all_uses_with(subgraph_output, propagate_meta=True) + + # Erase in reverse topological order + for node in reversed(region): + graph.erase_node(node) + + if config.graph_deduplication_lint: + _detect_cycles(graph) + stable_topological_sort(graph) + graph.lint() if config.graph_deduplication_lint: graph.lint() diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index 1be528a7ed72..272eeff54f44 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -27,6 +27,8 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.utils._pytree import tree_flatten +from .graph_utils import _flatten_args_kwargs + T = TypeVar("T") @@ -253,6 +255,8 @@ def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: """ topological_ranking = {node: i for i, node in enumerate(graph.nodes)} region_groups_with_rank = [] + # needed to detect if replacing a region will create cycles + node_to_recursive_ancestors = _populate_recursive_ancestor_map(graph) # Create region groups; a region group is a group # of regions that are all identical. In this initial state @@ -281,7 +285,12 @@ def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: # overlap. seen_nodes: set[Node] = set() for region_group in region_groups: - fully_expand_region_group(region_group, seen_nodes, self._is_identical) + fully_expand_region_group( + region_group, + seen_nodes, + node_to_recursive_ancestors, + self._is_identical, + ) # sort topologically for region in region_group: region.sort(key=lambda n: topological_ranking[n]) @@ -297,6 +306,7 @@ def __str__(self) -> str: def fully_expand_region_group( regions: list[Region], seen_nodes: set[Node], + node_to_recursive_ancestors: dict[Node, set[Node]], is_identical_fn: Callable[[Node, Node], bool], ) -> None: debug_log("--------------------------------------------------") @@ -327,17 +337,19 @@ def fully_expand_region_group( # regions are only expanded if the node to add is valid # for ALL regions while current_node: - add_node = True + add_node = not _will_create_cycle( + current_node, regions[0], node_to_recursive_ancestors + ) nodes_to_add.clear() nodes_to_add.append(current_node) nodes_to_add_set = set(nodes_to_add) - for region_it in region_iters[1:]: + for ind, region_it in enumerate(region_iters[1:]): + ind += 1 # compensate for the 0th region node = region_it.next() debug_log("--------------------") debug_log("considering adding: %s, cur_node: %s", node, current_node) debug_log("previously claimed nodes: %s", node in seen_nodes) - debug_log("%s", seen_nodes) if node: debug_log("is_identical: %s", is_identical_fn(node, current_node)) add_node &= ( @@ -345,6 +357,9 @@ def fully_expand_region_group( and node not in nodes_to_add_set and node.op != "placeholder" and is_identical_fn(node, current_node) + and not _will_create_cycle( + node, regions[ind], node_to_recursive_ancestors + ) ) nodes_to_add.append(node) nodes_to_add_set.add(node) @@ -369,3 +384,35 @@ def fully_expand_region_group( debug_log("end expand new region group: %s", regions) debug_log("--------------------------------------------------") + + +def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[Node]]: + node_to_recursive_ancestors: dict[Node, set[Node]] = {} + for node in graph.nodes: + node_to_recursive_ancestors[node] = set() + for node in graph.nodes: + all_args = _flatten_args_kwargs((node.args, node.kwargs)) + for arg in all_args: + if isinstance(arg, Node): + node_to_recursive_ancestors[node].update( + node_to_recursive_ancestors[arg] + ) + node_to_recursive_ancestors[node].add(node) + return node_to_recursive_ancestors + + +def _will_create_cycle( + node_to_add: Node, + region: Region, + node_to_recursive_ancestors: dict[Node, set[Node]], +) -> bool: + region_set: set[Node] = set(region) + region_ancestors: set[Node] = set( + tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0] + ) + external_users = [user for user in node_to_add.users if user not in region_set] + for user in external_users: + if user in region_ancestors: + return True + + return False diff --git a/torch/_dynamo/graph_utils.py b/torch/_dynamo/graph_utils.py new file mode 100644 index 000000000000..cde627f244e8 --- /dev/null +++ b/torch/_dynamo/graph_utils.py @@ -0,0 +1,69 @@ +from collections import deque +from typing import Any + +from torch.fx import Graph, Node +from torch.utils._pytree import tree_flatten + + +# flattens with support for slices +# Note: a better way to do this would +# be register/unregister slices as pytree nodes +# but there is no unregister API in the pytorch +# pytree impl +def _flatten_args_kwargs(args: Any) -> list[Node]: + fully_flattened = [] + + def flatten(args: Any) -> None: + flattened, _ = tree_flatten(args) + for arg in flattened: + if isinstance(arg, slice): + start = arg.start + stop = arg.stop + step = arg.step + flatten((start, stop, step)) + else: + fully_flattened.append(arg) + + flatten(args) + + return fully_flattened + + +def _detect_cycles(graph: Graph) -> str: + current_path: deque[Node] = deque() + current_path_set: set[Node] = set() + pending: deque[tuple[Node, Node]] = deque() + + def add_to_current_path(node: Node) -> None: + current_path.append(node) + current_path_set.add(node) + + def pop_current_path() -> None: + node = current_path.pop() + current_path_set.remove(node) + + def current_path_head() -> Node: + return current_path[-1] + + for origin in graph.find_nodes(op="placeholder"): + current_path.clear() + current_path_set.clear() + add_to_current_path(origin) + for child in origin.users: + pending.append((child, origin)) + + while pending: + cur_node, parent = pending.pop() + + while current_path_head() != parent: + pop_current_path() + + if cur_node in current_path_set: + current_path.append(cur_node) + return f"cycle detected in path: {current_path}" + + add_to_current_path(cur_node) + for child in cur_node.users: + pending.append((child, cur_node)) + + return "no cycle detected" diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 48631a7021f9..fe3a93be8644 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1984,11 +1984,20 @@ def _get_code_parts(langs): ) if config.enable_cpp_symbolic_shape_guards: - # For exporting we need the python code parts - python_code_parts, verbose_code_parts, cpp_code_parts = _get_code_parts( - ("python", "verbose_python", "cpp") - ) + try: + # For exporting we need the python code parts + python_code_parts, verbose_code_parts, cpp_code_parts = _get_code_parts( + ("python", "verbose_python", "cpp") + ) + python_fallback = False + except OverflowError: + # Cannot use int64_t + python_fallback = True + python_code_parts, verbose_code_parts = _get_code_parts( + ("python", "verbose_python") + ) else: + python_fallback = True python_code_parts, verbose_code_parts = _get_code_parts( ("python", "verbose_python") ) @@ -2004,11 +2013,10 @@ def _get_code_parts(langs): if compile_context := CompileContext.try_get(): compile_context.shape_env_guards.extend(verbose_code_parts.exprs) - if config.enable_cpp_symbolic_shape_guards: - import ctypes - - from torch._inductor.codecache import CppCodeCache + int_source_to_symbol = [] + float_source_to_symbol = [] + if not python_fallback: assert cpp_code_parts # type: ignore[possibly-undefined] code_parts, source_to_symbol = ( cpp_code_parts.exprs, @@ -2018,10 +2026,6 @@ def _get_code_parts(langs): if not code_parts: return - int_source_to_symbol = [] - float_source_to_symbol = [] - - python_fallback = False for source, symbol in source_to_symbol.items(): if isinstance(source, ConstantSource): python_fallback = True @@ -2039,62 +2043,78 @@ def _get_code_parts(langs): # int64_t/double in C++ guards for now. python_fallback = True - if not python_fallback: - source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) - try: - guard_managers = [ - self.get_guard_manager_from_source(IndexedSource(source, i)) - for i, source in enumerate(source_to_symbol) - ] - - int_symbols_str = ", ".join( - f"{symbol} = int_values[{i}]" - for i, (_, symbol) in enumerate(int_source_to_symbol) - ) - float_symbols_str = ", ".join( - f"{symbol} = float_values[{i}]" - for i, (_, symbol) in enumerate(float_source_to_symbol) - ) + if not python_fallback: + import ctypes - if int_symbols_str: - int_symbols_str = f"int64_t {int_symbols_str};" - if float_symbols_str: - float_symbols_str = f"double {float_symbols_str};" - - func_str = textwrap.dedent( - f""" - #include - #include - #include - - extern "C" int8_t guard(int64_t *int_values, double *float_values) {{ - {int_symbols_str} - {float_symbols_str} - return ({") && (".join(code_parts)}); - }} - """ - ) - guards_log.debug( - "C++ shape guard function: %s %s", - func_str, - verbose_code_parts.exprs, - ) - clib = CppCodeCache.load(func_str) - cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value - assert cguard - except torch._inductor.exc.InvalidCxxCompiler: - # No valid C++ compiler to compile the shape guard - pass - else: - install_symbolic_shape_guard( - guard_managers, - len(int_source_to_symbol), - len(float_source_to_symbol), - cguard, - clib, - verbose_code_parts.exprs, - ) - return + from torch._inductor.codecache import CppCodeCache + + assert cpp_code_parts # type: ignore[possibly-undefined] + code_parts, source_to_symbol = ( + cpp_code_parts.exprs, + cpp_code_parts.source_to_symbol, + ) + + source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) + try: + guard_managers = [ + self.get_guard_manager_from_source(IndexedSource(source, i)) + for i, source in enumerate(source_to_symbol) + ] + + int_symbols_str = ", ".join( + f"{symbol} = int_values[{i}]" + for i, (_, symbol) in enumerate(int_source_to_symbol) + ) + float_symbols_str = ", ".join( + f"{symbol} = float_values[{i}]" + for i, (_, symbol) in enumerate(float_source_to_symbol) + ) + + if int_symbols_str: + int_symbols_str = f"int64_t {int_symbols_str};" + if float_symbols_str: + float_symbols_str = f"double {float_symbols_str};" + + func_str = textwrap.dedent( + f""" + #include + #include + #include + + #if defined(_MSC_VER) + # define EXTERN_DLL_EXPORT extern "C" __declspec(dllexport) + #else + # define EXTERN_DLL_EXPORT extern "C" + #endif + + EXTERN_DLL_EXPORT int8_t guard(int64_t *int_values, double *float_values) {{ + {int_symbols_str} + {float_symbols_str} + return ({") && (".join(code_parts)}); + }} + """ + ) + guards_log.debug( + "C++ shape guard function: %s %s", + func_str, + verbose_code_parts.exprs, + ) + clib = CppCodeCache.load(func_str) + cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value + assert cguard + except torch._inductor.exc.InvalidCxxCompiler: + # No valid C++ compiler to compile the shape guard + pass + else: + install_symbolic_shape_guard( + guard_managers, + len(int_source_to_symbol), + len(float_source_to_symbol), + cguard, + clib, + verbose_code_parts.exprs, + ) + return # Install all the symbolic guards in one python lambda guard. These are run # at the very end of the RootGuardManager via epilogue guards. diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba3dea42864d..856ae4e32973 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -79,6 +79,7 @@ from .code_context import code_context from .codegen import PyCodegen from .current_scope_id import enter_new_scope +from .device_interface import get_interface_for_device from .exc import ( BackendCompilerFailed, exceptions_allowed_to_be_fallback, @@ -240,6 +241,10 @@ def __init__(self, nn_modules: dict[str, torch.nn.Module]): def __repr__(self) -> str: return "FakeRootModule(...)" + def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]): + for k, v in nn_modules.items(): + setattr(self, k, v) + class WrapperBackend: def __init__(self, backend: CompilerFn): @@ -386,7 +391,7 @@ def __init__( # and LOAD_ATTR for same python objects free. self.variable_tracker_cache = VariableTrackerCache() self.unique_var_id = itertools.count() - self.code_options = dict(code_options) + self.code_options: dict[str, Any] = dict(code_options) self.output_instructions: list[Instruction] = [] # used to track nodes that are added between calls of copy_graphstate # and restore_graphstate @@ -397,7 +402,7 @@ def __init__( # Not checkpointed self.compiler_fn: Optional[CompilerFn] = compiler_fn - self.global_scope = global_scope + self.global_scope: Scope = global_scope self.local_scope = local_scope self.root_tx = root_tx @@ -458,7 +463,7 @@ def __init__( self.random_calls: list[ tuple[Callable[..., object], tuple[object, ...], dict[str, object]] ] = [] - self.random_values_var = None + self.random_values_var: Any = None # Bytecode to insert right before we call the graph self.pregraph_bytecode: list[Instruction] = [] @@ -884,7 +889,9 @@ def wrap_name(module_key): self.output.update_co_names(module_key) self.global_scope[module_key] = target return VariableTracker.build( - self, target, ConstantSource(source_name=module_key) + self, # type: ignore[arg-type] + target, + ConstantSource(source_name=module_key), ) for k, v in self.nn_modules.items(): @@ -1070,8 +1077,6 @@ def append_prefix_insts(): for value in stack_values: value.realize() - output_replacements = self.dedup_pass() - # Use nn.Module "proxies" in the constructed GraphModule so that # the resulting GM does not hold additional strong references to the original modules. # This prevents a strong ref cycle where Dynamo created code holds on to references @@ -1120,7 +1125,10 @@ def append_prefix_insts(): append_prefix_insts() random_calls_instructions = [] self.random_values_var = self.new_var("random_values") - rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) + rand_fn = disable( + _get_gen_rand_values_fn(self.random_calls), + reason="do not trace into Dynamo rng recovery function", + ) rand_fn_name = self.install_global("__gen_rand_values", rand_fn) codegen = PyCodegen(tx, root, overridden_sources=overridden_sources) random_calls_instructions.extend( @@ -1155,9 +1163,7 @@ def append_prefix_insts(): append_prefix_insts() # optimization to generate better code in a common case self.add_output_instructions( - self.compile_and_call_fx_graph( - tx, list(reversed(stack_values)), root, output_replacements - ) + self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] ) # restore all the live local vars @@ -1190,9 +1196,7 @@ def append_prefix_insts(): output = [] if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: output.extend( - self.compile_and_call_fx_graph( - tx, pass2.graph_output_vars(), root, output_replacements - ) + self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) ) if len(pass2.graph_outputs) != 0: @@ -1344,8 +1348,14 @@ def run_compiler_collective(self, tx): }, payload_fn=lambda: ds.local_state.render(), ) + device_types = compile_pg._device_types + assert len(device_types) == 1, ( + "Expect only one device type but got {}".format("+".join(device_types)) + ) with ( - torch.cuda.device(compile_pg.rank() % torch.cuda.device_count()), + get_interface_for_device(device_types.pop()).device( # type: ignore[attr-defined] + compile_pg.rank() % torch.accelerator.device_count() + ), dynamo_timed("compiler_collective", log_pt2_compile_event=True), ): all_states = [None] * compile_pg.size() @@ -1356,7 +1366,7 @@ def run_compiler_collective(self, tx): tx.speculation_log.clear() raise exc.CompileCollectiveRestartAnalysis - def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs): + def compile_and_call_fx_graph(self, tx, rv, root): """ Generate code from self.graph and return the Instruction()s to call that generated code. @@ -1379,9 +1389,8 @@ def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs): (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), {}, ) - - for old_node, new_node in replaced_outputs.items(): - old_node.replace_all_uses_with(new_node) + sub_gms = self.dedup_pass() + root.add_nn_modules(sub_gms) tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node) if not config.do_not_emit_runtime_asserts: @@ -1473,7 +1482,9 @@ def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs): # replace compiled_fn with the real forward method compiled_fn = lazy_gm.forward - compiled_fn = disable(compiled_fn) + compiled_fn = disable( + compiled_fn, reason="do not trace Dynamo-compiled graph" + ) counters["stats"]["unique_graphs"] += 1 # This is safe because we pre-process name to be unique @@ -1576,7 +1587,7 @@ def dedup_pass(self): if torch._dynamo.config.use_graph_deduplication: return apply_graph_deduplication(self) else: - return dict() + return {} def install_subgraph(self, name, sub_gm): next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True) diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 96ace1da75b4..8db484cab727 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -605,7 +605,9 @@ def hit(ty: str) -> defaultdict[CodeId, CodeState]: remote_cache = get_remote_cache() if remote_cache is not None: with dynamo_timed( - name := "pgo.get_remote_code_state", log_pt2_compile_event=True + name := "pgo.get_remote_code_state", + log_pt2_compile_event=True, + dynamo_compile_column_us="pgo_get_remote_code_state_time_us", ): CompileEventLogger.pt2_compile(name, cache_key=cache_key) # TODO: I don't really understand why there's a JSON container format @@ -716,7 +718,11 @@ def put_local_code_state(cache_key: str) -> None: def put_remote_code_state(cache_key: str) -> None: - with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True): + with dynamo_timed( + name := "pgo.put_remote_code_state", + log_pt2_compile_event=True, + dynamo_compile_column_us="pgo_put_remote_code_state_time_us", + ): CompileEventLogger.pt2_compile(name, cache_key=cache_key) assert _CODE_STATE is not None diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index d9be4e9febc9..f60aa57a5d40 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -21,6 +21,7 @@ "pytree", "sys", "fx", + "tensor", ) POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index c62f19e34406..f007b46800b2 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -56,9 +56,10 @@ def _(*args: Any, **kwargs: Any) -> bool: "structseq_fields", ): __func = getattr(optree, __name) - substitute_in_graph(__func, can_constant_fold_through=True)( + globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( __func.__python_implementation__ ) + __all__ += [__name] # noqa: PLE0604 del __func del __name diff --git a/torch/_dynamo/polyfills/tensor.py b/torch/_dynamo/polyfills/tensor.py new file mode 100644 index 000000000000..002ccf5d1d4f --- /dev/null +++ b/torch/_dynamo/polyfills/tensor.py @@ -0,0 +1,37 @@ +from typing import Any + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph( # type: ignore[arg-type] + torch.Tensor._make_subclass +) +def make_subclass( + cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any +) -> Any: + # This is a rough approximation of `THPVariable_make_subclass`. It should + # suffice for most of Dynamo tracing purposes. + # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 + assert len(kwargs) == 0, "_make_subclass only supports requires_grad as keyword arg" + data = data.detach() + + # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. + if data.requires_grad != requires_grad: + data.requires_grad = requires_grad + + # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. + if cls is torch.Tensor: + return torch.Tensor(data) + + # Calling `as_subclass` because + # 1. Dynamo knows how to handle it + # 2. the C impls match at this point -- both `THPVariable_make_subclass` and + # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. + return data.as_subclass(cls) + + +__all__ = [ + "make_subclass", +] diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 4c85d98cfd16..1deb09e2cc1e 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -295,9 +295,7 @@ def _track_obj( variable: VariableTracker, mutation_type_cls=ValueMutationExisting, ): - """Start tracking a new variable for mutation""" - assert variable.source is not None - + """Start tracking an existing or new variable for mutation""" if id(item) in self.id_to_variable: raise AssertionError( f"{variable} is already tracked for mutation. This could be " @@ -576,12 +574,18 @@ def _get_modified_vars(self): return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen): - # Make sure we codegen these modified VT to their source by default, so - # that mutation and aliasing are properly accounted for. + # We must codegen modified VT to their source by default, so that + # mutation and aliasing are properly accounted for. + # + # Since newly constructed objects don't have a source, we manually + # codegen their construction and store them to a newly assigned local + # source. Note that `ValueMutationNew` isn't tracked by SideEffects. for var in self._get_modified_vars(): - if isinstance(var.mutation_type, AttributeMutationNew) and isinstance( - var, variables.CellVariable - ): + if not isinstance(var.mutation_type, AttributeMutationNew): + assert var.source is not None + continue + + if isinstance(var, variables.CellVariable): # Cells created in the root frame are created either by # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit # `make_cell` for the non-root-frame cells here. @@ -595,18 +599,38 @@ def codegen_save_tempvars(self, cg: PyCodegen): var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: var.source = LocalCellSource(var.local_name) - elif isinstance(var.mutation_type, AttributeMutationNew): - if isinstance(var, variables.AutogradFunctionContextVariable): - unimplemented_v2( - gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", - context="", - explanation="We cannot reconstruct a torch.autograd.Function's context object.", - hints=[], - ) - + elif isinstance(var, variables.TensorVariable): + # NOTE: for historical reasons we never assigned local sources + # to newly constructed tensor object, so we keep it that way. + # They are always loaded from output of the fx graph, so one can + # think of it as having a "OutputGraphSource" for codegen + # purposes. + # + # However, tensor subclass objects are different, because the + # reconstruction logic in `PyCodegen` loads the data tensor from + # graph output and then calls `as_subclass`, meaning we must + # assign a source to it to ensure we only reconstruct one + # subclass instance. + if isinstance( + var, variables.torch_function.TensorWithTFOverrideVariable + ): + # Don't codegen from temp source assigned from the 1st pass. + cg(var, allow_cache=False) + cg.add_cache(var) + # `add_cache` generates STORE and consumes TOS, but we never + # cleared it. TODO move this call into `add_cache` + cg.clear_tos() + var.source = LocalSource(cg.tempvars[var]) + elif isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented_v2( + gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", + context="", + explanation="We cannot reconstruct a torch.autograd.Function's context object.", + hints=[], + ) + else: # Reconstruct the bytecode for # base_cls.__new__(user_cls, *args) - if isinstance(var, variables.UserDefinedObjectVariable): def load_new_method(): @@ -630,10 +654,6 @@ def load_new_method(): cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) - else: - # The remaning cases here are `AttributeMutationExisting` and - # `MutableSideEffects`, which have sources already. - assert var.source is not None for ctx, args in self.save_for_backward: cg(ctx.source) @@ -993,7 +1013,7 @@ def codegen_update_mutated(self, cg: PyCodegen): else: cg.tx.output.update_co_names(name) cg(value) - cg(var.source) + cg(var) suffixes.append([create_instruction("STORE_ATTR", argval=name)]) elif isinstance(var, variables.ListIteratorVariable): for _ in range(var.index): diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index e01c166c97d2..f31d613170a5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -21,7 +21,7 @@ import dataclasses import enum -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union from torch._guards import ChainedSource, GuardSource, Source @@ -29,6 +29,9 @@ from .bytecode_transformation import create_call_function, create_instruction +if TYPE_CHECKING: + from .codegen import PyCodegen + # It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, # so those cases are omitted intentionally @@ -120,7 +123,7 @@ class LocalSource(Source): # or `co_freevars`. is_derefed_cell_contents: bool = False - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if self.is_derefed_cell_contents: codegen.load_deref(self.local_name) else: @@ -137,7 +140,7 @@ def name(self): class SyntheticLocalSource(Source): local_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load(self.local_name)) def guard_source(self): @@ -154,7 +157,7 @@ class RandomValueSource(Source): def guard_source(self): return GuardSource.RANDOM_VALUE - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) codegen.append_output(codegen.create_load_const(self.random_call_index)) codegen.append_output(create_instruction("BINARY_SUBSCR")) @@ -167,7 +170,7 @@ def name(self): class GlobalSource(Source): global_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_global(self.global_name, add=True)) def guard_source(self): @@ -181,7 +184,7 @@ def name(self): class GlobalWeakRefSource(Source): global_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_global(self.global_name, add=True) @@ -198,7 +201,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class WeakRefCallSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen(self.base)) codegen.extend_output(create_call_function(0, False)) @@ -227,7 +230,7 @@ def __post_init__(self): ) object.__setattr__(self, "member", member_parts[-1]) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -249,7 +252,7 @@ class LocalCellSource(Source): local_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics, # Dynamo's bytecode transformation differentiates them slightly, so we # always emit `LOAD_CLOSURE` here. @@ -267,7 +270,7 @@ def reconstruct(self, codegen): class GradSource(ChainedSource): member: str = "grad" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -342,7 +345,7 @@ def __post_init__(self): else: assert self.idx is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, f"call_{self.prop.method_name()}" @@ -378,7 +381,7 @@ class IndexedSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def guard_source(self): @@ -393,7 +396,7 @@ class NegateSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def guard_source(self): @@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -450,7 +453,7 @@ def name(self): class AttrProxySource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -484,7 +487,7 @@ def __post_init__(self): self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.field)) codegen.append_output(codegen.create_load_const(self.idx_key)) @@ -509,7 +512,7 @@ def __post_init__(self): super().__setattr__("index", self.index.__reduce__()) super().__setattr__("index_is_slice", True) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) if self.index_is_slice: codegen.append_output(codegen.create_load_const(self.unpack_slice())) @@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") ) @@ -577,7 +580,7 @@ def __post_init__(self): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # reconstruct dict.__getitem__(dct, key) # Load dict.__getitem__ @@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource): Same as GetItemSource with reconstruct and name overridden to be list specific. """ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Reconstruct list.__getitem__(lst, index) to avoid any side effects # from possibly overridden __getitem__. @@ -646,7 +649,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class TupleIteratorGetItemSource(GetItemSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") ) @@ -663,7 +666,7 @@ class TypeSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) @@ -677,7 +680,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class OptimizerSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -689,7 +692,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class NNModuleSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -738,7 +741,7 @@ def _get_index(self): return TorchFunctionModeStackVariable.get_mode_index(self.ind) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, "get_torch_function_mode_stack_at" @@ -755,7 +758,7 @@ def guard_source(self): class ConstantSource(Source): source_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_global(self.source_name, add=False)) def guard_source(self): @@ -776,7 +779,7 @@ def name(self) -> str: def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) @@ -842,6 +845,12 @@ def is_from_local_source(source: Source, *, only_allow_input=False): return True +def is_from_source(source: Source, target: Source): + if isinstance(source, ChainedSource): + return is_from_source(source.base, target) + return source == target + + def is_from_unspecialized_param_buffer_source(source: Source): if isinstance(source, UnspecializedParamBufferSource): return True diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 2ceb1368f7a7..fe634614db4e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -590,15 +590,7 @@ def jump_graph_break(self, inst, value, extra_msg=""): hints=_hints, ), ) - if not self.should_compile_partial_graph(): - unimplemented_v2( - gb_type="Should not compile partial graph (data-dependent branching)", - context="", - explanation="Dynamo has determined when encountering data-dependent " - "branching (e.g. `if my_tensor.item() > 0:`) that it should not " - "compile the partial graph.", - hints=[], - ) + assert self.should_compile_partial_graph() # compile a partial subgraph prefix then jump into user code if self.maybe_has_backedge(): msg = ( @@ -642,8 +634,24 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if value.is_python_constant(): if bool(value.as_python_constant()): return self.jump(inst) - else: + elif self.should_compile_partial_graph(): jump_graph_break(self, inst, value) + else: + unimplemented_v2( + gb_type="Data-dependent assertion failed (cannot compile partial graph)", + context=f"value: {value}", + explanation="Dynamo has determined when encountering a data-dependent assert failure " + "that it should not compile the partial graph.", + hints=[ + *graph_break_hints.FUNDAMENTAL, + "Use `torch._assert()` to raise a hard AssertionError when the check fails. " + "This error will propagate back the user code " + "that called the compiled function (i.e. Dynamo wil not trace any exception handling).", + "Remove the assert statement.", + "Move the assert statement outside of any context managers in order to graph break with " + "partial graph compilation (if fullgraph=False).", + ], + ) # TODO maybe should respect DtoH sync intention of users later?? # Manually insert torch._assert_async instead of python assert and jump over @@ -1701,34 +1709,6 @@ def WITH_CLEANUP_FINISH(self, inst): self.popn(2) self.push(None) - def CALL_FINALLY(self, inst): - """ - pushes the address of the next instruction onto the stack and increments - bytecode counter by delta - """ - # Python 3.8 only - addr = self.indexof[self.next_instruction] - self.push(ConstantVariable.create(addr)) - self.jump(inst) - - def END_FINALLY(self, inst): - # Python 3.8 only - # https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY - tos = self.pop() - if isinstance(tos, ConstantVariable): - self.instruction_pointer = tos.as_python_constant() - else: - pass - - def POP_FINALLY(self, inst): - # Python 3.8 only - preserve_tos = inst.argval - if preserve_tos: - tos = self.pop() - _ = self.pop() - if preserve_tos: - self.push(tos) # type: ignore[possibly-undefined] - def FOR_ITER(self, inst): it = self.pop().realize() try: @@ -1749,18 +1729,22 @@ def FOR_ITER(self, inst): self.push(ConstantVariable.create(None)) self.jump(inst) - def _raise_exception_variable(self, val) -> NoReturn: - # User can raise exception in 2 ways - # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") - - # 1) when user raises exception type + def _create_exception_type(self, val): if isinstance( val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable) ): # Create the instance of the exception type # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 val = val.call_function(self, [], {}) # type: ignore[arg-type] + return val + + def _raise_exception_variable(self, val) -> NoReturn: + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + val = self._create_exception_type(val) # Handle https://peps.python.org/pep-0479/ # CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this @@ -1787,6 +1771,10 @@ def _raise_exception_variable(self, val) -> NoReturn: def RAISE_VARARGS(self, inst): if inst.arg == 0: + if not len(self.exn_vt_stack): + msg = ConstantVariable("No active exception to reraise") + exc.raise_observed_exception(RuntimeError, self, args=[msg]) + # re-raise the previous exception. Here CPython refers to the exception # on top of the exception stack assert len(self.exn_vt_stack) @@ -1798,24 +1786,16 @@ def RAISE_VARARGS(self, inst): val = self.stack[-1] self._raise_exception_variable(val) else: - # raise .. from None + # raise .. from ... from_vt = self.pop() - if isinstance(from_vt, ConstantVariable) and from_vt.value is None: - val = self.pop() - try: - self._raise_exception_variable(val) - finally: - # Update __cause__/__supppress_context__ in the raised exception - curr_exc = self.exn_vt_stack.get_current_exception() - curr_exc.call_setattr( - self, ConstantVariable("__cause__"), ConstantVariable(None) - ) - unimplemented_v2( - gb_type="Re-raise with 2 arguments", - context=str(from_vt), - explanation="Dynamo does not support `raise ... from [not-None]`", - hints=[], - ) + val = self.pop() + try: + self._raise_exception_variable(val) + finally: + # Update __cause__/__supppress_context__ in the raised exception + curr_exc = self.exn_vt_stack.get_current_exception() + cause = self._create_exception_type(from_vt) + curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) def CLEANUP_THROW(self, inst): # https://github.com/python/cpython/pull/96010 @@ -2824,6 +2804,17 @@ def MATCH_KEYS(self, inst): def LOAD_ASSERTION_ERROR(self, inst): self.load_builtin_from_argval("AssertionError") + def LOAD_BUILD_CLASS(self, inst): + unimplemented_v2( + gb_type="LOAD_BUILD_CLASS bytecode not supported", + context="", + explanation="Dynamo does not support tracing classes that are defined in the compiled region.", + hints=[ + "Move the class definition out of the compiled region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + UNARY_POSITIVE = stack_op(operator.pos) UNARY_NEGATIVE = stack_op(operator.neg) UNARY_NOT = stack_op(operator.not_) @@ -3773,10 +3764,14 @@ def check_inlineable(func): if isinstance(func, UserFunctionVariable) and inspect.getattr_static( func.get_function(), "_torchdynamo_disable", False ): + msg = inspect.getattr_static( + func.get_function(), "_torchdynamo_disable_msg", None + ) unimplemented_v2( gb_type="Skip inlining `torch.compiler.disable()`d function", context=str(func.get_function()), - explanation=f"Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable`", + explanation=f"Skip inlining function {func.get_function()} since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", hints=[ "Remove the `torch.compiler.disable` call", ], diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index e927fc4a1eaf..ac505c0de02a 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -95,3 +95,22 @@ def tearDown(self) -> None: if self._prior_is_grad_enabled is not torch.is_grad_enabled(): log.warning("Running test changed grad mode") torch.set_grad_enabled(self._prior_is_grad_enabled) + + +class CPythonTestCase(TestCase): + _stack: contextlib.ExitStack + + @classmethod + def tearDownClass(cls) -> None: + cls._stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls._stack = contextlib.ExitStack() # type: ignore[attr-defined] + cls._stack.enter_context( # type: ignore[attr-defined] + config.patch( + enable_trace_unittest=True, + ), + ) diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d44ad4b2408d..ce25a2969050 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -524,3 +524,9 @@ def reset_rng_state(use_xla: bool = False) -> None: import torch_xla.core.xla_model as xm xm.set_rng_state(1337, str(xm.xla_device())) + + +def _skipped_function_for_test_reconstruct( + f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs +) -> _T: + return f(*args, **kwargs) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 05739259dc5b..aaae72c86228 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -308,6 +308,8 @@ "torch._dynamo.mark_static": UserFunctionVariable, "torch._dynamo.nonstrict_trace": UserFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, "torch.set_default_device": UserFunctionVariable, @@ -362,6 +364,7 @@ "math.isinf", "math.isnan", "math.isqrt", + "math.lcm", "math.ldexp", "math.lgamma", "math.log", @@ -510,7 +513,6 @@ "torch._C._debug_set_fusion_group_inlining", "torch._C._demangle", "torch._C._disabled_torch_dispatch_impl", - "torch._C._disabled_torch_function_impl", "torch._C._dispatch_call_boxed", "torch._C._dispatch_check_all_invariants", "torch._C._dispatch_check_invariants", @@ -1622,6 +1624,7 @@ "torch._values_copy", "torch._weight_int4pack_mm", "torch._weight_int4pack_mm_for_cpu", + "torch._weight_int4pack_mm_with_scales_and_zeros", "torch._weight_int8pack_mm", "torch._weight_norm_interface", "torch._weight_norm", @@ -2234,7 +2237,6 @@ ) -torch_c_binding_in_graph_functions["math.lcm"] = TorchInGraphFunctionVariable if sys.version_info >= (3, 11): torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable @@ -3172,7 +3174,6 @@ def is_numpy_type_info(obj) -> bool: random, traceback, linecache, - unittest, ) # third party libraries skiplist is defined by str, because users may not use these libraries. @@ -3579,6 +3580,12 @@ def check_file(filename, is_inlined_call=False): ): return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") + if ( + filename.startswith(_module_dir(unittest)) + and not torch._dynamo.config.enable_trace_unittest + ): + return SkipResult(True, "unittest") + if bool(SKIP_DIRS_RE.match(filename)): return SkipResult(True, "SKIP_DIRS") diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 89e47e823cdd..04ee7aa86d69 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -94,7 +94,14 @@ if typing.TYPE_CHECKING: - from collections.abc import Generator, Iterable, Iterator, KeysView, ValuesView + from collections.abc import ( + Generator, + ItemsView, + Iterable, + Iterator, + KeysView, + ValuesView, + ) try: @@ -579,12 +586,27 @@ def instant( @staticmethod def try_add_pt2_compile(event_name: str, **metadata: object): """ - Adds to an existing pt2_compile event, but silently returns if the event doesn't exist. + Adds to an existing pt2_compile event, but silently returns if the event doesn't exist + or ChromiumEventLogger is not initialized. This function is syntactic sugar for chromium_event_logger().try_add_event_data. """ + if CHROMIUM_EVENT_LOG is None: + return chromium_log = get_chromium_event_logger() chromium_log.try_add_event_data(event_name, **metadata) + @staticmethod + def try_(method_fn, *args, **kwargs): + """ + Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set + """ + if CHROMIUM_EVENT_LOG is None: + return + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + return + method_fn(*args, **kwargs) + @contextmanager def dynamo_timed( @@ -1245,6 +1267,8 @@ class CompilationMetrics: ir_count: Optional[int] = None cudagraph_skip_reason: Optional[str] = None python_version: Optional[str] = None + pgo_put_remote_code_state_time_us: Optional[int] = None + pgo_get_remote_code_state_time_us: Optional[int] = None @classmethod def create(cls, metrics: dict[str, Any]): @@ -1408,6 +1432,7 @@ def clean_for_json(d: dict[str, Any]) -> dict[str, Any]: "reorderable_logging_functions", "ignore_logger_methods", "traceable_tensor_subclasses", + "nontraceable_tensor_subclasses", "_custom_ops_profile", } @@ -2400,6 +2425,7 @@ def check_numpy_ndarray_args(args, kwargs): dict_keys: type[KeysView[Any]] = type({}.keys()) dict_values: type[ValuesView[Any]] = type({}.values()) +dict_items: type[ItemsView[Any, Any]] = type({}.items()) odict_values: type[ValuesView[Any]] = type(OrderedDict().values()) tuple_iterator: type[Iterator[Any]] = type(iter(())) range_iterator: type[Iterator[Any]] = type(iter(range(0))) @@ -4371,7 +4397,9 @@ def does_not_override_dict_iter_methods(user_cls): # compiled bytecode # They will be skipped which is the desired result def call_size(x, i): - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable( + recursive=True, reason="__torch_function__ tracing helper function" + ) def fn(x, i): return x.size(i) @@ -4379,7 +4407,9 @@ def fn(x, i): def call_stride(x, i): - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable( + recursive=True, reason="__torch_function__ tracing helper function" + ) def fn(x, i): return x.stride(i) @@ -4387,7 +4417,9 @@ def fn(x, i): def call_storage_offset(x): - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable( + recursive=True, reason="__torch_function__ tracing helper function" + ) def fn(x): return x.storage_offset() diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index fbf780bf7fa3..e5274d0f0ce7 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -29,7 +29,8 @@ if TYPE_CHECKING: - from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase + from ..codegen import PyCodegen + from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase class SourceType(Enum): @@ -399,7 +400,7 @@ def maybe_fx_node(self): except NotImplementedError: return None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def unpack_var_sequence(self, tx) -> list["VariableTracker"]: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d00ea5edc90a..d85885449b06 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -68,7 +68,11 @@ SymbolicContext, ) from torch.fx.immutable_collections import immutable_dict, immutable_list -from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, +) from torch.utils._sympy.value_ranges import ValueRanges from torch.utils.weak import TensorWeakRef @@ -140,6 +144,7 @@ wrap_fake_exception, ) from .base import ( + AttributeMutationNew, typestr, ValueMutationExisting, ValueMutationNew, @@ -271,6 +276,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -343,7 +349,7 @@ def __post_init__(self): self._example = TensorWeakRef(self._example) assert is_fake(self.fake_tensor) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.source) def erase(self): @@ -364,7 +370,7 @@ def __init__(self) -> None: is_tensor=False, ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): assert codegen.tx.output.backward_state_var codegen.add_push_null( lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") @@ -611,11 +617,30 @@ def create_2d_tma_descriptor(): return id_dispatch(self, value) # Everything else (NB: order matters!) - if is_traceable_wrapper_subclass(value) or istype( - value, config.traceable_tensor_subclasses + if ( + isinstance(value, torch.Tensor) + and type(value) + not in ( + # These torch-native subclasses have overly restrictive + # `__torch_function__` which prevents Dynamo from reading their + # tensor attributes like `is_nested` or calling methods like + # `_is_view`. + torch.nn.parameter.UninitializedBuffer, + torch.nn.parameter.UninitializedParameter, + ExpandedWeight, + ) + and type(value) not in config.nontraceable_tensor_subclasses ): - return self.wrap_tensor(value) - elif is_namedtuple(value): + if type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__: + # This case it's either tensor or subclass with default + # torch_dispatch (they might override torch_function or not), + # and we can always trace into them. + return self.wrap_tensor(value) + elif is_traceable_wrapper_subclass(value): + # For non-default torch_dispatch, we have more requirements. + return self.wrap_tensor(value) + + if is_namedtuple(value): self.install_guards(GuardBuilder.SEQUENCE_LENGTH) output = [ LazyVariableTracker.create( @@ -929,11 +954,6 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif ( - isinstance(value, torch._C._TensorMeta) - and value in config.traceable_tensor_subclasses - ): - return TensorSubclassVariable(value, source=self.source) elif ( istype(value, contextlib.nullcontext) and inspect.getattr_static(value, "enter_result", None) is None @@ -1186,6 +1206,20 @@ def build_key_value(i, k, v): if value is torch.autograd._unsafe_preserve_version_counter: self.install_guards(GuardBuilder.FUNCTION_MATCH) return PreserveVersionContextVariable.constructor(self.tx) + if ( + # `value` must be a strict subclass of `torch.Tensor` + issubclass(value, torch.Tensor) + and value is not torch.Tensor + # `TensorSubclassVariable` is not for subclass that overrides + # `torch_dispatch`. + and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + # `TensorSubclassVariable` would lead to construction of + # `TensorWithTFOverrideVariable`, but we don't want that for + # traceable wrapper subclasses (we wrap those subclass instances + # into `TensorVariable`). + and not is_traceable_wrapper_subclass_type(value) + ): + return TensorSubclassVariable(value, source=self.source) # This is a userdefined class, so install an ID_MATCH even if its a # global variable. self.install_guards(GuardBuilder.ID_MATCH) @@ -1539,7 +1573,13 @@ def wrap_module(self, value: torch.nn.Module): # we graph break here, Dynamo does not know how to create # continuation functions for such bytecodes. So, we delay the # graph break to CALL_FUNCTION. - return DelayGraphBreakVariable(source=self.source) + msg = inspect.getattr_static( + value.forward, "_torchdynamo_disable_msg", None + ) + return DelayGraphBreakVariable( + source=self.source, + msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})", + ) self.install_guards(GuardBuilder.TYPE_MATCH) self.source = AttrSource(self.source, "_orig_mod") @@ -1722,7 +1762,22 @@ def wrap_tensor(self, value: torch.Tensor): # Guards are added inside register_attr_or_module ) - if type(value) in config.traceable_tensor_subclasses: + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + options = {} + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): # Ordinarily, we would fakeify a tensor so that it can get dynamic # shapes and be computed on without triggering actual operations. # However, how can we fakeify a tensor subclass? Ordinary @@ -1740,24 +1795,13 @@ def wrap_tensor(self, value: torch.Tensor): # To simplify things for now, the __dict__ tracking bits haven't # been implemented yet, but they can be added into this design at # a later point in time. - subclass_type = type(value) - else: - assert type(value) in ( - torch.Tensor, - torch.nn.Parameter, - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ) or is_traceable_wrapper_subclass(value), type(value) subclass_type = None - - # NB: this just says we accessed a tensor from the same source again - # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). - # This is distinct from two distinct sources mapping to the same - # Tensor (per id())! No guard is necessary here. See below for the - # other case. - is_duplicate_tensor = source in self.tx.output.input_source_to_var - if is_duplicate_tensor: - return self.tx.output.input_source_to_var[source] + else: + subclass_type = type(value) + options["torch_function_fn"] = build_torch_function_fn( + self.tx, value, self.source + ) + self.install_guards(GuardBuilder.TYPE_MATCH) if get_static_address_type(value) == "guarded": self.install_guards(GuardBuilder.ID_MATCH) @@ -1765,13 +1809,6 @@ def wrap_tensor(self, value: torch.Tensor): # By this point, we should have deduplicated all tensors self.assert_not_wrapped_by_this_graph(value) - options = {} - if type(value) in config.traceable_tensor_subclasses: - options["torch_function_fn"] = build_torch_function_fn( - self.tx, value, self.source - ) - self.install_guards(GuardBuilder.TYPE_MATCH) - if ( isinstance(value, torch.Tensor) and value.is_nested @@ -2464,7 +2501,9 @@ def _wrap_fx_preexisting_tensor( f"wrapped by this instance of Dynamo. Found: {tensor}" ) - return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls) + return construct_tensor_variable( + target_cls, tx, proxy, tensor, subclass_type, options + ) # This is 2 in the above comment (wrapping the output of a traced op) @@ -2498,36 +2537,23 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe import torch._utils if isinstance(example_value, torch.Tensor): - is_parameter = isinstance(example_value, torch.nn.Parameter) - is_buffer = isinstance(example_value, torch.nn.Buffer) - - # NB: In most (all?) cases, this does not actually do a clone. - # (WARNING: this means that if we mutate metadata on the fake - # tensor, the stored example value will update too!) - example_value = _clone_input(example_value, tx.fake_mode) - set_example_value(proxy.node, example_value) - # We bind the unbacked symints in sizes/trdies of tensor lazily. - # So that subgraphs can access the unbacked symbol's proxy in parent graph - # when lifting unbacked symbols of input tensors to subgraph inputs. - # We do it lazily because the tensor may not be used in subgraphs. - tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) - specialized_props = target_cls.specialize(example_value) - # TODO: not sure about this fake mode test - if ( - isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) - and example_value.fake_mode is tx.fake_mode - ): - tensor_type = subclass_type if subclass_type else torch.Tensor - specialized_props["class_type"] = ( - torch.nn.Parameter - if is_parameter - else torch.nn.Buffer - if is_buffer - else tensor_type - ) - - options.update(specialized_props) - return target_cls(proxy, **options) + var = construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options + ) + # NOTE: [Side effect tracking for newly constructed tensor] + # For newly constructed objects that have mutable attributes, we usually + # construct their VariableTracker via `track_object_new`, but since + # tensor variable construction is a bit different, we handle them + # speically here. This ensures that codegen will actually generate the + # attribute mutations on this tensor. + # + # NOTE we pass a dummy object as the `item` argument to avoid + # constructing a dummy _tensor_ object. The object isn't used for + # newly constructed VTs anyways. + tx.output.side_effects._track_obj( + proxy, var, mutation_type_cls=AttributeMutationNew + ) + return var elif ( hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" @@ -2696,6 +2722,43 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ) +def construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options +): + """ + Actually construct a tensor variable after all the pre-processing from + wrapping a pre-existing or newly created tensor value. + """ + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + options.update(specialized_props) + return target_cls(proxy, **options) + + def get_automatic_dynamic_shapes_mark_as(): if config.automatic_dynamic_shapes_mark_as == "dynamic": return DimDynamic.DYNAMIC diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c66c369876b9..5a19e7076899 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -10,6 +10,7 @@ import sys import types import typing +import unittest from collections import defaultdict, OrderedDict from collections.abc import KeysView, Sequence from typing import Callable, TYPE_CHECKING, Union @@ -87,6 +88,7 @@ if TYPE_CHECKING: # Cyclic dependency... + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator log = logging.getLogger(__name__) @@ -730,7 +732,7 @@ def as_proxy(self): return DTYPE[self.fn] return super().as_proxy() - def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen"): + def reconstruct(self, codegen: "PyCodegen"): name = self.fn.__name__ assert self.fn.__module__ == "builtins" assert name not in codegen.tx.f_globals, "shadowed global" @@ -1267,6 +1269,12 @@ def call_str(self, tx: "InstructionTranslator", arg): # Inline the user function return tx.inline_user_function_return(user_func_variable, [arg], {}) + elif isinstance(arg, (variables.ExceptionVariable,)): + if len(arg.args) == 0: + value = f"{arg.exc_type}" + else: + value = ", ".join(a.as_python_constant() for a in arg.args) + return variables.ConstantVariable.create(value=value) def _call_min_max(self, tx: "InstructionTranslator", *args): if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): @@ -1650,7 +1658,10 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): ) def call_len(self, tx: "InstructionTranslator", *args, **kwargs): - return args[0].call_method(tx, "__len__", args[1:], kwargs) + try: + return args[0].call_method(tx, "__len__", args[1:], kwargs) + except AttributeError as e: + raise_observed_exception(type(e), tx, args=list(e.args)) def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__getitem__", args[1:], kwargs) @@ -1798,11 +1809,11 @@ def call_getattr( name_var: VariableTracker, default=None, ): - name = name_var.as_python_constant() - if not name_var.is_python_constant(): unimplemented("non-const getattr() name") + name = name_var.as_python_constant() + if tx.output.side_effects.is_attribute_mutation(obj): if isinstance(obj, variables.UnspecializedNNModuleVariable): if ( @@ -1864,6 +1875,30 @@ def call_getattr( variables.UserDefinedObjectVariable, ), ): + if ( + isinstance(obj, variables.UserDefinedObjectVariable) + and issubclass(obj.value.__class__, unittest.TestCase) + and config.enable_trace_unittest + and name + in ( + "assertRaisesRegex", + "assertNotWarns", + "assertWarnsRegex", + "assertDictEqual", + "assertSequenceEqual", + "assertWarns", + ) + ): + unimplemented_v2( + gb_type="Failed to trace builtin operator", + context=f"function: unittest.TestCase.{name}", + explanation=f"Dynamo does not know how to trace builtin operator `{name}` ", + hints=[ + f"Avoid calling builtin `{name}`. " + "Please report an issue to PyTorch.", + ], + ) + try: return obj.var_getattr(tx, name) except NotImplementedError: @@ -1933,6 +1968,20 @@ def call_setattr( "the middle of the graph, which aot_autograd does not currently know how to handle. " ) elif name == "data": + # See comments on `test_set_data_on_scoped_tensor` for plans + # to support this. + if obj.source is None: + unimplemented_v2( + gb_type="Failed to mutate tensor data attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor created outside `torch.compile` region", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + # Remove the old reference in tracked fakes - if we don't do this # new .data value size and shape differences will cause # tracked fakes to produce incorrect guards. This is sound because the TensorVariable diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 6760bd1ff73a..f86d2d2062a7 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -133,7 +133,7 @@ def const_getattr(self, tx: "InstructionTranslator", name): def call_method( self, - tx, + tx: "InstructionTranslator", name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 04f552c54fa3..26d87113089a 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -41,8 +41,11 @@ from .base import VariableTracker from .functions import ( NestedUserFunctionVariable, + SkipFunctionVariable, UserFunctionVariable, UserMethodVariable, + WrappedNestedUserFunctionVariable, + WrappedSkipFunctionVariable, WrappedUserFunctionVariable, WrappedUserMethodVariable, ) @@ -50,6 +53,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -85,12 +89,12 @@ def exit(self, tx: "InstructionTranslator", *args): self.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct_type(self, codegen): + def reconstruct_type(self, codegen: "PyCodegen"): codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: @@ -111,9 +115,21 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": assert len(args) == 1 + assert isinstance( + args[0], + ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserMethodVariable, + UserFunctionVariable, + ), + ) + if isinstance(args[0], NestedUserFunctionVariable): - args[0] = UserFunctionVariable(args[0].get_function()) - assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable)) + return WrappedNestedUserFunctionVariable(args[0], self) + + if isinstance(args[0], SkipFunctionVariable): + return WrappedSkipFunctionVariable(args[0], self) if isinstance(args[0], UserMethodVariable): return WrappedUserMethodVariable(args[0], self) @@ -1057,7 +1073,7 @@ def exit(self, tx: "InstructionTranslator", *args): _unsafe_set_version_counter ).call_function(tx, [self.tensors, self.prev_versions], {}) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): unimplemented_v2( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), @@ -1278,7 +1294,7 @@ def call_method( def as_proxy(self): return self.proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # If we got here, this stream is fully subsumed by the graph - this means it is # not an input or global assert not self.source @@ -1340,7 +1356,7 @@ def call_method( def as_proxy(self): return self.proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # If we got here, this event is fully subsumed by the graph - this means it is # not an input or global assert not self.source @@ -1378,7 +1394,7 @@ def call_function( assert not kwargs return self.ctx.exit(tx, *args) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 6ed522f5a874..2703d8c4eb7d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -22,7 +22,9 @@ import collections import functools +import inspect import types +from collections.abc import Hashable as py_Hashable from typing import Optional, TYPE_CHECKING from torch._subclasses.fake_tensor import is_fake @@ -32,12 +34,19 @@ from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..source import is_from_local_source -from ..utils import cmp_name_to_op_mapping, dict_keys, dict_values, specialize_symnode +from ..utils import ( + cmp_name_to_op_mapping, + dict_items, + dict_keys, + dict_values, + specialize_symnode, +) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -46,6 +55,10 @@ # - (perhaps) Define how it is compared in _HashableTracker._eq_impl +def was_instancecheck_override(obj): + return type(obj).__dict__.get("__instancecheck__", False) + + def is_hashable(x): # NB - performing isinstance check on a LazVT realizes the VT, accidentally # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at @@ -65,6 +78,13 @@ def is_hashable(x): return x.as_proxy().node.meta.get("example_value") is not None elif isinstance(x, variables.TupleVariable): return all(is_hashable(e) for e in x.items) + elif ( + isinstance(x, variables.UserDefinedObjectVariable) + and not was_instancecheck_override(x.value) + and inspect.getattr_static(x.value, "__hash__") is int.__hash__ + and isinstance(x.value, int) + ): + return isinstance(x.value, py_Hashable) else: return isinstance( x, @@ -73,7 +93,7 @@ def is_hashable(x): variables.SymNodeVariable, variables.ConstantVariable, variables.EnumVariable, - variables.user_defined.UserDefinedClassVariable, + variables.UserDefinedClassVariable, variables.UserFunctionVariable, variables.SkipFunctionVariable, variables.misc.NumpyVariable, @@ -133,6 +153,11 @@ def underlying_value(self): # Access the underlying value inside the referent_vt for the key representation Hashable = ConstDictVariable._HashableTracker return Hashable(self.vt.referent_vt).underlying_value + elif isinstance(self.vt, variables.UserDefinedObjectVariable): + # The re module in Python 3.13+ has a dictionary (_cache2) with + # an object as key (`class _ZeroSentinel(int): ...`): + # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual + return self.vt.value else: x = self.vt.as_python_constant() return x @@ -257,7 +282,7 @@ def is_new_item(self, value, other): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.add_push_null( @@ -376,7 +401,7 @@ def call_method( # corresponding value VT. For __contains__, we add a DICT_CONTAINS # guard. But for all the other methods, we insert the DICT_KEYS_MATCH # guard to be conservative. - from . import BuiltinVariable, ConstantVariable, TupleVariable + from . import BuiltinVariable, ConstantVariable Hashable = ConstDictVariable._HashableTracker @@ -398,9 +423,7 @@ def call_method( self.install_dict_keys_match_guard() if self.source: tx.output.guard_on_key_order.add(self.source.name()) - return TupleVariable( - [TupleVariable([k.vt, v]) for k, v in self.items.items()] - ) + return DictItemsVariable(self) elif name == "keys": self.install_dict_keys_match_guard() if self.source: @@ -542,7 +565,7 @@ def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: def unpack_var_sequence(self, tx): return self.dv_dict.unpack_var_sequence(tx) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # load types.MappingProxyType if self.source: unimplemented( @@ -677,7 +700,7 @@ def python_type(self): def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) @@ -782,7 +805,7 @@ def python_type(self): def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( @@ -858,7 +881,7 @@ class DictViewVariable(VariableTracker): def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: super().__init__(**kwargs) - assert self.kv in ("keys", "values") + assert self.kv in ("keys", "values", "items") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @@ -873,12 +896,9 @@ def view_items_vt(self): raise NotImplementedError def unpack_var_sequence(self, tx): - def unwrap(x): - return x.vt if self.kv == "keys" else x - - return [unwrap(x) for x in self.view_items] + return self.view_items_vt - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) @@ -938,3 +958,15 @@ def view_items_vt(self): def python_type(self): return dict_values + + +class DictItemsVariable(DictViewVariable): + kv = "items" + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] + + def python_type(self): + return dict_items diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index f30b69e44b6b..fcbfb22c6d33 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -75,6 +75,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, @@ -470,7 +471,7 @@ def __str__(self): __repr__ = __str__ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, @@ -520,6 +521,7 @@ def next_variable(self, tx): with patch.dict(counters, {"unimplemented": counters["inline_call"]}): return tracer.inline_call_() except ObservedException as e: + tracer.generator_exhausted = True raise e except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit @@ -956,11 +958,15 @@ def call_function( self.context.exit(tx) return result + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + class WrappedUserFunctionVariable(UserFunctionVariable): def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn", None) - kwargs.pop("obj", None) super().__init__(wrapped.fn, **kwargs) self.wrapped = wrapped self.context = context @@ -976,6 +982,11 @@ def call_function( self.context.exit(tx) return result + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): def convert(x): @@ -1078,6 +1089,11 @@ def call_method(self, tx, name, args, kwargs): def has_closure(self): return self.closure is not None + def const_getattr(self, tx, name): + if name == "__name__": + return self.fn_name.as_python_constant() + return super().const_getattr(tx, name) + def has_self(self): return False @@ -1108,7 +1124,7 @@ def bind_args(self, parent, args, kwargs): return result - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) @@ -1167,6 +1183,46 @@ def reconstruct(self, codegen): codegen.store_attr(name) +class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("fn_name", None) + kwargs.pop("code", None) + kwargs.pop("f_globals", None) + kwargs.pop("defaults", None) + kwargs.pop("kwdefaults", None) + kwargs.pop("annotations", None) + kwargs.pop("closure", None) + kwargs.pop("wrapped_fn", None) + super().__init__( + wrapped.fn_name, + wrapped.code, + wrapped.f_globals, + wrapped.defaults, + wrapped.kwdefaults, + wrapped.annotations, + wrapped.closure, + wrapped.wrapped_fn, + ) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + class SkipFunctionVariable(VariableTracker): _nonvar_fields = { "value", @@ -1198,10 +1254,12 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): + msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) unimplemented_v2( gb_type="Skip calling `torch.compiler.disable()`d function", context=str(self.value), - explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`", + explanation=f"Skip calling function `{self.value}` since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", hints=[ "Remove the `torch.compiler.disable` call", ], @@ -1314,6 +1372,31 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): return fn_var_getattr(tx, self.value, self.source, name) +class WrappedSkipFunctionVariable(SkipFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("value", None) + kwargs.pop("reason", None) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + class WrapperUserFunctionVariable(VariableTracker): """ Used to represent a wrapper object that contains the actual callable as an @@ -1503,7 +1586,7 @@ def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: def python_type(self): return functools.partial - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: @@ -1959,7 +2042,7 @@ def to_metadata(self): self.element_size.as_proxy(), ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 6e971301687e..8eeaacccb38c 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3129,6 +3129,7 @@ def install_subgraph_in_output_graph( # inputs have already been seen before. If yes, the subgraph is already # installed in the output graph and we can just access the subgraph # using the saved attr name. + fake_inputs = [ node.meta["example_value"] for node in body_gmod.graph.nodes diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 502616c440e9..3cf9c994ddc2 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -249,7 +250,7 @@ def __init__(self, item: VariableTracker, **kwargs) -> None: def next_variable(self, tx): return self.item - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -279,7 +280,7 @@ def next_variable(self, tx): self.item = self.item.call_method(tx, "__add__", [self.step], {}) return old_item - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -425,7 +426,7 @@ def get_item(it): self.index += 1 return variables.TupleVariable(args) - def reconstruct_items(self, codegen): + def reconstruct_items(self, codegen: "PyCodegen"): for it in self.iterables: if isinstance(it, list): remaining_items = it[self.index :] @@ -436,7 +437,7 @@ def reconstruct_items(self, codegen): else: codegen(it) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True ) @@ -481,7 +482,7 @@ def next_variable(self, tx): args = super().next_variable(tx) return self.fn.call_function(tx, args.items, {}) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True ) @@ -555,7 +556,7 @@ def _next(): if pred_res.as_python_constant(): return item - def reconstruct_items(self, codegen): + def reconstruct_items(self, codegen: "PyCodegen"): if isinstance(self.iterable, list): remaining_items = self.iterable[self.index :] codegen.foreach(remaining_items) @@ -565,7 +566,7 @@ def reconstruct_items(self, codegen): else: codegen(self.iterable) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen(self.fn) self.reconstruct_items(codegen) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 241bfb2c808b..1430dc912cb1 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -32,10 +32,10 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, variables +from .. import config, trace_rules, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import raise_observed_exception, unimplemented +from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource @@ -57,6 +57,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -81,7 +82,7 @@ def __init__(self, typevar, objvar=None, **kwargs) -> None: # cls for a classmethod) self.objvar = objvar - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) codegen(self.typevar) if self.objvar is not None: @@ -161,6 +162,14 @@ def call_method( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) + # This essentially simulates CPython's `super_getattro`: + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 + # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. + # + # However, `res`'s type needs to be checked for `tp_descr_get`, and + # applied if it has one. We currently don't have polyfills for all the + # relevant `tp_descr_get`, so we explicitly handle the cases we care + # about here (e.g., note the staticmethod, classmethod cases). if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -266,6 +275,37 @@ def call_method( source = self.source and AttrSource(self.source, attr_name) return VariableTracker.build(tx, attr_value, source) + elif inner_fn is torch._C._disabled_torch_function_impl: + # See `THPModule_disable_torch_function` for the C impl. + # The signature of _disabled_torch_function_impl is similar to + # `__torch_function__`, just without the first `cls` argument: + # * (func, types, args, kwargs) + func = args[0] + tf_kwargs = {} + tf_args = args[2].items + for hash_key_vt, value_vt in args[3].items.items(): + key_str = hash_key_vt.vt.as_python_constant() + tf_kwargs[key_str] = value_vt + + output_old = tx.output.torch_function_enabled + tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled + tx.output.torch_function_enabled = False + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + try: + return func.call_function(tx, tf_args, tf_kwargs) + finally: + tx.output.torch_function_enabled = output_old + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + tx_old + ) + elif ( + isinstance(inner_fn, types.MethodDescriptorType) + and inner_fn in trace_rules.get_tensor_method() + ): + # FunctionType but implementation is in C, we support some of these, + # e.g., tensor ops like `torch.Tensor.to`. + fn_var = VariableTracker.build(tx, inner_fn, source) + return fn_var.call_function(tx, [self.objvar] + args, kwargs) unimplemented(f"non-function or method super: {inner_fn}") @@ -292,7 +332,7 @@ def __init__(self, exc_type, args, **kwargs) -> None: def set_context(self, context: "ExceptionVariable"): self.__context__ = context - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", self.exc_type.__name__) ) @@ -396,6 +436,24 @@ class DelayGraphBreakVariable(UnknownVariable): Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. """ + def __init__(self, msg=None, **kwargs): + super().__init__(**kwargs) + self.msg = msg + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented_v2( + gb_type="Unsupported function call (delayed)", + context=f"source: {self.source}", + explanation="Dynamo determined that a graph break should occur " + f"when calling `{self.source.name()}`. Reason: {self.msg}", + hints=[], + ) + class ComptimeVariable(VariableTracker): """ @@ -403,7 +461,7 @@ class ComptimeVariable(VariableTracker): Dynamo compile time """ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError("comptime is special form") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": @@ -620,11 +678,10 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ): - from ..trace_rules import is_callable_allowed from .builder import wrap_fx_proxy if name == "apply": - if is_callable_allowed(self.fn_cls): + if trace_rules.is_callable_allowed(self.fn_cls): trampoline_autograd_apply = produce_trampoline_autograd_apply( self.fn_cls ) @@ -642,8 +699,6 @@ def call_method( elif name == "backward": return self.call_backward(tx, args, kwargs) else: - from .. import trace_rules - source = AttrSource(self.source, name) if self.source is not None else None try: obj = inspect.getattr_static(self.fn_cls, name) @@ -890,7 +945,7 @@ def const_getattr(self, tx: "InstructionTranslator", name): raise NotImplementedError return inspect.getattr_static(step2, name) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.obj) codegen.extend_output(codegen.create_load_attrs(self.name)) @@ -1107,7 +1162,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): def as_python_constant(self): return self.value - def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen") -> None: # We're just trying to load the type here. Reconstructing the type from # scratch is tricky - for a type like `typing.List[int]` we'd need to # deconstruct the origin and args. The origin for `List[int]` is `list` @@ -1282,7 +1337,7 @@ def __init__(self, **kwargs) -> None: def __repr__(self) -> str: return "NullVariable" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if sys.version_info < (3, 11): unimplemented("cannot reconstruct NullVariable in < Python 3.11") codegen.append_output(create_instruction("PUSH_NULL")) @@ -1323,7 +1378,7 @@ def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -1372,7 +1427,7 @@ def call_function(self, tx: "InstructionTranslator", args, kwargs): tx.debug_locals.append((self, list(args))) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): return self.source.reconstruct(codegen) @staticmethod @@ -1667,7 +1722,7 @@ def call_random_meth(*args, **kwargs): return call_random_fn(tx, call_random_meth, args, kwargs) return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -1708,7 +1763,7 @@ def call_function( ) -> "VariableTracker": return self.referent_vt - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) codegen(self.referent_vt) codegen.extend_output(create_call_function(1, False)) diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 51c1ea6bf141..6edd4a7c8ea4 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split() @@ -36,7 +37,7 @@ def __init__(self, proxy, param_vars, **kwargs) -> None: self.param_vars = param_vars super().__init__(**kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): assert self.source is None assert self.param_vars is not None codegen.add_push_null( diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 3470adfa2c7e..ef6a69ceee7c 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -67,9 +67,10 @@ set_example_value, tensortype_to_dtype, ) -from .base import VariableTracker +from .base import AttributeMutationNew, VariableTracker from .constant import ConstantVariable from .lists import SizeVariable +from .user_defined import UserDefinedClassVariable try: @@ -79,6 +80,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -410,8 +412,6 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): return ConstantVariable(ret_val) def var_getattr(self, tx: "InstructionTranslator", name): - from . import UserDefinedClassVariable - if self.is_strict_mode(tx): if name in self._strict_mode_banned_ops(): unimplemented( @@ -453,7 +453,8 @@ def var_getattr(self, tx: "InstructionTranslator", name): ): # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. return variables.misc.DelayGraphBreakVariable( - source=AttrSource(self.source, name) + source=AttrSource(self.source, name), + msg="Getting an inplace view on a graph input is not supported", ) # For attributes (not methods) that were not caught in the special handling above, @@ -613,7 +614,7 @@ def call_method( """ # This is seen in inspect signature where we check if the value is a default value - if name == "__eq__" and isinstance(args[0], variables.UserDefinedClassVariable): + if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): return variables.ConstantVariable(False) try: @@ -788,9 +789,23 @@ def method_as_subclass(self, cls): tx = InstructionTranslator.current_tx() py_cls = cls.as_python_constant() - return TensorWithTFOverrideVariable.from_tensor_var( + var = TensorWithTFOverrideVariable.from_tensor_var( tx, self, py_cls, cls.source ) + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + unimplemented_v2( + gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + context=f"{self}.as_subclass({cls})", + explanation="Currently not supported", + hints=[ + "Avoid this call or move it outside `torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) def method_get_device(self): if isinstance(self.device, torch.device): @@ -1431,25 +1446,58 @@ def from_tensor_variable(cls, tensor_variable): return FakeItemVariable(**dict(tensor_variable.__dict__)) -class TensorSubclassVariable(VariableTracker): - def __init__(self, value, *args, **kwargs) -> None: - self.value = value - super().__init__(*args, **kwargs) - +class TensorSubclassVariable(UserDefinedClassVariable): def call_function( self, tx: "InstructionTranslator", args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if len(args) == 1 and isinstance(args[0], TensorVariable): - from .torch_function import TensorWithTFOverrideVariable + # Handle `Subclass(existing_tensor, ...)` calls. + from .torch_function import TensorWithTFOverrideVariable - return TensorWithTFOverrideVariable.from_tensor_var( - tx, args[0], self.value, self.source + new_func = self.value.__new__ + if new_func is torch.Tensor.__new__: + if ( + len(args) == 1 + and isinstance(args[0], TensorVariable) + and len(kwargs) == 0 + ): + data = args[0] + # Simulate `torch.Tensor.__new__` as shallow-copying the input + # tensor data with a new type. TODO polyfill? + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, data, self.value, self.source + ) + else: + unimplemented_v2( + gb_type="Calling subclass default constructor with more than tensor argument", + context=f"{self.value}(args={args}, kwargs={kwargs})", + explanation="Currently not supported", + hints=[ + "Avoid this constructor call or move it outside " + "`torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + # Let Dynamo trace through custom `__new__` + var = VariableTracker.build(tx, new_func).call_function( + tx, [self] + args, kwargs ) - return super().call_function(tx, args, kwargs) + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) + + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var def as_python_constant(self): return self.value @@ -1511,7 +1559,7 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.from_tensor) codegen.load_method("untyped_storage") codegen.call_method(0) @@ -1526,7 +1574,7 @@ def __init__( super().__init__(**kwargs) self.from_tensor = from_tensor - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.from_tensor) codegen.load_method("data_ptr") codegen.call_method(0) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1e7a9baf9494..c85d5b5c577c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -34,7 +34,7 @@ import math import re from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import torch._C import torch._refs @@ -64,7 +64,7 @@ proxy_args_kwargs, unwrap_if_wrapper, ) -from .base import VariableTracker +from .base import typestr, VariableTracker from .ctx_manager import ( AutocastModeVariable, ProfilerContextVariable, @@ -76,6 +76,7 @@ from .torch_function import ( can_dispatch_torch_function, dispatch_torch_function, + TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) @@ -168,19 +169,23 @@ constant_fold_functions = dict.fromkeys(constant_fold_functions) -tracing_state_functions = { - torch.jit.is_scripting: False, - torch.jit.is_tracing: False, - torch._C._get_tracing_state: None, - torch.fx._symbolic_trace.is_fx_tracing: False, - torch.onnx.is_in_onnx_export: False, - torch._dynamo.external_utils.is_compiling: True, - torch._utils.is_compiling: True, - torch.compiler.is_compiling: True, - torch.compiler.is_dynamo_compiling: True, - torch.compiler.is_exporting: True, - torch.nn.modules.activation._is_make_fx_tracing: False, -} +@functools.lru_cache(None) +def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: + # Defined as a function to avoid circular import like torch.onnx + return { + torch.jit.is_scripting: False, + torch.jit.is_tracing: False, + torch._C._get_tracing_state: None, + torch.fx._symbolic_trace.is_fx_tracing: False, + torch.onnx.is_in_onnx_export: False, + torch._dynamo.external_utils.is_compiling: True, + torch._utils.is_compiling: True, + torch.compiler.is_compiling: True, + torch.compiler.is_dynamo_compiling: True, + torch.compiler.is_exporting: True, + torch.nn.modules.activation._is_make_fx_tracing: False, + } + bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) @@ -222,7 +227,7 @@ def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): try: name = f"{self.value.__module__}.{self.value.__name__}" except Exception: @@ -455,7 +460,7 @@ def _register(handler): ) from .builder import wrap_fx_proxy, wrap_fx_proxy_cls - @register(*tracing_state_functions) + @register(*tracing_state_functions()) def handle_tracing_state_functions( self, tx: "InstructionTranslator", *args, **kwargs ): @@ -469,7 +474,7 @@ def handle_tracing_state_functions( torch.compiler.is_exporting, ): tx.mark_inconsistent_side_effects() - return ConstantVariable.create(tracing_state_functions[self.value]) + return ConstantVariable.create(tracing_state_functions()[self.value]) @register(*dispatch_key_set_functions) def handle_dispatch_key_set_functions( @@ -896,6 +901,28 @@ def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): elif isinstance(expr, ConstantVariable): return expr + @register(torch.fx.experimental.symbolic_shapes.guard_or_true) + def handle_guard_or_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_false) + def handle_guard_or_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) + ) + elif isinstance(expr, ConstantVariable): + return expr + @register(torch._C._autograd._unsafe_set_version_counter) def handle_unsafe_set_version_counter( self, tx: "InstructionTranslator", *args, **kwargs @@ -1157,6 +1184,28 @@ def patched_fn(*args, **kwargs): ) if self.is_tensor_method(): + name = self.value.__name__ + # Guard against inplace view op on input tensor (not supported) + if args and isinstance(args[0], variables.TensorVariable): + tensor_var = args[0] + # Check if input tensor and inplace_view op specifcally + if tensor_var.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view + in getattr(fn, fn.overloads()[0]).tags + ): + unimplemented_v2( + gb_type="Inplace op on input tensor", + context="", + explanation=f"Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + "Ensure you do not modify input tensor in place.", + ], + ) return self.call_tensor_method(tx, args, kwargs) special_handler = self._get_handlers().get(self.value) @@ -1350,7 +1399,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - if is_traceable_wrapper_subclass_type(data.class_type): + if isinstance( + data, TensorWithTFOverrideVariable + ) or is_traceable_wrapper_subclass_type(data.class_type): unimplemented("Parameter constructor with tensor subclass NYI") if not can_convert_to_tracable_parameter(): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index e51f6ccd6c9d..982a65117717 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -24,9 +24,6 @@ See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w for more information on the design. - -To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses -in torch/_dynamo/config.py """ import collections @@ -62,6 +59,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable +from .functions import UserMethodVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -69,6 +67,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -384,7 +383,7 @@ def __init__(self, value, source=None, **kwargs): self.cm_obj = value # needed for BC with calling enter from CM code self.source = source - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) @@ -428,7 +427,7 @@ def exit(self, tx: "InstructionTranslator", *args): ) return ConstantVariable.create(None) - def reconstruct_type(self, codegen): + def reconstruct_type(self, codegen: "PyCodegen"): ty = NoEnterTorchFunctionMode codegen( AttrSource( @@ -592,15 +591,13 @@ def __init__(self, *args, **kwargs) -> None: def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # [Note: __torch_function__] coerce `tensor_var` into a # TensorWithTFOverrideVariable. In eager, this is just a type change. - # This isn't sound if a __torch_function__ tensor subclass defines a - # constructor, but if only a __torch_function__ impl is defined, this is - # okay to call. It is up to the user whether this is correct behavior - # or not. import torch + # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) - assert kwargs.pop("class_type") is torch.Tensor, ( - "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" + input_tensor_type = kwargs.pop("class_type") + assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( + f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" ) torch_fn_var = build_torch_function_fn(tx, class_type, cls_source) var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs) @@ -640,30 +637,56 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - if _is_attr_overidden(tx, self, name): - unimplemented( - f"Accessing overridden method/attribute {name} on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - - if tx.output.torch_function_enabled and hasattr(torch.Tensor, name): - if self.source: - install_guard( - AttrSource(AttrSource(self.source, "__class__"), name).make_guard( - GuardBuilder.FUNCTION_MATCH + # Handle non-overriden attributes inherited from `torch.Tensor`. + attr_is_overriden = _is_attr_overidden(tx, self, name) + if hasattr(torch.Tensor, name) and not attr_is_overriden: + if tx.output.torch_function_enabled: + if self.source: + install_guard( + AttrSource( + AttrSource(self.source, "__class__"), name + ).make_guard(GuardBuilder.FUNCTION_MATCH) ) + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + [self], + {}, ) - get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) - - return self.call_torch_function( - tx, - get_fn, - TupleVariable([self.class_type_var(tx)]), - [self], - {}, - ) else: - return super().var_getattr(tx, name) + # `TensorVariable.var_getattr` doesn't handle user-defined + # function/attribute well, so we explicitly handle them here. + # + # TODO move this logic into `TensorVariable`, or try to merge it + # with similar logic in `UserDefinedObjectVariable`. + try: + attr = inspect.getattr_static(self.class_type, name) + except AttributeError: + pass + else: + import types + + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + attr_source = AttrSource(cls_source, name) + if isinstance(attr, types.FunctionType): + install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return UserMethodVariable(attr, self) + + elif isinstance(attr, property): + getter_source = AttrSource(attr_source, "fget") + getter = attr.fget + getter_var = UserMethodVariable(getter, self, source=getter_source) + return getter_var.call_function(tx, [], {}) + + elif attr_is_overriden: + unimplemented( + f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950 + ) + + return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): return call_torch_function( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index b842a552649f..fc39d238f309 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,6 +82,7 @@ ) from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker from .dicts import DefaultDictVariable +from .lists import SizeVariable try: @@ -96,6 +97,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -579,6 +581,10 @@ def call_function( assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) + elif self.value is torch.Size: + # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. + tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) items = list(args) @@ -1502,7 +1508,7 @@ def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs): return variables.ConstantVariable.create(None) super().call_method(tx, method_name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if self.idx == self.REMOVED: # Hook has already been removed, return a dummy handle codegen.add_push_null( diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 445465bffb64..3db84d43a484 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -27,6 +27,7 @@ _check_dynamic_shapes, _combine_args, _DimHint, + _DimHintType, _process_dynamic_shapes, _RelaxedConstraint, _tree_map_with_path, @@ -52,6 +53,7 @@ SequenceKey, tree_map_with_path, ) +from torch.utils._sympy.numbers import int_oo if TYPE_CHECKING: @@ -362,17 +364,21 @@ def make_constraints( (used only to enumerate the user-input nodes) """ + def is_int(x: object) -> bool: + return isinstance(x, int) or ( + isinstance(x, torch.SymInt) and x.node.expr.is_number + ) + shape_env = fake_mode.shape_env assert shape_env is not None inline_constraints = gm.meta.get("inline_constraints", []) - range_constraints = { - symbol: inline_constraints[symbol] for symbol in inline_constraints - } + range_constraints = defaultdict(lambda: ValueRanges(0, int_oo)) | inline_constraints if not dynamic_shapes: - return range_constraints + return dict(range_constraints) # clean up dynamic markers from tensors - for arg in pytree.tree_flatten(combined_args)[0]: + flat_paths, flat_args = zip(*pytree.tree_flatten_with_path(combined_args)[0]) + for arg in flat_args: if isinstance(arg, torch.Tensor): _clean_dynamic_markers(arg) @@ -388,6 +394,7 @@ def make_constraints( input_dims = defaultdict(list) free_symbols = set() + range_violations = [] for input_index, node in enumerate(gm.graph.nodes): if input_index < num_lifted_inputs or node.op != "placeholder": continue @@ -397,19 +404,63 @@ def make_constraints( continue shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] for i, d in enumerate(node.meta["val"].shape): - if isinstance(d, torch.SymInt) and not d.node.expr.is_number: + dim = None + if isinstance(shape_spec, (list, tuple)): + dim = shape_spec[i] + elif isinstance(shape_spec, dict): + dim = shape_spec.get(i) + if not is_int(d): # Compute the range constraint for the symbolic expression corresponding # to this shape dimension and store it. - dim = shape_spec[i] if shape_spec else None if dim is None or isinstance(dim, _DimHint): - range_constraints[d.node.expr] = shape_env.bound_sympy(d.node.expr) + range_constraints[d.node.expr] &= shape_env.bound_sympy(d.node.expr) else: - range_constraints[d.node.expr] = ValueRanges( + range_constraints[d.node.expr] &= ValueRanges( lower=dim.min, upper=dim.max ) + input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) free_symbols.update(d.node.expr.free_symbols) + # check user-specified min/max range for DimHints; + # we might want to do this even if model tracing inferred a static dimension. + if isinstance(dim, _DimHint): + trace_vr = ( + range_constraints[d.node.expr] + if not is_int(d) + else ValueRanges(int(d), int(d)) + ) + try: + user_vr = ValueRanges( + lower=0 if dim.min is None else dim.min, + upper=int_oo if dim.max is None else dim.max, + ) + if is_int(d): + out_vr = trace_vr & user_vr + else: + range_constraints[d.node.expr] &= user_vr + shape_env.var_to_range[d.node._expr] &= user_vr + out_vr = range_constraints[d.node.expr] + # check for specializations + if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton(): + msg = ( + f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), " + f"but tracing inferred a static shape of {out_vr.lower} for dimension " + f"inputs{pytree.keystr(flat_paths[input_index])}.shape[{i}]." + ) + range_violations.append(msg) + except torch.utils._sympy.value_ranges.ValueRangeError: + msg = ( + f"- Received user-specified min/max range of [{dim.min}, {dim.max}], " + f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], " + f"for inputs{pytree.keystr(flat_paths[input_index])}.shape[{i}]." + ) + range_violations.append(msg) + + if range_violations: + prefix = "Found the following conflicts between user-specified ranges and inferred ranges from model tracing:\n" + raise ValueError(prefix + "\n".join(range_violations)) + for symbol in free_symbols: if symbol not in range_constraints: # Placeholders can have symbolic shapes that are derived expressions. @@ -418,7 +469,7 @@ def make_constraints( # we want to record range constraints for their root symbols. range_constraints[symbol] = shape_env.var_to_range[symbol] - return range_constraints + return dict(range_constraints) def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: @@ -658,9 +709,38 @@ def _override(self, func, args, kwargs): ): return torch._refs.tensor, args, kwargs if func.__name__ == "__getitem__" and isinstance(args[0], torch.Tensor): - # Redirect to torch.select for indexing with symint. - if isinstance(args[1], torch.SymInt): - return torch.select, [args[0], 0, args[1]], {} + + def rewrite(dim, item): + # Redirect to torch.select for indexing. + if isinstance(item, (int, torch.SymInt)): + return dim, (torch.select, [dim, item]) + # Redirect to torch.ops.aten.slice for slicing. + if isinstance(item, slice): + return dim + 1, ( + torch.ops.aten.slice, + [dim, item.start, item.stop, item.step or 1], + ) + # Otherwise do nothing. + + items = args[1] if isinstance(args[1], tuple) else (args[1],) + dim = 0 + # Sequence rewrites. + sequence = [] + for item in items: + if (r := rewrite(dim, item)) is None: + return func, args, kwargs + dim, call_spec = r + sequence.append(call_spec) + + def run(): + # Run sequence. + t = args[0] + for _method, _args in sequence: + t = _method(t, *_args) + return t + + return run, [], {} + return func, args, kwargs def __torch_function__(self, func, types, args=(), kwargs=None): diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 77255c8d07d7..8ecb84b7adf4 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import collections -import warnings +import logging from typing import Any, Union import torch @@ -19,6 +19,9 @@ from torch.fx.graph_module import _get_attr +log = logging.getLogger(__name__) + + class ConstantAttrMap(collections.abc.MutableMapping): """A mapping class that understands how to use module constants (tensors, ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, @@ -175,6 +178,8 @@ def lift_constants_pass( continue if "LoweredBackendModule" in type(constant_val).__name__: continue + if "AOTInductorRunnerWrapper" in type(constant_val).__name__: + continue if isinstance(constant_val, torch.utils._pytree.TreeSpec): continue @@ -213,9 +218,11 @@ def lift_constants_pass( elif isinstance(constant_val, torch.Tensor): # Remove the parameterness of constant_val if isinstance(constant_val, torch.nn.Parameter): - warnings.warn( - f"{node.target} created when tracing {node.meta.get('stack_trace', '')} is a parameter. But" - f"it's not registered with register_parameter(). export will treat it as a constant tensor" + log.debug( + "%s created when tracing %s is a parameter. But " + "it's not registered with register_parameter(). export will treat it as a constant tensor", + str(node.target), + str(node.meta.get("stack_trace", "")), ) # We get the real data out of the parameter by disabling the surrounding fake mode. with unset_fake_temporarily(): @@ -232,7 +239,6 @@ def lift_constants_pass( constant_name = f"lifted_tensor_{num_tensor_constants}" constant_fqn = get_constant_fqn(node, constant_name) num_tensor_constants += 1 - else: raise SpecViolationError( f"getattr node {node} referencing unsupported type {type(constant_val)}" diff --git a/torch/_export/serde/aoti_schema.py b/torch/_export/serde/aoti_schema.py deleted file mode 100644 index d19add43705c..000000000000 --- a/torch/_export/serde/aoti_schema.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from torch._export.serde.schema import Node - - -@dataclass -class ExternKernelNode: - name: str - node: Node - - -@dataclass -class ExternKernelNodes: - nodes: list[ExternKernelNode] diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index fbf0be7d78f6..4274fc431dda 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<<3a8a6be8158821263b71ad9018c921664cd32c2f9b4deeac119e2292d186a02b>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -358,4 +358,5 @@ struct ExternKernelNode { struct ExternKernelNodes { 10: list nodes; + 20: optional string protocol; } diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 0fbaf8644d74..d1d74c624c43 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -8,7 +8,7 @@ from torch._export.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 7) +SCHEMA_VERSION = (8, 8) TREESPEC_VERSION = 1 @@ -484,3 +484,4 @@ class ExternKernelNode: @dataclass class ExternKernelNodes: nodes: Annotated[list[ExternKernelNode], 10] + protocol: Annotated[Optional[str], 20] = None diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 3898303bda4b..e5f9ad4f8e28 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>> +# checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>> AOTInductorModelPickleData: kind: struct fields: @@ -141,6 +141,9 @@ ExternKernelNodes: fields: nodes: type: List[ExternKernelNode] + protocol: + type: Optional[str] + default: None GradientToParameterSpec: kind: struct fields: @@ -530,5 +533,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 7 +- 8 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 14cc7d2731bb..d630896f69c6 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec): def store_namedtuple_fields(ts): if ts.type is None: return - if ts.type == namedtuple: + if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name if serialized_type_name in self.treespec_namedtuple_fields: field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names @@ -1861,7 +1861,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: "as_none", "as_string", ): - node_name = self.signature.input_specs[i].arg.name + node_name = self.signature.input_specs[i].arg.name or f"arg{i}" placeholder_node = self.graph.placeholder(node_name) placeholder_node.meta["val"] = self.deserialize_input(input_) else: diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 833fec60cb69..c3562c470f1b 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -58,6 +58,8 @@ InputKind.TOKEN: "token", } +_DISABLE_ATEN_TO_ASSERTION_PASS = False + def _collect_and_set_constant_attrs( graph_signature, constants, mod @@ -577,6 +579,59 @@ def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.No return [node for node in nodes if node_call_back(node)] +@contextmanager +def _disable_aten_to_metadata_assertions(): + global _DISABLE_ATEN_TO_ASSERTION_PASS + orig_val = _DISABLE_ATEN_TO_ASSERTION_PASS + _DISABLE_ATEN_TO_ASSERTION_PASS = True + try: + yield + finally: + _DISABLE_ATEN_TO_ASSERTION_PASS = orig_val + + +def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + if _DISABLE_ATEN_TO_ASSERTION_PASS: + return + + aten_to_variants = [ + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + torch.ops.aten.to.dtype_layout, + ] + for node in gm.graph.nodes: + if node.target in aten_to_variants: + if ( + node.prev.target == torch.ops.aten._assert_tensor_metadata.default + and node.args[0] == node.prev.args[0] + ): + # skip if already guarded + continue + + if (tensor_val := node.args[0].meta.get("val")) is not None: + with gm.graph.inserting_before(node), _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + ), + ): + gm.graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(node.args[0],), + kwargs={ + "dtype": tensor_val.dtype, + "device": tensor_val.device, + "layout": tensor_val.layout, + }, + ) + + def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): from torch._export.passes._node_metadata_hook import ( _node_metadata_hook, @@ -600,6 +655,10 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): f"exported program: {first_call_function_nn_module_stack(gm.graph)}", export=True, ) + + # insert runtime assertions for aten.to nodes + _insert_aten_to_metadata_assert_pass(gm) + # update output specs gm.recompile() graph_signature.user_outputs = _graph_output_names(gm) diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 4940973c5f0d..8f80f2a6bcc4 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -149,7 +149,12 @@ def allowed_getattr_types(self) -> tuple[type[Any], ...]: def allowed_getattr_types_for_subgm(self) -> tuple[type[Any], ...]: # subgm in HOP's argument could has have getattr(weight) nodes, thus stateful - return (torch.fx.GraphModule, torch.nn.parameter.Parameter, torch.utils._pytree.TreeSpec) + return ( + torch.fx.GraphModule, + torch.nn.parameter.Parameter, + torch.Tensor, # for buffer and constant tensor + torch.utils._pytree.TreeSpec + ) def check_valid_op(self, op): pass @@ -271,10 +276,12 @@ def _is_type(name, ty): elif type(attr).__name__ == "AOTInductorEPModule": continue + elif type(attr).__name__ == "AOTInductorRunnerWrapper": + continue if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)): raise SpecViolationError( - f"Invalid get_attr type {type(attr)}. \n" + f"Invalid get_attr type {type(attr)} on target {node.target}. \n" f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}" ) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 4533d3f12cae..6e31070fd7a5 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -18,7 +18,12 @@ import torch from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions -from torch._dynamo.utils import CompileEventLogger, counters +from torch._dynamo.utils import ( + CHROMIUM_EVENT_LOG, + CompileEventLogger, + counters, + dynamo_timed, +) from torch._functorch import config from torch._inductor.codecache import ( _ident, @@ -373,6 +378,8 @@ def load(self, example_inputs) -> CompiledFxGraph: # so we can call it only after we're sure both forward and backward have # TODO: We don't cache debug lines for now, but we should for improved debugging + # Clear CompiledTritonKernels before loading from FXGraphCache + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() remote_cache = None constants = CompiledFxGraphConstants() if should_use_remote_fx_graph_cache(): @@ -547,45 +554,45 @@ def wrap_post_compile( torch._logging.trace_structured( "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str ) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + else: + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } - compiled_fw_func = self.compiled_fw.load(args) - compiled_bw_func = None - if self.compiled_bw is not None: - compiled_bw_func = self.compiled_bw.load(args) - needs_autograd = True - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - # Now that we've loaded forward and backward, call post compile on both - # This avoids setting things like BoxedBools in fx_config until - # after both forward and backward cache hit - fw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - bw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": True, - } - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, fw_fx_config - ) - compiled_bw_func = self.compiled_bw.post_compile( - compiled_bw_func, bw_fx_config - ) - else: - inference_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - - needs_autograd = False - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, inference_fx_config - ) + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) # Wrap the forward function in post compile wrappers compiled_fw_func = AOTDispatchSubclassWrapper( @@ -598,7 +605,7 @@ def wrap_post_compile( ) req_subclass_dispatch = self.maybe_subclass_meta is not None - CompileEventLogger.pt2_compile( + CompileEventLogger.try_add_pt2_compile( "backend_compile", requires_subclass_dispatch=req_subclass_dispatch ) @@ -841,21 +848,22 @@ def load( "components": debug_lines, } ) - CompileEventLogger.instant( - f"autograd_cache_{cache_state}", - metadata=cache_info, - time_ns=cache_event_time, - ) - CompileEventLogger.try_add_pt2_compile( - "backend_compile", - cache_state=cache_state, - cache_event_time=cache_event_time, - key=cache_info.get("key"), - components=cache_info.get("components"), - cache_bypass_reason=cache_info.get("cache_bypass_reason"), - remote_cache_enabled=remote, - local_cache_enabled=local, - ) + if CHROMIUM_EVENT_LOG: + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) torch._logging.trace_structured( "artifact", diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 0ac2144cd77a..aab77b80d40b 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -486,7 +486,7 @@ def prepare_for_partitioner(mod, num_primals, num_fw_outputs): new_graph.lint() - out = torch.fx.GraphModule(joint_gm, new_graph) + out = torch.fx.GraphModule(mod, new_graph) return out new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index cc7be374e35a..6e1d3a714c12 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -31,6 +31,7 @@ from torch._prims_common import CUDARngStateHelper from torch._subclasses import FakeTensor from torch.fx.experimental._backward_state import BackwardState +from torch.monitor import _WaitCounter from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -2225,7 +2226,9 @@ def _backward_impl(ctx, all_args): dynamo_compile_column_us="backward_cumulative_compile_time_us", log_waitcounter=True, waitcounter_name_override="entire_backward_compile", - ): + ), _WaitCounter( + "pytorch.wait_counter.dynamo_compile" + ).guard(): CompileEventLogger.compilation_metric(is_forward=False) # See Note: [Backward graph lazy lowering] CompiledFunction.compiled_bw = aot_config.bw_compiler( diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 51292fb00985..97b53b6c9a88 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1538,7 +1538,7 @@ def get_default_op_list() -> OpTypes: aten.argmax, aten.maximum, prims.iota, - prims._low_memory_max_pool2d_offsets_to_indices, + prims._low_memory_max_pool_offsets_to_indices, ] # noqa: E501,B950 # Natalia said that we should allow recomputing indexing :) default_recomputable_ops += [aten.index, aten.gather] diff --git a/torch/_guards.py b/torch/_guards.py index ad5f4a7b130a..c85c7b0d7325 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -631,8 +631,12 @@ def update(self, *others: set[Guard]): self.add(g, skip=1) def remove_guards_with_source(self, source): - """Delete all guards with a given source""" - self.inner = {g for g in self.inner if g.originating_source != source} + """Delete all guards that contains a given source""" + from ._dynamo.source import is_from_source + + self.inner = { + g for g in self.inner if not is_from_source(g.originating_source, source) + } class GuardsContext(Checkpointable[GuardsCheckpointState]): @@ -668,12 +672,19 @@ def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ... @abstractmethod def get_proxy_dispatch_entry(self, identifier: str): ... + @abstractmethod + def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule): ... + + @abstractmethod + def get_lazy_bwd_entry(self, identifier: str): ... + class InvokeSubgraphCache(HopSubgraphCache): def __init__(self) -> None: self.autograd_cache: dict[str, Callable] = {} self.proxy_dispatch_cache: dict[str, Callable] = {} self.dynamo_identifiers: dict[str, str] = {} + self.lazy_bwd_cache: dict[str, torch.fx.GraphModule] = {} def add_dynamo_identifier(self, cache_key: str, identifier: str): self.dynamo_identifiers[cache_key] = identifier @@ -693,6 +704,12 @@ def add_proxy_dispatch_entry(self, identifier: str, key: Callable): def get_proxy_dispatch_entry(self, identifier: str): return self.proxy_dispatch_cache.get(identifier, None) + def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule): + self.lazy_bwd_cache[identifier] = gmod + + def get_lazy_bwd_entry(self, identifier: str): + return self.lazy_bwd_cache.get(identifier, None) + class HopDispatchSetCache: def __init__(self) -> None: diff --git a/torch/_higher_order_ops/aoti_call_delegate.py b/torch/_higher_order_ops/aoti_call_delegate.py index 286575726dc2..d90586f8950d 100644 --- a/torch/_higher_order_ops/aoti_call_delegate.py +++ b/torch/_higher_order_ops/aoti_call_delegate.py @@ -1,20 +1,25 @@ +# mypy: allow-untyped-defs + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict - from __future__ import annotations import torch import torch.utils._pytree as pytree from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) -AOTI_LOWERED_MODULE = "AOTInductorEPModule" +AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper" class AOTICallDelegate(HigherOrderOperator): @@ -22,7 +27,7 @@ class AOTICallDelegate(HigherOrderOperator): It has the following signature: aoti_call_delegate( - lowered_module: AOTInductorEPModule, + lowered_module: Union[AOTInductorEPModule, AOTInductorRunnerWrapper] original_gm:fx.GraphModule, weight_args: List[Tensor], input_args: List[Tensor], @@ -30,15 +35,9 @@ class AOTICallDelegate(HigherOrderOperator): where, - lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs - - original_gm is the original GraphModule before lowering, allowing FakeTensor propagation + - original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation - weight_args is the list of weights in original GraphModule, including parameters and buffers - input_args is the list of flatten inputs - - NOTE: aoti_call_delegate doesn't support retracing yet, as original_gm is currently stateful with weight as get_attr nodes. - This will fail functionalization during retrace. When we move AOTI to accept stateless GraphModule, we can enable retracing. - - When serialization, we have special hanlding for aoti_call_delegate, as AOTInductorEPModule is not serializable - and stateful original_gm is failing the verifier. """ def __init__(self) -> None: @@ -62,7 +61,6 @@ def __call__( @aoti_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) -# pyre-ignore def call_delegate_cpu( lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] original_gm: torch.fx.GraphModule, @@ -77,27 +75,60 @@ def call_delegate_cpu( new_args = pytree.tree_map_only( tuple(map_types.keys()), lambda a: map_types[type(a)](a), - input_args, + weight_args + input_args, lambda a: isinstance(a, tuple(map_types.keys())), ) - - has_fake_input_args = any(isinstance(arg, FakeTensor) for arg in new_args) - has_fake_params = any( - isinstance(param, FakeTensor) for param in original_gm.parameters() - ) - has_fake_buffers = any( - isinstance(buffer, FakeTensor) for buffer in original_gm.buffers() + has_fake_args = any(isinstance(arg, FakeTensor) for arg in new_args) + if has_fake_args: + # use stateless original_gm for tracing with fake tensors + fake_out = original_gm(*new_args) + return fake_out + else: + # use AOTI Runner for real tensors + new_input_args = new_args[len(weight_args) :] + if type(lowered_module).__name__ == "AOTInductorRunnerWrapper": + return lowered_module(*new_input_args) # type: ignore[misc] + elif type(lowered_module).__name__ == "AOTInductorEPModule": + return lowered_module(new_input_args) # type: ignore[misc] + else: + raise RuntimeError( + f"Unexpected lowered_module type: {type(lowered_module)}." + ) + + +def trace_aoti_call_delegate( + proxy_mode, func_overload, lowered_module, original_gm, weight_args, input_args +): + proxy_mode.tracer.root.register_module("lowered_module", lowered_module) + proxy_mode.tracer.root.register_module("original_gm", original_gm) + + node_args = (lowered_module, original_gm, weight_args, input_args) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="aoti_call_delegate" ) + with disable_proxy_modes_tracing(): + out = call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) - if has_fake_input_args or has_fake_params or has_fake_buffers: - # aoti lowered module doesn't support fake tensor - return original_gm(*new_args) - else: - return lowered_module(new_args) # type: ignore[misc] + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@aoti_call_delegate.py_impl(ProxyTorchDispatchMode) +def call_delegate_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +): + res = trace_aoti_call_delegate( + mode, aoti_call_delegate, lowered_module, original_gm, weight_args, input_args + ) + return res @aoti_call_delegate.py_impl(FakeTensorMode) -# pyre-ignore def call_delegate_fake_tensor_mode( mode: FakeTensorMode, lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] @@ -107,3 +138,24 @@ def call_delegate_fake_tensor_mode( ) -> list[torch.Tensor]: with mode: return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) + + +@aoti_call_delegate.py_functionalize_impl +def call_delegate_functionalize( + ctx, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +): + unwrapped_weight_args = tuple( + ctx.unwrap_tensors(weight_arg) for weight_arg in weight_args + ) + unwrapped_input_args = tuple( + ctx.unwrap_tensors(input_arg) for input_arg in input_args + ) + with ctx.redispatch_to_next(): + res = aoti_call_delegate( + lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type] + ) + return ctx.wrap_tensors(res) diff --git a/torch/_higher_order_ops/base_hop.py b/torch/_higher_order_ops/base_hop.py index 02eee4b2c07b..af47b3e5fdc5 100644 --- a/torch/_higher_order_ops/base_hop.py +++ b/torch/_higher_order_ops/base_hop.py @@ -6,7 +6,10 @@ import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization -from torch._higher_order_ops.utils import reenter_make_fx +from torch._higher_order_ops.utils import ( + check_input_alias_and_mutation_return_ouputs, + reenter_make_fx, +) from torch._ops import HigherOrderOperator from torch._subclasses import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode @@ -126,6 +129,64 @@ def _call_Functionalize(self, ctx, subgraph, *operands, **kwargs): out = self(functionalized_subgraph, *unwrapped_operands, **kwargs) return ctx.wrap_tensors(out) + def gen_schema(self, *args, **kwargs): + from .schema import CFunctionSchemaGen, HopArgumentInfoGen + + subgraph, *operands = args + + assert isinstance( + subgraph, torch.fx.GraphModule + ), f"NYI non GraphModule subgraph got {subgraph}" + + fake_args = [ + ph.meta["example_value"] + for ph in subgraph.graph.find_nodes(op="placeholder") + ] + ( + mutated_inp_idx, + inp_inp_alias, + inp_out_alias, + out_out_alias, + output, + ) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args) + + assert ( + len(inp_inp_alias) == 0 + and len(inp_out_alias) == 0 + and len(out_out_alias) == 0 + ), "Aliasing is not suppported for HOP subgraph." + args = [ + HopArgumentInfoGen.from_example( + subgraph, name="subgraph", default_value=None, is_mutated=False + ) + ] + for idx, arg in enumerate((*operands, *kwargs.items())): + if isinstance(arg, tuple): + # kwargs value are treated as default argument + arg_name, example_value = arg + default = example_value + else: + arg_name = f"arg{idx}" + example_value = arg + default = None + args.append( + HopArgumentInfoGen.from_example( + example_value=example_value, + name=arg_name, + default_value=default, + is_mutated=idx in mutated_inp_idx, + ) + ) + + # The output is represented as a single argument + out = HopArgumentInfoGen.from_example( + example_value=output, + name="out", + default_value=None, + is_mutated=False, + ) + return CFunctionSchemaGen.from_hop_argument_info(str(self), args, out) + class BaseHOPFunction(torch.autograd.Function): @staticmethod @@ -151,9 +212,11 @@ def backward(ctx, *grad_outputs): from .utils import _from_fun fw_inputs = pytree.tree_map(_from_fun, operands) - _, joint_graph, _ = create_fw_bw_graph( - subgraph, fw_inputs, grad_outputs - ) + ( + _, + joint_graph, + _, + ) = create_fw_bw_graph(subgraph, fw_inputs, grad_outputs) # The joint graph returns (*grad_inputs, *fwd_outputs). # We only need the grad_inputs. diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 6501ca6ad1ca..31846752e3db 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -204,7 +204,7 @@ def materialize_as_graph( exclude_key_set: torch._C.DispatchKeySet, force_enable_grad=False, ) -> torch.fx.GraphModule: - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable(recursive=True, reason=None) def _materialize_as_graph_inner(): with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 4ac832cc6221..c899370b8d5a 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs +import contextlib from contextlib import nullcontext +from dataclasses import dataclass, field from typing import Optional, Union import torch @@ -31,11 +33,22 @@ track_tensor_tree, ) from torch.fx.graph_module import GraphModule +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts invoke_subgraph_counter = 0 +# During the tracing of the joint graph, we construct this information. This is +# used to filter out grad_outs/tangents in the `backward` method of +# InvokeSubgraphAutogradOp. +@dataclass +class OutputMetadata: + num_fw_outs: Optional[int] = None + indexes_with_none: set[int] = field(default_factory=set) + indexes_with_no_grad: set[int] = field(default_factory=set) + + class InvokeSubgraphHOP(HigherOrderOperator): def __init__(self) -> None: super().__init__("invoke_subgraph") @@ -133,6 +146,7 @@ def get_invoke_subgraph_cache(): return cache +# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra def trace_joint_graph(fn, fw_inputs, fw_outputs): """ Naively trace out a joint graph. This simplifies the reconstruction of joint @@ -173,6 +187,7 @@ def joint_fn(*primals_and_tangents): return _maybe_reenter_make_fx(joint_fn)(*joint_operands) +# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra def create_fw_bw_graph(subgraph, operands, grad_outputs=None): with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): @@ -188,11 +203,41 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): else fake_mode.shape_env.ignore_fresh_unbacked_symbols() ) + with context: + fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + + num_fw_outs = len(fw_outs) + + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. + output_metadata = OutputMetadata() + + output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): + if fw_out is None: + output_metadata.indexes_with_none.add(idx) + elif not fw_out.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + if grad_outputs is None: # Infer grad_outputs to be the same properties as the fw_outputs # if they're not passed in - with context: - grad_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + # Although fw_outs are equivalent to grad_outputs for tracing + # purposes, we have to carefully handle the None and fw_out that do + # not have require_grad. At those indexes, we will have None in the + # backward graph. + grad_outputs = fw_outs + grad_outputs = [grad for grad in grad_outputs if grad is not None] + grad_outputs = [grad for grad in grad_outputs if grad.requires_grad] + + # Force grad_out to be contiguous. This is because at runtime, + # grad_out could have different strides than fw_outs. So, we + # force the grad_outs to be contiguous for both tracing and + # runtime. + grad_outputs = [grad.contiguous() for grad in grad_outputs] if any( not isinstance(out, torch.Tensor) @@ -213,60 +258,193 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): fw_inputs, grad_outputs, ) - return fw_graph, bw_graph, len(grad_outputs) + return fw_graph, bw_graph, output_metadata + + +def get_output_metadata(subgraph, operands): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + # args are functional tensors, generate some example tensors + fw_inputs = pytree.tree_map(_from_fun, operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(fw_inputs) + context = ( + nullcontext() + if fake_mode is None or fake_mode.shape_env is None + else fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + + with context: + fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + + num_fw_outs = len(fw_outs) + + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. + output_metadata = OutputMetadata() + + output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): + if fw_out is None: + output_metadata.indexes_with_none.add(idx) + elif not fw_out.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + return output_metadata + + +def trace_joint_graph_as_bwd( + subgraph, num_primals, joint_operands, include_key_set, exclude_key_set +): + """ + Naively trace out a joint graph. This simplifies the reconstruction of joint + graph in the min-cut partitioner later on. + """ + from torch._functorch.aot_autograd import create_joint + + dummy_aot_config = get_dummy_aot_autograd_config() + + if isinstance(subgraph, torch.fx.GraphModule): + + def graph_with_interpreter(*args): + # Running graph with interpreter is needed for propagating the stack_trace + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(subgraph).run(*args) + + fn = graph_with_interpreter + else: + fn = subgraph + + # This joint_fn is inserted as the backward graph as is. This simplifies the + # min-cut partitioner work later on. + # Input signature - (*primals, *tangents) + # Output signature - (*grads, *fw_outs) + # The output signature is deliberately kept grads first and fw_outs second. + # Having grads first makes the min-cut partitioner HOP graph stitching + # easier. + def joint_fn(*primals_and_tangents): + primals = primals_and_tangents[:num_primals] + tangents = primals_and_tangents[num_primals:] + + fw_outs, grads = create_joint( + prepare_fw_with_masks(fn), aot_config=dummy_aot_config + )(primals, tangents) + + maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents) + + # return signature is deliberately kept (*grads, *fw_outs). This + # simplifies partitioning work later on. + return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs))) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + joint_operands = [_from_fun(arg) for arg in joint_operands] + with contextlib.ExitStack() as stack: + stack.enter_context( + torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), + ) + with torch.enable_grad(): + return _maybe_reenter_make_fx(joint_fn)(*joint_operands) class InvokeSubgraphAutogradOp(torch.autograd.Function): """ - This autograd function op is to stash the backward graph in the ctx while - running forward. + Saves the subgraph, i.e. original callable, in the forward method. And then + traces out a joint graph in the backward. This delaying of tracing in + backward, also called as lazy backward, ensures that the assumptions about + the grad_out strides and tensor-subclass-ness are already accounted for. """ @staticmethod - def forward(ctx, fw_graph, bw_graph, identifier, num_fw_outs, *operands): - ctx._fw_graph = fw_graph - ctx._bw_graph = bw_graph + def forward( + ctx, + subgraph, + identifier, + output_metadata, + *operands, + ): + # We want to delay the backward graph construction until the backward. + # So in forward, we just run the fw callable as is. And save all the + # information necessary to construct the backward graph in the ctx. + ctx._subgraph = subgraph ctx._identifier = identifier - ctx._num_fw_outs = num_fw_outs + ctx._output_metadata = output_metadata + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + save_tensors_and_symints_for_backward(ctx, operands) with torch._C._AutoDispatchBelowAutograd(): out = invoke_subgraph( - fw_graph, + subgraph, f"___forward_{identifier}", operands, ) - save_tensors_and_symints_for_backward(ctx, operands) + # Check that None is at expected indexes. + for idx, o in enumerate(out): + if o is None: + assert idx in output_metadata.indexes_with_none + return out @staticmethod - def backward(ctx, *grad_outs): - bw_graph = ctx._bw_graph + def backward( + ctx, + *grad_outs, + ): + subgraph = ctx._subgraph identifier = ctx._identifier + output_metadata = ctx._output_metadata primals = saved_tensors_and_symints(ctx) - num_fw_outs = ctx._num_fw_outs - # While tracing we made the assumption that tangents are contiguous. So, - # force the grad_outs to be contiguous. - contiguous_grad_outs = tuple([o.contiguous() for o in grad_outs]) + # Filter out grads that are None or do not require_grad. This was + # the assumption we made during the tracing of joint_graph. + filtered_grad_outs = [] + for idx, o in enumerate(grad_outs): + if o is None: + assert idx in output_metadata.indexes_with_none + elif idx in output_metadata.indexes_with_no_grad: + # Deliberately skip over the grad_outs which we know should be + # None because the corresponding fwd_out does not require_grad. + pass + else: + filtered_grad_outs.append(o) + filtered_grad_outs = tuple(filtered_grad_outs) # bw_graph is a joint graph with signature (*primals_and_tangents) and # returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs # to extract the grads. - primals_and_tangents = primals + contiguous_grad_outs - grads = invoke_subgraph( - bw_graph, f"___backward_{identifier}", primals_and_tangents - )[:-num_fw_outs] - return None, None, None, None, *grads + primals_and_tangents = primals + filtered_grad_outs + # Check if we have already traced the bwd subgraph. + bw_graph = None + invoke_subgraph_cache = get_invoke_subgraph_cache() + if invoke_subgraph_cache: + bw_graph = invoke_subgraph_cache.get_lazy_bwd_entry(identifier) -@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) -def _(subgraph, identifier, operands): - from torch.utils._python_dispatch import _get_current_dispatch_mode + if bw_graph is None: + bw_graph = trace_joint_graph_as_bwd( + subgraph, + len(primals), + primals_and_tangents, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + ) - mode = _get_current_dispatch_mode() - assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return subgraph(*operands) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_lazy_bwd_entry(identifier, bw_graph) + + grads = invoke_subgraph( + bw_graph, f"___backward_{identifier}", primals_and_tangents + )[: -output_metadata.num_fw_outs] + return None, None, None, *grads @invoke_subgraph.py_impl(DispatchKey.Autograd) @@ -293,11 +471,11 @@ def _(subgraph, identifier, operands): ): return saved_autograd_fn(*operands) - fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands) + output_metadata = get_output_metadata(subgraph, operands) def autograd_fn_callable(*args): return InvokeSubgraphAutogradOp.apply( - fw_graph, bw_graph, identifier, num_fw_outs, *args + subgraph, identifier, output_metadata, *args ) # Save the autograd_fn_callable in the dispatch set cache. @@ -307,6 +485,15 @@ def autograd_fn_callable(*args): return autograd_fn_callable(*operands) +@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) +def _(subgraph, identifier, operands): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return subgraph(*operands) + + @invoke_subgraph.py_functionalize_impl def _(ctx, subgraph, identifier, operands): unwrapped_operands = ctx.unwrap_tensors(operands) @@ -335,6 +522,18 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands): if graph is None: graph = reenter_make_fx(subgraph)(*operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(operands) + insert_deferred_runtime_asserts( + graph, + fake_mode.shape_env, + "invoke_subgraph_proxy_torch_dispatch_mode", + export=True, + ) + graph.recompile() + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") proxy_mode.tracer.root.register_module(qualname, graph) diff --git a/torch/_higher_order_ops/schema.py b/torch/_higher_order_ops/schema.py new file mode 100644 index 000000000000..1cf4e9a5032c --- /dev/null +++ b/torch/_higher_order_ops/schema.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +# Below is an implementation of generating FunctionSchema from example values. +# This is helpful for generating FunctionSchema for HigherOrderOperator, where +# we don't have a function to inspect and each call of the higher order operator +# would have different schema. +@dataclass(frozen=True) +class HopArgumentInfo: + # Could give a name to the operand by default it's empty string. + name: str + example_value: Any + # Provide an default_value + default_value: Any + # Whether this arugment gets mutated in the hop subgraph. + # For output, this should always be False + is_mutated: bool + + +class HopArgumentInfoGen: + @staticmethod + def from_example( + example_value: Any, + *, + name: str = "", + default_value: Optional[Any], + is_mutated: bool = False, + ) -> HopArgumentInfo: + if default_value is not None: + assert type(example_value) == type(default_value) + return HopArgumentInfo( + name=name, + example_value=example_value, + default_value=default_value, + is_mutated=is_mutated, + ) + + +class CTypeGen: + convert_to_base_ty = { + int: torch._C.IntType.get(), + float: torch._C.FloatType.get(), + str: torch._C.StringType.get(), + bool: torch._C.BoolType.get(), + } + + # should return torch._C.JitType but that annotation is busted + @staticmethod + def from_example(obj: Any) -> Any: + import torch + + if isinstance(obj, torch.fx.GraphModule): + return torch._C.AnyType.get() + return torch._C._jit_try_infer_type(obj).type() + + +class CArgumentGen: + @staticmethod + def from_hop_argument_info( + arg_idx: int, arg_info: HopArgumentInfo, is_output: bool = False + ) -> Any: + typ = CTypeGen.from_example(arg_info.example_value) + if is_output: + return torch._C.Argument("", typ, None, None, False, None) + + alias_set = set({f"alias::a{arg_idx}"}) if arg_info.is_mutated else set() + alias_info = torch._C._AliasInfo(arg_info.is_mutated, alias_set, alias_set) # type: ignore[attr-defined] + return torch._C.Argument( + arg_info.name, typ, None, arg_info.default_value, False, alias_info + ) + + +class CFunctionSchemaGen: + """ + Note: [HigherOrderOperator schema generation] + Each invocation of a HigherOrderOperator will have a different schema. + For example, the schema of torch.cond varies depending on the true_fn and + false_fn. So we need a way to generate the schema for each invocation of a HOP. + + We want to enforce the following invariants for HOP's schema: + 1. Flattened inputs. There should be no pytree structure in it. + 2. Flattened outputs. Note even if the hop returns a single value, it should be wrapped as a tuple. + 3. No aliasing. This includes inp-inp aliasing, inp-out aliasing and out-out aliasing. + + By enforcing these invariants, we could make HOP's schema meets the requirement of schema parser + and makes hop easier to handle downstream. For example, suppose we have an invoke_quant_test HOP: + + class GraphModule(torch.nn.Module): + def forward(self, l_x_, l_y_): + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); + + class subgraph_0(torch.nn.Module): + def forward(self, l_x_, l_y_): + add_ = l_x_.add_(1) + matmul = l_x_ @ l_y_ + sin = matmul.sin() + child = sin.cos() + child_1 = l_x_ + l_y_ + child_2 = l_x_ - l_y_ + child_3 = l_x_ @ l_y_ + return (child, child_1, child_2, child_3) + + By encoding the inputs of hop into a list of HopArgumentInfo and output as a single HopArgumentInfo, + we would get the following schema: + invoke_quant_test(Any arg0, Tensor(!) arg1, Tensor arg2, str scheme="\\"nf4\\"") -> (Tensor, Tensor, Tensor, Tensor) + """ + + @staticmethod + def from_hop_argument_info( + op_name: str, + inp_argument_info: list[HopArgumentInfo], + out_argument_info: HopArgumentInfo, + ) -> Any: + args = [] + for i, arg_info in enumerate(inp_argument_info): + args.append(CArgumentGen.from_hop_argument_info(i, arg_info)) + + # NOTE: we want the output to always be a single argument with torch._C.TupleType. + assert isinstance( + out_argument_info.example_value, tuple + ), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" + assert ( + not out_argument_info.is_mutated + ), "out_argument_info.is_mutated should always be set to False." + rets = None + if len(out_argument_info.example_value) == 1: + rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)] + else: + rets = [ + CArgumentGen.from_hop_argument_info( + i, + HopArgumentInfoGen.from_example( + name=f"out{i}", + example_value=val, + default_value=None, + is_mutated=False, + ), + is_output=True, + ) + for i, val in enumerate(out_argument_info.example_value) + ] + + return torch._C.FunctionSchema( + op_name, + "", + args, + rets, + False, + False, + ) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 4fb0ca60098b..fd3f327a68ae 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -15,6 +15,7 @@ disable_proxy_modes_tracing, make_fx, ) +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.multiprocessing.reductions import StorageWeakRef @@ -293,6 +294,10 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): pre_dispatch=pre_dispatch, _error_on_data_dependent_ops=False, )(*inputs) + if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: + insert_deferred_runtime_asserts( + gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True + ) return gm @@ -698,6 +703,25 @@ def check_input_alias_and_mutation( gm: torch.fx.GraphModule, fake_args: list[FakeTensor], ) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]: + ( + mutated_inputs, + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + ) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1] + return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map + + +def check_input_alias_and_mutation_return_ouputs( + gm: torch.fx.GraphModule, + fake_args: list[FakeTensor], +) -> tuple[ + list[int], + dict[int, int], + dict[int, int], + dict[int, int], + Union[tuple[Any, ...], list[Any]], +]: with disable_proxy_modes_tracing(): """This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias in the graph module gm. It checks whether input tensor versions have @@ -760,7 +784,13 @@ def _tensor_storage(t) -> StorageWeakRef: for i, inp in enumerate(cloned) if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map } - return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map + return ( + mutated_inputs, + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + outputs, + ) registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {} @@ -816,4 +846,10 @@ def __repr__(self): return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})" def __call__(self, *args, **kwargs): + if isinstance(self.subgraph, torch.fx.GraphModule): + # Running graph with interpreter is needed for propagating the stack_trace + with fx_traceback.preserve_node_meta(): + return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)( + *args, **kwargs + ) return self.ctx.functionalize(self.subgraph)(*args, **kwargs) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index a2acd6570a20..f9d05e24fff7 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -283,12 +283,14 @@ def aot_compile( flat_example_inputs, options = _aoti_flatten_inputs( gm, args, kwargs, options=options ) + from torch._export.utils import _compiling_state_context - return compile_fx_aot( - gm, - flat_example_inputs, # type: ignore[arg-type] - config_patches=options, - ) + with _compiling_state_context(): + return compile_fx_aot( + gm, + flat_example_inputs, # type: ignore[arg-type] + config_patches=options, + ) def list_mode_options( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index e331badcff34..8a01ae5d6429 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1047,12 +1047,17 @@ def iterate_over_candidates() -> Generator[ triton_bundler_meta = TritonBundler.read_and_emit(bundle) if (meta := triton_bundler_meta) is not None: cache_info["triton_bundler_meta"] = str(meta) - # TODO: Clean up autograd cache integration CompileEventLogger.try_add_pt2_compile( "inductor_compile", cached_kernel_names=meta.cached_kernel_names ) + CompileEventLogger.try_add_pt2_compile( + "AOTAutogradCache.inductor_load", + cached_kernel_names=meta.cached_kernel_names, + ) if len(meta.cached_kernel_names) > 0: - CompileEventLogger.increment_toplevel("num_triton_bundles") + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, "num_triton_bundles" + ) try: artifact_path = graph.after_deserialization(constants) @@ -1306,17 +1311,22 @@ def load_with_key( cache_info["cache_state"] = "hit" if remote_cache: # Count remote cache hit stats - CompileEventLogger.increment_toplevel( - "inductor_fx_remote_cache_hit_count" + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_hit_count", ) - CompileEventLogger.add_to_set_toplevel( - "inductor_fx_remote_cache_hit_keys", key + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_hit_keys", + key, ) if (time_saved_ns := compiled_graph._time_taken_ns) is not None: cache_info["time_saved_ns"] = time_saved_ns - CompileEventLogger.increment_toplevel( - "distributed_ephemeral_timeout_us", time_saved_ns // 1000 + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "distributed_ephemeral_timeout_us", + time_saved_ns // 1000, ) if ( ephemeral_increase @@ -1326,11 +1336,14 @@ def load_with_key( else: if remote_cache: # Count remote cache miss stats - CompileEventLogger.increment_toplevel( - "inductor_fx_remote_cache_miss_count" + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_miss_count", ) - CompileEventLogger.add_to_set_toplevel( - "inductor_fx_remote_cache_miss_keys", key + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_miss_keys", + key, ) log.info("fx graph cache miss for key %s", key) counters["inductor"]["fxgraph_cache_miss"] += 1 @@ -1451,6 +1464,10 @@ def compile( extra=cpp_command, specified_dir=specified_output_path, ) + kernel_code = ( + f"// Triton kernels are embedded as comments in {wrapper_path}\n" + + kernel_code + ) _, kernel_path = write( kernel_code, "kernel.cpp", @@ -3160,6 +3177,7 @@ class CUDACodeCache: class CacheEntry: input_path: str output_path: str + error_json: Optional[str] = None cache: dict[str, CacheEntry] = {} cache_clear = staticmethod(cache.clear) @@ -3196,6 +3214,14 @@ def compile( lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if os.path.exists(output_path + ".error"): + with open(output_path + ".error", encoding="utf-8") as fh: + error_json = fh.read() + cmd_parts, error_output = json.loads(error_json) + cls.cache[key] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + raise exc.CUDACompileError(cmd_parts, error_output) if not os.path.exists(output_path): cmd = cuda_compile_command( [input_path], output_path, dst_file_ext, extra_args @@ -3211,6 +3237,14 @@ def compile( cmd_parts, stderr=subprocess.STDOUT, env=os.environ ) except subprocess.CalledProcessError as error: + error_json = json.dumps( + [cmd_parts, error.output.decode("utf-8")] + ) + cls.cache[key] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + with open(output_path + ".error", "w", encoding="utf-8") as fh: + fh.write(error_json) raise exc.CUDACompileError(cmd_parts, error.output) from error end_time = time() log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" @@ -3220,8 +3254,12 @@ def compile( "CUDA Compilation skipped: %s since output already exists", input_path, ) - cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) - + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, None) + cache_entry: CUDACodeCache.CacheEntry = cls.cache[key] + if cache_entry.error_json is not None: + # Restore cached Exception and raise it as if we had compiled + cmd_parts, error_output = json.loads(cache_entry.error_json) + raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) return (cls.cache[key].output_path, key, input_path) @classmethod diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 7fce40e869ef..417e215d4f57 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1397,6 +1397,8 @@ def output(self, name: str) -> str: return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) assert output_name not in self.inplace_buffers if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 60d151d59b10..b05e20ada3ee 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2990,7 +2990,9 @@ def store_reduction(self, name, index, value): else: # Vertical reduction if out_dtype != dtype: - converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" + converted_value = ( + f"{DTYPE_TO_CPP[out_dtype].replace('::', '_')}_{value}" + ) if out_dtype == torch.bool: convert = f"{value}.template cast()" else: @@ -5457,20 +5459,21 @@ def max_parallel_depth(self): start_depth = 0 max_depth = 0 is_reduction = self.loops[0].is_reduction - loop_sizes = sympy.Integer(1) + num_steps = sympy.Integer(1) for loop in self.loops: if loop.is_reduction != is_reduction: break - loop_sizes = loop_sizes * loop.size + num_steps = num_steps * FloorDiv(loop.size, loop.steps) max_depth += 1 - # When the range of the first inner loop is much larger than the range of all outer loops, - # change `start_depth` to the first inner loop and recalculate `max_depth`. + # When the number of steps of the first inner loop is much larger than the number of steps of + # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. if ( max_depth < len(self.loops) - and isinstance(loop_sizes, sympy.Integer) + and isinstance(num_steps, sympy.Integer) and isinstance(self.loops[max_depth].size, sympy.Integer) - and loop_sizes * 300 < self.loops[max_depth].size + and num_steps * 300 + < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps) ): start_depth = max_depth max_depth = 0 diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 67a1b08cb5c4..77cf270ad894 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1413,7 +1413,7 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): int64_t ldb, int64_t ldc, int64_t q_group_size, - const bfloat16* {{restrict_keyword}} ScaleAndZeros, + const at:BFloat16* {{restrict_keyword}} ScaleAndZeros, int64_t lds, // leading dimension of ScaleAndZeros int64_t k_start) { constexpr int BLOCK_K = {{block_k}}; @@ -1551,7 +1551,7 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): def get_kernel_extra_args_declare(self) -> str: return ( "const int64_t q_group_size,\n" - " const bfloat16* __restrict__ ScaleAndZeros,\n" + " const at:BFloat16* __restrict__ ScaleAndZeros,\n" " const int64_t lds,\n" " int64_t k_start," ) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 8254363cbdcb..415c979c0eda 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -43,12 +43,6 @@ #include #endif -typedef at::Half half; -typedef at::BFloat16 bfloat16; - -typedef at::Float8_e4m3fn float8_e4m3fn; -typedef at::Float8_e5m2 float8_e5m2; - template struct Welford { T mean = T(0); @@ -86,7 +80,7 @@ struct WelfordHelper { std::vector> welford_stk; uint64_t depth; // depth of welford_stk. uint64_t num_chunks; // number of chunks stored in welford_stk. - WelfordHelper() {} + WelfordHelper() = default; WelfordHelper(uint64_t N) { uint64_t m = (N + kChunkSize - 1) / kChunkSize; //div up depth = m > 0 ? ceil(log2(m)) : 0; @@ -635,7 +629,7 @@ inline int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_ template struct AsIntegerType { typedef T type; }; template <> struct AsIntegerType { typedef uint32_t type; }; template <> struct AsIntegerType { typedef uint64_t type; }; -template <> struct AsIntegerType { typedef uint16_t type; }; +template <> struct AsIntegerType { typedef uint16_t type; }; template typename std::enable_if_t, T> diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 8362d052f773..8707ee4d9bb2 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -30,7 +30,7 @@ DTYPE_TO_CPP = { torch.float32: "float", torch.float64: "double", - torch.float16: "half", + torch.float16: "at::Half", torch.int64: "int64_t", torch.int32: "int32_t", torch.int16: "int16_t", @@ -40,14 +40,14 @@ torch.uint16: "uint16_t", torch.uint8: "uint8_t", torch.bool: "bool", - torch.bfloat16: "bfloat16", - torch.complex32: "c10::complex", - torch.complex64: "c10::complex", - torch.complex128: "c10::complex", - torch.float8_e4m3fn: "float8_e4m3fn", - torch.float8_e5m2: "float8_e5m2", - torch.float8_e4m3fnuz: "float8_e4m3fnuz", - torch.float8_e5m2fnuz: "float8_e5m2fnuz", + torch.bfloat16: "at::BFloat16", + torch.complex32: "at::complex", + torch.complex64: "at::complex", + torch.complex128: "at::complex", + torch.float8_e4m3fn: "at::Float8_e4m3fn", + torch.float8_e5m2: "at::Float8_e5m2", + torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz", } DTYPE_TO_ATEN = { diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 1ea1459659ae..6a0a4b4ba888 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -403,6 +403,19 @@ def gen_check(handle_kind, idx, name, tensor): """ ) + # Create a separate function for each input check to avoid "too big to optimize" error + for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): + self.prefix.splice( + f""" + AOTI_NOINLINE static void check_input_{idx}( + AtenTensorHandle* input_handles + ) {{ + """ + ) + with self.prefix.indent(): + gen_check("input_handles", idx, name, tensor) + self.prefix.writeline("}") + # force noinline to avoid any potential compilation slowdown due to aggressive # inline done by the host compiler self.prefix.splice( @@ -422,8 +435,8 @@ def gen_check(handle_kind, idx, name, tensor): """ ) with self.prefix.indent(): - for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): - gen_check("input_handles", idx, name, tensor) + for idx in range(len(V.graph.graph_inputs)): + self.prefix.writeline(f"check_input_{idx}(input_handles);") self.prefix.writeline("}") def write_wrapper_decl(self): @@ -475,13 +488,10 @@ def write_wrapper_decl(self): DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor ) { + __check_inputs_outputs(input_handles, output_handles); """ self.generate_input_output_runtime_checks() - run_impl_proto += """ - __check_inputs_outputs(input_handles, output_handles); - """ - self.prefix.splice(run_impl_proto) else: # cpp entry function for JIT with cpp wrapper @@ -1107,11 +1117,25 @@ def generate_c_shim_extern_kernel_call( debug_printer_manager.set_printer_args( debug_args if debug_args is not None else args, kernel, None, None, "extern" ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel, device) - self.writeline( + shim_fn_codes = ( f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" ) + if enable_kernel_profile: + shim_fn_codes = textwrap.dedent( + f""" + {{ + RECORD_FUNCTION("{shim_fn}", c10::ArrayRef()); + {shim_fn_codes} + }} + """ + ) + self.writeline(shim_fn_codes) def generate_c_shim_extern_kernel_alloc( self, extern_kernel: ir.ExternKernelAlloc, args: list[str] @@ -1558,17 +1582,12 @@ def create_dtypeview_call(reinterpret_call: str) -> tuple[str, list[str]]: return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs def create_new_tensor_handle() -> tuple[str, list[str]]: - # TODO (benjaminglass1): uncomment this and remove the call to - # create_reinterpret_view after the AOTI forwards compatibility window has - # passed. - # - # tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" - # tmp_call_strs = [ - # f"AtenTensorHandle {tmp_AtenTensorHandle};", - # f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", - # ] - # return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs - return create_reinterpret_call(), [] + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs if ( size == data.layout.size @@ -2041,11 +2060,11 @@ def load_custom_op_wrapper(self): lines = """ RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); -if (codecache_module.get() == NULL) { +if (!codecache_module) { throw std::runtime_error("Failed to load torch._inductor.codecache"); } custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); -if (custom_op_wrapper.get() == NULL) { +if (!custom_op_wrapper) { throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); }""" @@ -2070,11 +2089,6 @@ def generate_float_value(self, val): def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): def generate_py_arg_inner(lines, raw_arg, arg_type): - def add_py_newref(): - if sys.version_info < (3, 10): - # Py_NewRef is only available since Python 3.10 - self.include_extra_header("torch/csrc/utils/pythoncapi_compat.h") - def handle_scalar(scalar): if isinstance(scalar, int): return f"PyLong_FromLongLong({scalar})" @@ -2135,24 +2149,13 @@ def handle_scalar(scalar): # torch/_prims_common/__init__.py return handle_scalar(raw_arg) elif isinstance(raw_arg, torch.device): - # device - self.include_extra_header("torch/csrc/Device.h") device_str, device_index = self.codegen_device(raw_arg).split(", ") return f"THPDevice_New(c10::Device(static_cast({device_str}), {device_index}))" elif isinstance(raw_arg, torch.dtype): - # dtype - add_py_newref() - self.include_extra_header("torch/csrc/DynamicTypes.h") return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" elif isinstance(raw_arg, torch.layout): - # memory layout - add_py_newref() - self.include_extra_header("torch/csrc/DynamicTypes.h") return f"Py_NewRef(torch::getTHPLayout(static_cast({self.codegen_layout(raw_arg)})))" elif isinstance(raw_arg, torch.memory_format): - # memory_format - add_py_newref() - self.include_extra_header("torch/csrc/utils/tensor_memoryformats.h") return ( "Py_NewRef(torch::utils::getTHPMemoryFormat(static_cast(" f"{self.codegen_memory_format(raw_arg)})))" @@ -2204,7 +2207,7 @@ def generate_fallback_kernel_with_runtime_lookup_jit( lines = textwrap.dedent( f""" RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1})); - if ({py_args_var}.get() == NULL) {{ + if (!{py_args_var}) {{ throw std::runtime_error("PyTuple_New {py_args_var} failed"); }} PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); @@ -2224,7 +2227,7 @@ def generate_fallback_kernel_with_runtime_lookup_jit( f""" // Call the custom op in Python RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); - if (py_{buf_name}.get() == NULL) {{ + if (!py_{buf_name}) {{ if (PyErr_Occurred()) {{ return; }} diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index a3e472834518..67ea2e2166e8 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -835,16 +835,12 @@ def create_new_tensor_handle() -> tuple[str, list[str]]: if (name := data.get_name()) in self.stack_allocated_buffers: return name, [] - # TODO (benjaminglass1): uncomment this and remove create_reinterpret_view - # after the AOTI forwards compatibility window has passed. - # - # tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" - # tmp_call_strs = [ - # f"AtenTensorHandle {tmp_AtenTensorHandle};", - # f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", - # ] - # return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs - return create_reinterpret_call(), [] + tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs if ( size == data.layout.size diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 56f0941715a9..e0f0726e7a89 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -54,6 +54,7 @@ class DeferredTritonCallWrapper: wrapper_name: str kernel_name: str + kernel_name_to_body: dict[str, str] arg_types: list[Any] def generate(self, wrapper: CppWrapperGpu): @@ -122,6 +123,11 @@ def generate(self, wrapper: CppWrapperGpu): ) prefix.writeline("){") with prefix.indent(): + if V.graph.aot_mode: + # Emit the original Triton kernel for debugging purposes + prefix.writeline("/*") + prefix.splice(self.kernel_name_to_body[self.kernel_name]) + prefix.writeline("*/") self.generate_grid(prefix, inductor_meta, params) self.generate_load_kernel(prefix, kernel_var_name, params) self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) @@ -205,6 +211,7 @@ def __init__(self) -> None: self.device_codegen = get_device_op_overrides(self.device) super().__init__() self.grid_id = count() + self._kernel_name_to_body: dict[str, str] = {} self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {} self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT" @@ -296,6 +303,7 @@ def define_kernel( cpp_definition: Optional[str] = None, ): if gpu: + self._kernel_name_to_body[kernel_name] = kernel_body if config.triton.autotune_at_compile_time: # Call PythonWrapperCodegen to create the autotune code block PythonWrapperCodegen.define_kernel( @@ -432,7 +440,7 @@ def process_args(arg, arg_type, arg_signature=None): is not None ): global_scratch_def, global_scratch_var = global_scratch - code.writeline(global_scratch_def) + code.writeline(maybe_hipify_code_wrapper(global_scratch_def)) new_args.append(f"&{global_scratch_var}") return ", ".join(new_args) @@ -502,7 +510,10 @@ def generate_kernel_call( wrapper_name = f"call_{kernel_name}" if wrapper_name not in self._triton_call_wrappers: self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper( - wrapper_name, kernel_name, arg_types + wrapper_name, + kernel_name, + self._kernel_name_to_body, + arg_types, ) call_args.append(stream) if V.graph.aot_mode: diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 9f4469698207..28dbbfb446ba 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -96,6 +96,8 @@ def _print_floor(self, expr): assert len(expr.args) == 1 return self.cast_index(f"hl.floor({self._print(expr.args[0])})") + _print_FloorToInt = _print_floor + def _print_Trunc(self, expr): assert len(expr.args) == 1 return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") @@ -140,39 +142,42 @@ def _print_Abs(self, expr): def _print_OpaqueUnaryFn_cos(self, expr): assert len(expr.args) == 1 - return f"hl.cos(({self._print(expr.args[0])})" + return f"hl.cos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_cosh(self, expr): assert len(expr.args) == 1 - return f"hl.cosh(({self._print(expr.args[0])})" + return f"hl.cosh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_acos(self, expr): assert len(expr.args) == 1 - return f"hl.acos(({self._print(expr.args[0])})" + return f"hl.acos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sin(self, expr): assert len(expr.args) == 1 - return f"hl.sin(({self._print(expr.args[0])})" + return f"hl.sin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sinh(self, expr): assert len(expr.args) == 1 - return f"hl.sinh(({self._print(expr.args[0])})" + return f"hl.sinh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_asin(self, expr): assert len(expr.args) == 1 - return f"hl.asin(({self._print(expr.args[0])})" + return f"hl.asin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tan(self, expr): assert len(expr.args) == 1 - return f"hl.tan(({self._print(expr.args[0])})" + return f"hl.tan({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tanh(self, expr): assert len(expr.args) == 1 - return f"hl.tanh(({self._print(expr.args[0])})" + return f"hl.tanh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 - return f"hl.atan(({self._print(expr.args[0])})" + return f"hl.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr): + raise NotImplementedError("log2") def _print_FloorDiv(self, expr): if expr.is_integer: @@ -453,6 +458,10 @@ def pow(a, b): def log(x): return f"hl.log({x})" # hl.fast_log fails accuracy + @staticmethod + def log2(x): + raise NotImplementedError("log2") + @staticmethod def isinf(x): # workaround https://github.com/halide/Halide/issues/8309 diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index a5ea219eb037..c83c572e6e94 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -13,7 +13,7 @@ from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_ from torch.utils._sympy.value_ranges import ValueRanges -from ..utils import get_bounds_index_expr, get_kernel_metadata +from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata from ..virtualized import ops, OpsWrapper, V from .common import ( CSEVariable, @@ -298,6 +298,12 @@ def atan2(x: CSEVariable, y: CSEVariable) -> str: def sqrt(x: CSEVariable) -> str: return f"metal::sqrt({x})" + @staticmethod + def neg(x: CSEVariable) -> str: + # TODO: Does it rely on undefined behavior? + # If so, add special logic for unsigned types + return f"static_cast(-{x})" + @staticmethod def rsqrt(x: CSEVariable) -> str: return f"metal::rsqrt({x})" @@ -443,6 +449,10 @@ def chebyshev_polynomial_v(x: CSEVariable, n: CSEVariable) -> str: def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str: return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})" + @staticmethod + def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str: + return f"c10::metal::hermite_polynomial_h_forward({x}, {n})" + MetalOverrides._initialize_pointwise_overrides("mps") @@ -452,6 +462,7 @@ class MetalKernel(SIMDKernel): suffix = ";" newvar_prefix = "auto " max_threadgroup_size = 1024 + simd_group_size = 32 pexpr = PythonPrinter().doprint sexpr = MetalExprPrinter().doprint kexpr = sexpr @@ -487,22 +498,33 @@ def store( else: self.stores.writeline(DeferredLine(name, line)) - def _new_accvar( + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + reduction_dim = next(t for t in self.range_trees if t.is_reduction) + # Only one thread in the reduction group needs to store the results + line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" + line = f"if ({reduction_dim.name} == 0) {line}" + self.stores.writeline(DeferredLine(name, line)) + + def _new_idxvar( self, dtype: torch.dtype, elem_count: Optional[int] = None, + default_value: Optional[Any] = None, + is_threadgroup: bool = True, bounds: ValueRanges[Any] = ValueRanges.unknown(), ) -> CSEVariable: var_name = f"tmp_acc_{next(self.acc_var_ids)}" var = V.kernel.create_cse_var(var_name, bounds, dtype) + var_def = "threadgroup " if is_threadgroup else "" + var_def += f"{self.dtype_to_str(dtype)} {var_name}" if elem_count: - self.indexing_code.writeline( - f"threadgroup {self.dtype_to_str(dtype)} {var_name}[{elem_count}];" - ) - else: - self.indexing_code.writeline( - f"threadgroup {self.dtype_to_str(dtype)} {var_name};" - ) + var_def += f"[{elem_count}]" + if default_value is not None: + var_def += f" = {default_value}" + self.indexing_code.writeline(var_def + self.suffix) return var def reduction( @@ -513,10 +535,20 @@ def reduction( value: Union[CSEVariable, tuple[CSEVariable, ...]], ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: """Codegen a reduction operation""" - reduction_dim = next(t for t in self.range_trees if t.is_reduction) - acc_buf_size = min(reduction_dim.numel, self.max_threadgroup_size) + # Establish reduction buffer size and index expression + reduction_idx = "" + acc_buf_size = 1 + for rd in self.range_trees: + if not rd.is_reduction: + continue + if reduction_idx: + reduction_idx += " + " + reduction_idx += f"{rd.name} * {acc_buf_size}" + acc_buf_size *= rd.numel + acc_buf_size = min(acc_buf_size, self.max_threadgroup_size) + if reduction_type == "any": - acc = self._new_accvar(dtype) + acc = self._new_idxvar(dtype) self.indexing_code.writeline(f"{acc} = false;") self.indexing_code.writeline( "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" @@ -533,27 +565,28 @@ def reduction( ) return acc if reduction_type in ["prod", "sum"]: - acc_buf = self._new_accvar(src_dtype, acc_buf_size) - if self.multistage_reduction: + acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype] + acc_buf = self._new_idxvar( + acc_dtype, ceildiv(acc_buf_size, self.simd_group_size) + ) + if not self.multistage_reduction: + val = value + else: default_val, reduction_op = ( (0, "+") if reduction_type == "sum" else (1, "*") ) - self.indexing_code.writeline( - f"{acc_buf}[{reduction_dim.name}] = {default_val};" - ) - self.compute.splice( - f"{acc_buf}[{reduction_dim.name}] {reduction_op}= {value};" + val = self._new_idxvar( + acc_dtype, default_value=default_val, is_threadgroup=False ) - else: - self.compute.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") + self.compute.splice(f"{val} {reduction_op}= {value};") return self.cse.generate( self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], ) if reduction_type in ["max", "min", "argmin", "argmax"]: - acc_buf = self._new_accvar(src_dtype, acc_buf_size) - acc_thread_var = f"{acc_buf}[{reduction_dim.name}]" + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" src_metal_type = DTYPE_TO_METAL[src_dtype] if not self.multistage_reduction: self.compute.splice( @@ -572,9 +605,9 @@ def reduction( idx_var = next( t for t in self.range_tree_nodes.values() if t.is_reduction ) - idx_acc_buf = self._new_accvar(torch.long, acc_buf_size) + idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) cmp_op = ">" if reduction_type == "argmax" else "<" - idx_thread_var = f"{idx_acc_buf}[{reduction_dim.name}]" + idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]" self.indexing_code.splice(f"{idx_thread_var} = -1;") self.compute.splice(f""" if ({value} {cmp_op} {acc_thread_var}) {{ @@ -599,8 +632,8 @@ def reduction( assert not self.multistage_reduction, ( f"Multistage reduction not yet supported for {reduction_type}" ) - acc_buf = self._new_accvar(src_dtype, acc_buf_size) - self.compute.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") wf_res = self.cse.generate( self.compute, f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index f33c39623acb..db8091c78648 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -18,6 +18,7 @@ import torch import torch._logging +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.immutable_collections import immutable_dict from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing @@ -1764,7 +1765,11 @@ def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: return tilings pointwise_ranges, reduction_ranges = node.get_ranges() - if len(pointwise_ranges) <= 1 and len(reduction_ranges) <= 1: + if ( + len(pointwise_ranges) <= 1 + and len(reduction_ranges) <= 1 + or free_unbacked_symbols(pointwise_ranges + reduction_ranges) + ): return [] # Tile either pointwise or reduction dims. @@ -2013,7 +2018,11 @@ def convert_tiling_to_3d( ) -> Optional[dict[str, sympy.Expr]]: a0, a1 = tiling0["x"], tiling0.get("y", 1) b0, b1 = tiling1["x"], tiling1.get("y", 1) - if V.graph.sizevars.size_hint(a1 - b1) == 0: + + if ( + free_unbacked_symbols([a1, b1]) + or V.graph.sizevars.size_hint(a1 - b1) == 0 + ): return None if V.graph.sizevars.size_hint(a1 - b1) < 0: # swap so a0 is bigger diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 2fe5c304c24a..b125efd6bdbf 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -605,7 +605,12 @@ def _print_FloatPow(self, expr: sympy.Expr) -> str: f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" ) - _print_PowByNatural = _print_FloatPow + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + if expr.args[0].is_Integer: + return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})" + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) def _print_Where(self, expr: sympy.Expr) -> str: c = self.doprint(expr.args[0]) @@ -678,6 +683,10 @@ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))" + def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return ( @@ -924,7 +933,11 @@ def _shaped_constant(value, dtype, shape): # NOTE: We use a tensor here in order to get the expected type. # Otherwise, e.g. float64 constants would be trunctated to float32. - return f"tl.full({shape}, {triton_val}, {triton_type})" + if value < 0 and not dtype.is_signed: + triton_signed_type = f"tl.{triton_type[4:]}" + return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})" + else: + return f"tl.full({shape}, {triton_val}, {triton_type})" @classmethod def constant(cls, value, dtype): @@ -2543,17 +2556,19 @@ def _mask_value(value, default) -> CSEVariable: masked_value = _mask_value(value, default) if reduction_type in ("argmax", "argmin"): + accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() accumulator_index = str( self.cse.generate( self.compute, f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", - dtype=V.kernel.get_index_dtype_as_torch_dtype(), + dtype=accumulator_dtype, ) ) root_op = {"argmax": "max", "argmin": "min"}[reduction_type] final_argreduce( self.compute, result_var, masked_value, accumulator_index ) + result_var.dtype = accumulator_dtype elif reduction_type == "welford_reduce": if self.cooperative_reduction: # cooperative reductions require full welford for correctness diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2d5f6a55b4cc..ddd4ec515516 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -122,9 +122,14 @@ def signature_to_meta( def is_unaligned_buffer(arg: TensorArg): buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + if buf_name in V.graph.graph_inputs: # See Note: [Input Alignment handling in Inductor] - return buf_name not in V.graph.aligned_inputs + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False if buf_name in V.graph.constants: # all constants are assumed to be aligned diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index bfb78977dba4..c10831bd8278 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -78,12 +78,13 @@ pexpr = PythonPrinter().doprint -ReuseKey = tuple[torch.device, torch.dtype, str] +ReuseKey = tuple[torch.device, torch.dtype, str, bool] BufferLike = Union[ir.Buffer, WorkspaceArg] def buffer_reuse_key(node: BufferLike) -> ReuseKey: storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers return ( node.get_device_or_error(), node.get_dtype(), @@ -91,6 +92,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey: # for s0 for s1, just because they happen to share the same # size hint sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, ) @@ -620,6 +622,7 @@ def __init__(self): # Map key is the kernel argument name; value is a tuple of the resulting example # tensor name with the kernel where that tensor was most recently used. self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {} + self.kernel_autotune_tmp_arg_idx: int = 0 # If the generated source code is exactly the same, reuse the # pre-existing kernel for it self.src_to_kernel: dict[str, str] = {} @@ -1991,7 +1994,7 @@ def wrap_arg(arg): return [wrap_arg(arg) for arg in call_args] - def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): + def generate_example_arg_value(self, arg, arg_type, raw_arg=None): if isinstance(arg_type, torch_dtype): if isinstance(raw_arg, ir.TMADescriptor): # first we generate the underlying buffer @@ -2004,8 +2007,9 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): assert raw_arg is not None, ( "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" ) - buf_name = f"tmp_arg_{index}" + buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}" buf = raw_arg + self.kernel_autotune_tmp_arg_idx += 1 size = tuple( V.graph.sizevars.atomically_apply_size_hint( @@ -2182,13 +2186,13 @@ def get_autotune_deletion_call() -> str: arg_str = arg elif arg not in self.kernel_autotune_example_args: arg_str = self.generate_example_arg_value( - arg, arg_type, raw_arg, i + arg, arg_type, raw_arg ) else: arg_str = self.kernel_autotune_example_args[arg][0] self.kernel_autotune_example_args[arg] = (arg_str, kernel_name) else: - arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i) + arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) all_args.append(arg_str if key is None else f"{key}={arg_str}") self.kernel_autotune_calls.writeline( diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index a94d224a2d8a..ba77e78240a2 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -620,15 +620,6 @@ def compile_fx_inner( dynamo_compile_column_us="inductor_cumulative_compile_time_us", ) ) - # NB: Why is this the dynamo_compile counter? The rule here is that - # if it gets an entry in the dynamo_compile table, we also want to - # tick up the wait counter. We have to displeasingly manually trigger - # the counter here because we may dropped into compile_fx directly - # from lazy backwards compilation. - stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()) - stack.enter_context( - _WaitCounter("pytorch.wait_counter.all_compilation_types").guard() - ) if torch._dynamo.callback_handler.prevent_duplicate_callbacks: stack.enter_context(torch._dynamo.callback_handler.install_callbacks()) @@ -660,6 +651,10 @@ def _compile_fx_inner( """ aot_mode: bool = V.aot_compilation + # Clean up Compiled Triton Kernels per inductor compile, as the future objects + # may not be valid for use after they are run/autotuned + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: # trigger the real recompilation for _LazyGraphModule before returning # the forward method. @@ -691,7 +686,6 @@ def _compile_fx_inner( with ( _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _, - _WaitCounter("pytorch.wait_counter.all_compilation_types").guard(), ): use_cache = ( not config.force_disable_caches @@ -876,7 +870,8 @@ def _compile_fx_inner( if log.isEnabledFor(logging.INFO): mm_table_data = [] for key, value in counters["aten_mm_info"].items(): - name, m, n, k = key.split("_") + m, n, k = key.split("_")[-3:] + name = "_".join(key.split("_")[:-3]) mm_table_data.append([name, m, n, k, value]) log.info("Overview info of inductor aten mms: ") log.info( @@ -889,8 +884,8 @@ def _compile_fx_inner( log.info("{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 log.info("-" * 100) - # Clear Compiled Triton Kernels per inductor compile, as the future objects - # may not be valid for use after they are run/autotuned + # Not strictly necessary, but good to clean up straggling futures + # that are unused to reclaim memory. torch._inductor.async_compile.CompiledTritonKernels.cache_clear() _step_logger()( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 10904cd53991..27b77d199f09 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -500,6 +500,9 @@ def use_autoheuristic(name: str) -> bool: # automatically create fallbacks when encountering an unhandled op implicit_fallbacks = True +assume_unaligned_fallback_output = ( + os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" +) # fuse even in cases without common reads aggressive_fusion = False @@ -1129,7 +1132,7 @@ class triton: ) # type: ignore[assignment] # hint to Triton when arguments are divisible by 16 - divisible_by_16 = True + divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" # Minimum R0_BLOCK to be used for a TritonSplitScanKernel # NOTE: This also indirectly controls the size of workspace buffer required @@ -1165,7 +1168,7 @@ class triton: # Whether persistent matmul kernels should be enabled this flag only has effect when on h100 # with a verison of triton new enough to support TMA enable_persistent_tma_matmul = ( - os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "1") == "1" + os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1" ) # Skip L1 cache for buffers that are used only once. Disabled by default skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1" diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index a8f25056dd52..aeef51ae6cc0 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -838,8 +838,13 @@ def _get_python_related_args() -> tuple[list[str], list[str]]: python_include_dirs.append(python_include_path) if _IS_WINDOWS: - python_path = os.path.dirname(sys.executable) - python_lib_path = [os.path.join(python_path, "libs")] + python_lib_path = [ + str( + ( + Path(sysconfig.get_path("include", scheme="nt")).parent / "libs" + ).absolute() + ) + ] else: python_lib_path = [sysconfig.get_config_var("LIBDIR")] diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 68ea4a010e6e..f6ce7e43ad95 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -167,6 +167,9 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]: def check_multiple_devices_or_any_cpu_nodes( device_node_mapping: dict[torch.device, torch.fx.Node], ) -> Optional[str]: + # meta tensors are supported since there is no compute + device_node_mapping.pop(torch.device("meta"), None) + if torch._inductor.config.graph_partition: # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index ddf044ebf1ff..2dd8a47feb4a 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -2,6 +2,7 @@ import functools import logging import math +import operator import sys import typing from typing import Any, Callable, Optional, TypeVar, Union @@ -261,7 +262,9 @@ def bmm( self: torch.Tensor, batch2: torch.Tensor, ) -> torch.Tensor: - if config.coordinate_descent_tuning and self.device.type != "cpu": + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious( batch2.shape[2] == 1 ): @@ -315,7 +318,10 @@ def mm( ) -> torch.Tensor: # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. # todo: Look into why and fix it (hopefully) - if config.coordinate_descent_tuning and self.device.type != "cpu": + + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious( input2.shape[1] == 1 ): @@ -963,38 +969,40 @@ def index_reduce( ) -@register_decomposition(aten.max_pool2d_with_indices) -def max_pool2d_with_indices( +def _max_pool_with_indices( x: torch.Tensor, kernel_size: list[int], - stride: Optional[Union[int, list[int]]] = None, - padding: Union[int, list[int]] = 0, - dilation: Union[int, list[int]] = 1, - ceil_mode: bool = False, + stride: Optional[Union[int, list[int]]], + padding: Union[int, list[int]], + dilation: Union[int, list[int]], + ceil_mode: bool, + dim: int, ) -> tuple[torch.Tensor, torch.Tensor]: if dilation == 1: - dilation = [1, 1] + dilation = [1] * dim if padding == 0: - padding = [0, 0] + padding = [0] * dim if not stride: stride = kernel_size - kernel_size = pad_listlike(kernel_size, 2) - dilation = pad_listlike(dilation, 2) - padding = pad_listlike(padding, 2) - stride = pad_listlike(stride, 2) + kernel_size = pad_listlike(kernel_size, dim) + dilation = pad_listlike(dilation, dim) + padding = pad_listlike(padding, dim) + stride = pad_listlike(stride, dim) - window_size = kernel_size[0] * kernel_size[1] - # We fallback when the window size is too large + window_size = functools.reduce(operator.mul, kernel_size) + # We fallback when using non-default dilation or when the window size is too large if ( - torch._inductor.lowering.should_fallback_max_pool2d_with_indices(kernel_size) + torch._inductor.lowering.should_fallback_max_pool_with_indices( + kernel_size, n_dim=dim + ) or window_size > torch.iinfo(torch.int8).max ): return NotImplemented - vals, offsets = prims._low_memory_max_pool2d_with_offsets( + vals, offsets = prims._low_memory_max_pool_with_offsets( x, kernel_size, stride, @@ -1002,10 +1010,10 @@ def max_pool2d_with_indices( dilation, ceil_mode, ) - indices = prims._low_memory_max_pool2d_offsets_to_indices( + indices = prims._low_memory_max_pool_offsets_to_indices( offsets, - kernel_size[1], - x.size(-1), + kernel_size, + x.shape[-dim:], stride, padding, dilation, @@ -1013,6 +1021,34 @@ def max_pool2d_with_indices( return vals, indices +@register_decomposition(aten.max_pool2d_with_indices) +def max_pool2d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=2 + ) + + +@register_decomposition(aten.max_pool3d_with_indices) +def max_pool3d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=3 + ) + + @register_decomposition(aten.adaptive_max_pool2d) def adaptive_max_pool2d( x: torch.Tensor, output_size: list[int] diff --git a/torch/_inductor/extern_node_serializer.py b/torch/_inductor/extern_node_serializer.py index ffd390152034..19bf39fdd2e7 100644 --- a/torch/_inductor/extern_node_serializer.py +++ b/torch/_inductor/extern_node_serializer.py @@ -1,6 +1,6 @@ import json -from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode @@ -19,6 +19,7 @@ def extern_node_json_serializer( extern_kernel_nodes: list[inductor_ExternKernelNode], ) -> str: serialized_nodes = ExternKernelNodes( - nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes], + protocol="json", ) - return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder, indent=2) diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 4a42b71559c5..a95d8419d662 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -174,6 +174,7 @@ def failing(self) -> bool: "autoheuristic_collect": ["pad_mm", "mixed_mm"], "autoheuristic_use": ["pad_mm", "mixed_mm"], "traceable_tensor_subclasses": [OrderedSet()], + "nontraceable_tensor_subclasses": [OrderedSet()], } SamplingType = Callable[[str, type[Any], Any], Any] @@ -499,6 +500,7 @@ def keys(self) -> KeysView[ComboType]: }, "torch._dynamo.config": { "traceable_tensor_subclasses": DEFAULT, # Typing + "nontraceable_tensor_subclasses": DEFAULT, # Typing "compiled_autograd_kwargs_override": DEFAULT, # Typing "fail_on_recompile_limit_hit": DEFAULT, # fails in combo with suppress_errors "suppress_errors": DEFAULT, diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index f2ab19dd720f..327e15cce92c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -197,9 +197,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass( decompose_auto_functionalized ) - GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( - comms.reinplace_fsdp_all_gather - ) + if not torch._dynamo.config.skip_fsdp_hooks: + GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( + comms.reinplace_fsdp_all_gather + ) GraphTransformObserver(gm, "lower_scan_to_while_loop").apply_gm_pass( lower_scan_to_while_loop ) diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index e1dff0162cb5..8df1c1e1f2a6 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -163,9 +163,9 @@ def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): ) -def get_qconv2d_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(users=1): return CallFunction( - torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -345,13 +345,13 @@ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_val return actual_value == expected_value -def _is_valid_quantized_conv2d_optimization_pattern(): +def _is_valid_quantized_conv_optimization_pattern(): def fn(match): output_dtype = _get_pattern_output_dtype(match) if output_dtype in [torch.float32, torch.bfloat16]: # Only keep matched pattern with same output_dtype qconv_node_after_weight_prepack = filter_nodes( - match.nodes, torch.ops.onednn.qconv2d_pointwise + match.nodes, torch.ops.onednn.qconv_pointwise )[0] return _check_node_kwarg_arg_value( qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype @@ -365,7 +365,7 @@ def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): return ( _is_valid_qconv_binary_optimization_pattern() if has_binary_post_op - else _is_valid_quantized_conv2d_optimization_pattern() + else _is_valid_quantized_conv_optimization_pattern() ) @@ -374,8 +374,8 @@ def fn(match): if len(match.nodes) != 1: return False return match.nodes[0].target in ( - torch.ops.onednn.qconv2d_pointwise.default, - torch.ops.onednn.qconv2d_pointwise.tensor, + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv_pointwise.tensor, torch.ops.onednn.qconv2d_pointwise.binary, torch.ops.onednn.qconv2d_pointwise.binary_tensor, ) @@ -444,8 +444,8 @@ def qconv(match: Match, *args, **kwargs): postop_args, postop_algorithm, ) - counters["inductor"]["qconv2d_unary_lower_count"] += 1 - counters["inductor"]["qconv2d_unary_lower_nodes"] += len(match.nodes) + counters["inductor"]["qconv_unary_lower_count"] += 1 + counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qconv @@ -630,7 +630,7 @@ def qlinear_binary(match: Match, *args, **kwargs): def _is_valid_qconv_binary_optimization_pattern(): return _is_valid_quantized_op_binary_optimization_pattern( - torch.ops.onednn.qconv2d_pointwise + torch.ops.onednn.qconv_pointwise ) @@ -801,11 +801,11 @@ def qconv_binary(match: Match, *args, **kwargs): def _register_quantization_unary_lowering(): # QConv2d for users in [1, 2]: - qconv_pattern = get_qconv2d_pt2e_pattern(users) + qconv_pattern = get_qconv_pt2e_pattern(users) _register_quantized_conv_lowering( qconv_pattern, 2, # pass_number - torch.ops.onednn.qconv2d_pointwise.default, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op ) # QLinear @@ -940,7 +940,7 @@ def _register_quantization_maxpool2d(): *max_pool2d_args, ) dequantize_lowmem_maxpool2d_pattern = CallFunction( - prims._low_memory_max_pool2d_with_offsets.default, + prims._low_memory_max_pool_with_offsets.default, get_dequantize_per_tensor_activation_pattern(), KeywordArg("kernel_size"), *max_pool2d_args, @@ -1375,7 +1375,7 @@ def _find_first_node_in_dequant_pattern(_node): counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) -def _is_valid_dequant_conv2d_pattern(dtype): +def _is_valid_dequant_conv_pattern(dtype): def _inner(match): # Here we do some further check to ensure: # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. @@ -1390,9 +1390,9 @@ def _inner(match): if ( meta_value is None or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") - or meta_value.dim() != 4 + or meta_value.dim() not in [3, 4] ): - # Only support conv2d now + # Only support conv1d/2d now return False assert dtype in [torch.float32, torch.bfloat16] @@ -1415,7 +1415,7 @@ def _inner(match): def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): @register_freezing_graph_pattern( pattern, - extra_check=_is_valid_dequant_conv2d_pattern(dtype), + extra_check=_is_valid_dequant_conv_pattern(dtype), pass_number=pass_number, ) def qconv_weight_prepack(match: Match, *args, **kwargs): @@ -1430,7 +1430,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): Insert weight prepack node and change the pattern to: int8 activation | - onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight + onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight """ assert dtype in [torch.float32, torch.bfloat16] conv_node = match.output_node() @@ -1532,7 +1532,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): "", # algorithm ) new_conv_node = graph.call_function( - torch.ops.onednn.qconv2d_pointwise.default, args=new_args + torch.ops.onednn.qconv_pointwise.default, args=new_args ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -1549,8 +1549,8 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] graph.erase_node(dequant_per_channel) # type: ignore[arg-type] - counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( + counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes ) @@ -2803,12 +2803,12 @@ def qconv(match: Match, *args, **kwargs): count_key = ( "qconv2d_binary_matcher_count" if has_binary_post_op - else "qconv2d_unary_matcher_count" + else "qconv_unary_matcher_count" ) nodes_key = ( "qconv2d_binary_matcher_nodes" if has_binary_post_op - else "qconv2d_unary_matcher_nodes" + else "qconv_unary_matcher_nodes" ) counters["inductor"][count_key] += 1 counters["inductor"][nodes_key] += len(match.nodes) @@ -2828,13 +2828,13 @@ def _register_qconv_unary_fusion(): PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv2d_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(1), aten.relu.default ), ), PostOpAttr( @@ -2842,7 +2842,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), 1, is_bf16, ), @@ -2853,7 +2853,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2864,7 +2864,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2877,21 +2877,21 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.default, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv2d_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), 1, is_bf16, ), @@ -2903,7 +2903,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2915,7 +2915,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2929,7 +2929,7 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv2d_pointwise.default, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) @@ -2947,7 +2947,7 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2959,7 +2959,7 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2986,7 +2986,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -3024,7 +3024,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index a4d6f482e25d..ee258dfd4158 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -21,7 +21,7 @@ ) from torch._inductor.virtualized import V from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode -from torch.fx.immutable_collections import immutable_dict +from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.passes.reinplace import _is_view_op from torch.utils import _pytree as pytree from torch.utils._ordered_set import OrderedSet @@ -720,6 +720,14 @@ def tensor_with_same_storage_already_reinplaced(arg): kwargs = dict(node.kwargs) kwargs["tensors_to_clone"] = tensors_to_clone node.kwargs = immutable_dict(kwargs) + if "eager_input_vals" in node.meta: + # We changed the kwargs, so we need to update eager_input_vals + # to something sane. + args, kwargs = node.meta["eager_input_vals"] + new_kwargs = {**kwargs} + new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone) + new_kwargs = immutable_dict(new_kwargs) + node.meta["eager_input_vals"] = (args, new_kwargs) elif ( inplaceable_op := inplaceable_foreach_ops.get(node.target, None) ) is not None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index a2e668d698b9..ca989c431aa7 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -366,9 +366,7 @@ def __init__( from torch._inductor.extern_node_serializer import extern_node_json_serializer self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = ( - extern_node_serializer - if config.is_fbcode() and extern_node_serializer - else extern_node_json_serializer + extern_node_json_serializer ) self.current_node: torch.fx.Node = None # type: ignore[assignment] @@ -430,7 +428,10 @@ def __init__( self.get_backend_features = functools.lru_cache(None)(get_backend_features) self.effectful_ops: dict[_EffectType, ir.Buffer] = {} - self.aligned_inputs: OrderedSet[str] = OrderedSet() + # Track the buffers that we know is unaligned + # This can either be a graph input or the output of fallback + # kernels. + self.unaligned_buffers: OrderedSet[str] = OrderedSet() self.no_fuse_buffer_names = OrderedSet[str]() self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() @@ -1116,8 +1117,8 @@ def placeholder( # expensive and cause recompiles; Instead, we're generating code # based on the alignment of the example input without guarding. with maybe_get_suppress_shape_guards_ctx(): - if should_assume_input_aligned(example): - self.aligned_inputs.add(target) + if not should_assume_input_aligned(example): + self.unaligned_buffers.add(target) return tensor def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override] @@ -1151,7 +1152,9 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> # use contiguous unless the (custom) op asks something else # explicitly - if torch._C.Tag.needs_fixed_stride_order in target.tags: + if torch._C.Tag.needs_exact_strides in target.tags: + decided_constraint = constrain_to_fake_tensors # type: ignore[assignment] + elif torch._C.Tag.needs_fixed_stride_order in target.tags: decided_constraint = constrain_to_fx_strides # type: ignore[assignment] elif torch._C.Tag.flexible_layout in target.tags: decided_constraint = None # type: ignore[assignment] @@ -1192,7 +1195,34 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> layout_constraints = maybe_layout_constraints(target) if layout_constraints: old_args, old_kwargs = args, kwargs - args, kwargs = layout_constraints(n, *args, **kwargs) + if layout_constraints is constrain_to_fake_tensors: + # only constrain_to_fake_tensor if this exists. + # otherwise, no constraints at all: the implication is + # that this operator was inserted by a custom pass + # so we'll give them the freedom. + if "eager_input_vals" in n.meta: + fake_args, fake_kwargs = n.meta["eager_input_vals"] + + # (fake_args, fake_kwargs) might not align with (args, kwargs). + # we need to normalize them based on the schema + assert isinstance(target, torch._ops.OpOverload) + + def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]: + result = torch.fx.operator_schemas.normalize_function( + target, args, kwargs + ) + assert result is not None + return result[0], result[1] + + fake_args, fake_kwargs = normalize(fake_args, fake_kwargs) + args, kwargs = normalize(args, kwargs) + old_args, old_kwargs = normalize(old_args, old_kwargs) + + args, kwargs = constrain_to_fake_tensors( + args, kwargs, fake_args, fake_kwargs + ) + else: + args, kwargs = layout_constraints(n, *args, **kwargs) out = lowerings[target](*args, **kwargs) # type: ignore[index] @@ -1506,9 +1536,9 @@ def debug(msg: str) -> None: old_args = args # type: ignore[possibly-undefined] old_kwargs = kwargs # type: ignore[possibly-undefined] - if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"): - inp_args = arg_kwarg_vals[0] - inp_kwargs = arg_kwarg_vals[1] + if eager_input_vals := n.meta.get("eager_input_vals"): + inp_args = eager_input_vals[0] + inp_kwargs = eager_input_vals[1] args, kwargs = constrain_to_fake_tensors( args, kwargs, inp_args, inp_kwargs ) @@ -1662,7 +1692,7 @@ def debug(msg: str) -> None: torch.ops.mkldnn._convolution_pointwise.binary, torch.ops.mkldnn._convolution_pointwise_.binary, torch.ops.mkldnn._convolution_transpose_pointwise.default, - torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, torch.ops.onednn.qconv2d_pointwise.binary, ] if torch._C.has_mkl: diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 170c2f00d44a..d764744d857a 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs from __future__ import annotations +import functools import logging +import operator from typing import Optional, TYPE_CHECKING import torch @@ -119,7 +121,28 @@ def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: ) -def _low_memory_max_pool2d_with_offsets_aten( +def _flattened_index_to_nd(indices, width): + dim = len(width) + + if dim == 1: + return [indices] + elif dim >= 2: + m = functools.reduce(operator.mul, width[1:]) + ih = indices // m + indices_new = indices - (ih * m) + return [ih, *_flattened_index_to_nd(indices_new, width[1:])] + else: + raise ValueError(f"Unknown dim: {dim}") + + +def _flatten_index(indices, width): + result = indices[0] + for d in range(1, len(indices)): + result = width[d] * result + indices[d] + return result + + +def _low_memory_max_pool_with_offsets_aten( self, kernel_size, stride, @@ -127,80 +150,69 @@ def _low_memory_max_pool2d_with_offsets_aten( dilation, ceil_mode, ): - vals, indices = torch.ops.aten.max_pool2d_with_indices( - self, kernel_size, stride, padding, dilation, ceil_mode - ) - - input_width = self.shape[-1] - kernel_width = kernel_size[1] - - bh_shape = [1] * self.ndim - bh_shape[-2] = -1 - bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view( - bh_shape - ) - - bw_shape = [1] * self.ndim - bw_shape[-1] = -1 - bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view( - bw_shape - ) + dim = len(kernel_size) + if dim == 2: + vals, indices = torch.ops.aten.max_pool2d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + else: + vals, indices = torch.ops.aten.max_pool3d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) - hbase = bh * stride[0] - padding[0] - wbase = bw * stride[1] - padding[1] + idhw = _flattened_index_to_nd(indices, self.shape[-dim:]) - ih = indices // input_width - iw = indices - (ih * input_width) + dhw_inc = [] - h_inc = (ih - hbase) // dilation[0] - w_inc = (iw - wbase) // dilation[1] + for d in range(dim): + bh_shape = [1] * self.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + indices.shape[-dim + d], dtype=torch.int64, device=self.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + h_inc = (idhw[d] - hbase) // dilation[d] + dhw_inc.append(h_inc) - offsets = h_inc * kernel_width + w_inc + offsets = _flatten_index(dhw_inc, kernel_size) return vals, offsets.to(torch.int8) -def _low_memory_max_pool2d_offsets_to_indices_aten( +def _low_memory_max_pool_offsets_to_indices_aten( offsets, - kernel_width, - input_width, + kernel_size, + input_size, stride, padding, dilation, ): + dim = len(kernel_size) offsets = offsets.to(torch.int64) - h_inc = offsets // kernel_width - w_inc = offsets - (h_inc * kernel_width) - - bh_shape = [1] * offsets.ndim - bh_shape[-2] = -1 - bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view( - bh_shape - ) - - bw_shape = [1] * offsets.ndim - bw_shape[-1] = -1 - bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view( - bw_shape - ) + dhw_inc = _flattened_index_to_nd(offsets, kernel_size) - hbase = bh * stride[0] - padding[0] - wbase = bw * stride[1] - padding[1] + idhw = [] + for d in range(dim): + bh_shape = [1] * offsets.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + offsets.shape[-dim + d], dtype=torch.int64, device=offsets.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + idhw.append(hbase + dhw_inc[d] * dilation[d]) - ih = hbase + h_inc * dilation[0] - iw = wbase + w_inc * dilation[1] - return ih * input_width + iw + return _flatten_index(idhw, input_size) -_low_memory_max_pool2d_with_offsets = make_prim( - "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 - _low_memory_max_pool2d_with_offsets_aten, +_low_memory_max_pool_with_offsets = make_prim( + "_low_memory_max_pool_with_offsets(Tensor self, SymInt[] kernel_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + _low_memory_max_pool_with_offsets_aten, return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), doc="Instead of returning indices, returns indices offsets.", ) -_low_memory_max_pool2d_offsets_to_indices = make_prim( - "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", # noqa: B950 - _low_memory_max_pool2d_offsets_to_indices_aten, +_low_memory_max_pool_offsets_to_indices = make_prim( + "_low_memory_max_pool_offsets_to_indices(Tensor self, SymInt[] kernel_size, SymInt[] input_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation) -> Tensor", # noqa: B950 + _low_memory_max_pool_offsets_to_indices_aten, doc="Convert small int offsets to regular indices.", ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 84069bbdf829..a312ea3ca11c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -95,6 +95,7 @@ sympy_index_symbol_with_prefix, sympy_product, sympy_subs, + tensor_is_aligned, ) from .virtualized import ops, OpsValue, V @@ -6996,11 +6997,16 @@ def generate_output(output, indices): # type: ignore[no-untyped-def] for key, val in output.items() } elif isinstance(output, torch.Tensor): - return MultiOutput( + buf = MultiOutput( cls.tensor_to_layout(output), packed, indices, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + output + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return buf elif isinstance(output, int): return output elif isinstance(output, torch.SymInt): @@ -8051,6 +8057,11 @@ def create_out_of_place( # type: ignore[no-untyped-def] ) for i, tensor in enumerate(example_output) ] + for buf, tensor in zip(packed.outputs, example_output): + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + tensor + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] return packed.outputs else: packed = cls( @@ -8060,6 +8071,10 @@ def create_out_of_place( # type: ignore[no-untyped-def] non_tensor_args, unflatten_args, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + example_output + ): + V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type] packed.outputs = [packed] return packed diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index c3886111cb02..cd074e2c36d4 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -74,6 +74,8 @@ def _is_large_block_for_cpu(m, n, k): group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ffa1531efd42..604f4523793a 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Optional +from typing import Any, Optional import torch from torch._dynamo.utils import counters @@ -21,10 +21,16 @@ from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.wrapper import PythonWrapperCodegen from ..ir import FlexibleLayout, is_triton -from ..lowering import register_lowering +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, + realize_inputs, TritonTemplate, ) from ..utils import ( @@ -46,6 +52,8 @@ mm_options, persistent_mm_grid, persistent_mm_options, + scale_mm_epilogue, + scaled_mm_options, should_fallback_to_aten, ) @@ -90,6 +98,8 @@ group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -119,7 +129,12 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -159,6 +174,8 @@ group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -188,7 +205,11 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -302,9 +323,183 @@ # inductor generates a suffix {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + """, ) +load_scales = r""" +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): + if SCALING_ROWWISE: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale +""" + + +apply_scaling = r""" +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALING_ROWWISE: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale +""" + + +device_tma = r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALING_ROWWISE: + stride_a_scale_m = 1 + stride_b_scale_n = 1 + else: + stride_a_scale_m = 0 + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +""" + + +scaled_mm_device_tma_template = TritonTemplate( + name="scaled_mm_device_tma", + grid=persistent_mm_grid, + source=device_tma + load_scales + apply_scaling, +) + # prevent duplication registration of extern functions @functools.lru_cache(None) @@ -326,6 +521,10 @@ def lazy_register_extern_choice(fn): has_out_variant=False, ) +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) @@ -336,6 +535,16 @@ def _is_large_block_for_cpu(m, n, k): return m * n > 2**13 +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt @@ -347,6 +556,32 @@ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) +def check_supported_striding(mat_a, mat_b) -> None: + def is_row_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[1], 1) + + def is_col_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[0], 1) + + def has_zero_dim(size) -> bool: + return bool( + V.graph.sizevars.statically_known_equals(size[0], 0) + or V.graph.sizevars.statically_known_equals(size[1], 0) + ) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + aten_bias_addmm = ExternKernelChoice(bias_addmm, None) @@ -746,6 +981,151 @@ def tuned_sparse_semi_structured_mm( ) +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + layout=None, +): + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat_a) + check_supported_striding(mat_a, mat_b) + + scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) + + input_nodes: tuple[Any, ...] + + if not bias: + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) + else: + bias_real = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + triton_input_nodes: tuple[Any, ...] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = ( + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ) + suffix_args = 3 + else: + triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) + suffix_args = 2 + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exists + if use_triton_tma_template(mat_a, mat_b) and not bias: + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options( + config, + m, + n, + k, + layout, + scale_a, + scale_b, + use_fast_accum, + device_tma=True, + ) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + + for config in scaled_mm_configs(m, n, k): + if k == 16 and config.kwargs["BLOCK_M"] >= 64: + continue # Triton crashes in this case + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and k < 32: + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + mm_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + **kwargs, + suffix_args=suffix_args, + epilogue_fn=scale_mm_epilogue(), + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + if should_fallback_to_aten(choices): + return aten_choice.output_node() + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + @functools.lru_cache(None) def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: props = torch.cuda.get_device_properties(index or 0) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index d990536c4362..079d6e83d623 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -72,16 +72,23 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): not inductor_config.force_same_precision or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) ) - return dict( - GROUP_M=8, + options_dict = dict( EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm ACC_TYPE=acc_type(layout.dtype), num_stages=config.num_stages, num_warps=config.num_warps, **config.kwargs, ) + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in config.kwargs: + group_m = config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + def persistent_mm_options(mat1, mat2): return dict( @@ -92,6 +99,47 @@ def persistent_mm_options(mat1, mat2): ) +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a, + scale_b, + use_fast_accum: bool, + device_tma: bool = False, +) -> dict[str, Any]: + def are_compatible_scales(size_a, size_b) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) + + mm_template_options["ACC_TYPE"] = "tl.float32" + mm_template_options["USE_FAST_ACCUM"] = use_fast_accum + mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 + + if device_tma: + mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + mm_template_options["NUM_SMS"] = get_num_sms() + + return mm_template_options + + def mm_args( mat1, mat2, @@ -154,6 +202,34 @@ def epilogue(acc, bias): return epilogue +def scale_mm_epilogue(): + """ + Create an epilogue function that applies scaling to matrix multiplication result + using the given scale factors. + + Args: + dtype: The data type of the output + scale_a: Scale factor for matrix A + scale_b: Scale factor for matrix B + + Returns: + Epilogue function that takes the accumulator and applies scaling + """ + + def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): + # The epilogue function receives the accumulator (result of mat1 @ mat2) + # and applies the scaling factors + # In the original scaled_mm, we use inverse scales, so we multiply by them + mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) + mul_acc = V.ops.mul(acc, mul_scales) + if bias is not None: + return V.ops.add(mul_acc, bias) + else: + return mul_acc + + return epilogue + + def _is_static_problem(layout: Layout) -> tuple[bool, bool]: """ Check if input tensors and output layout have static shapes and non-zero sizes. diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index ac6bbee6c75a..2e190595c0d1 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -53,6 +53,8 @@ group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py deleted file mode 100644 index aa917e120168..000000000000 --- a/torch/_inductor/kernel/mm_scaled.py +++ /dev/null @@ -1,608 +0,0 @@ -import functools -import logging -from collections.abc import Sequence -from typing import Any, Optional - -import sympy - -import torch -from torch._dynamo.utils import counters -from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate -from torch.utils._triton import has_triton_tma_device - -from ..config import triton as triton_config -from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox -from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering -from ..select_algorithm import ( - autotune_select_algorithm, - ExternKernelChoice, - realize_inputs, - TritonTemplate, -) -from ..utils import ( - get_num_sms, - get_tma_workspace_arg, - TMA_DESCRIPTOR_SIZE, - use_aten_gemm_kernels, - use_ck_gemm_template, - use_triton_template, -) -from ..virtualized import V -from .mm_common import ( - _is_static_problem, - mm_args, - mm_grid, - persistent_mm_grid, - should_fallback_to_aten, -) - - -log = logging.getLogger(__name__) -aten = torch.ops.aten - -load_scales = r""" -@triton.jit -def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): - if SCALING_ROWWISE: - # For row-wise scaling, we'll return the pointers - return a_scale_ptr, b_scale_ptr - else: - # For per-tensor scaling, we'll load the scalar values - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr) - return a_scale, b_scale -""" - - -apply_scaling = r""" -@triton.jit -def apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE: tl.constexpr, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, -): - if SCALING_ROWWISE: - # For row-wise scaling, we need to load the scales for each row/column - a_scales = tl.load( - a_scale + (offs_cm * stride_a_scale_m), - mask=offs_cm < M, - other=0.0, - ) - b_scales = tl.load( - b_scale + (offs_cn * stride_b_scale_n), - mask=offs_cn < N, - other=0.0, - ) - acc_scale = a_scales[:, None] * b_scales[None, :] - else: - # For per-tensor scaling, we can directly use the loaded scalar values - acc_scale = a_scale * b_scale - - return accumulator * acc_scale -""" - - -device_tma = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - if SCALING_ROWWISE: - stride_a_scale_m = 1 - stride_b_scale_n = 1 - else: - stride_a_scale_m = 0 - stride_b_scale_n = 0 - - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if ki == k_tiles - 1: - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) - - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" - - -scaled_mm_device_tma_template = TritonTemplate( - name="scaled_mm_device_tma", - grid=persistent_mm_grid, - source=device_tma + load_scales + apply_scaling, -) - - -scaled_mm_template = TritonTemplate( - name="scaled_mm", - grid=mm_grid, - source=r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.) - b = tl.load(B, mask=rk[:, None] < k, other=0.) - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, out_dtype=ACC_TYPE) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - if SCALING_ROWWISE: - inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) - inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) - inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] - acc *= inv_scale_row - else: - # for tensor-wise scaling, the scales are scalars - inv_a_scale = tl.load(A_inverse_scale) - inv_b_scale = tl.load(B_inverse_scale) - inv_scale = inv_a_scale * inv_b_scale - acc *= inv_scale - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask")}} -""", -) - - -# Inductor does not allow optional tensor input arguments currently (pass None as an -# input node to template choices), but since for _scaled_mm there is only one such arg -# (bias), work around by having a second template when bias is provided. -scaled_mm_bias_template = TritonTemplate( - name="scaled_mm_bias", - grid=mm_grid, - source=r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.) - b = tl.load(B, mask=rk[:, None] < k, other=0.) - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, out_dtype=ACC_TYPE) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - if SCALING_ROWWISE: - inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) - inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) - inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] - acc *= inv_scale_row - else: - # for tensor-wise scaling, the scales are scalars - inv_a_scale = tl.load(A_inverse_scale) - inv_b_scale = tl.load(B_inverse_scale) - inv_scale = inv_a_scale * inv_b_scale - acc *= inv_scale - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # bias - bias = tl.load(bias_ptr + rn, mask=rn < N) - acc += bias - - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask")}} -""", -) - - -aten__fp8_mm = ExternKernelChoice( - torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out -) - - -def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool: - # Same sized scales are compatable - if len(size_a) == len(size_b): - return True - - # Both need to be scalars or len(1) tensors - if len(size_a) <= 1 and len(size_b) <= 1: - return True - - return False - - -def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: - def is_row_major(stride: Sequence[_IntLike]) -> bool: - return stride[1] == 1 - - def is_col_major(stride: Sequence[_IntLike]) -> bool: - return stride[0] == 1 - - def has_zero_dim(size: Sequence[_IntLike]) -> bool: - return bool(size[0] == 0 or size[1] == 0) - - # Check mat_a (self) stride requirements - torch._check( - is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), - lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", - ) - - # Check mat_b stride requirements - torch._check( - is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), - lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", - ) - - -def scaled_mm_options_device_tma( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a: StorageBox, - scale_b: StorageBox, - use_fast_accum: bool, -) -> dict[str, Any]: - even_k_symbolic = ( - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - return dict( - GROUP_M=8, - EVEN_K=even_k_symbolic, - ACC_TYPE="tl.float32", - USE_FAST_ACCUM=use_fast_accum, - num_stages=config.num_stages, - num_warps=config.num_warps, - # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) == 2, - TMA_SIZE=TMA_DESCRIPTOR_SIZE, - NUM_SMS=get_num_sms(), - **config.kwargs, - ) - - -def scaled_mm_options( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a: StorageBox, - scale_b: StorageBox, - use_fast_accum: bool, -) -> dict[str, Any]: - even_k_symbolic = ( - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - return dict( - GROUP_M=8, - EVEN_K=even_k_symbolic, - ACC_TYPE="tl.float32", - USE_FAST_ACCUM=use_fast_accum, - num_stages=config.num_stages, - num_warps=config.num_warps, - # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) == 2, - **config.kwargs, - ) - - -add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) - - -def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool: - available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul - # _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes - # When K is 16, BLOCK_K = 16 and is not valid - min_k = k >= 32 - return available and min_k and not has_bias - - -@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] -def tuned_scaled_mm( - mat_a: TensorBox, - mat_b: TensorBox, - scale_a: TensorBox, - scale_b: TensorBox, - bias: Optional[TensorBox] = None, - scale_result: Optional[TensorBox] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - layout: Optional[Layout] = None, -) -> TensorBox: - m, n, k, layout, mat_a, mat_b = mm_args( - mat_a, mat_b, layout=layout, out_dtype=out_dtype - ) - - # below is for getting an overview logging info of inductor mms - counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 - log.info( - "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", - m, - n, - k, - mat_a.get_dtype(), - mat_b.get_dtype(), - layout, - ) - - device_type = get_device_type(mat_a) - - check_supported_striding(mat_a, mat_b) - - scale_a, scale_b = realize_inputs(scale_a, scale_b) - - input_nodes: tuple[Any, ...] - # workaround for Inductor not supporting optional tensor input arguments - if bias is None: - input_nodes = (mat_a, mat_b, scale_a, scale_b) - triton_template = scaled_mm_template - else: - bias = realize_inputs(bias) - input_nodes = (mat_a, mat_b, scale_a, scale_b, bias) - triton_template = scaled_mm_bias_template - - aten_choice = aten__fp8_mm.bind( - input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum - ) - - choices: list[ChoiceCaller] = [] - if use_aten_gemm_kernels(): - choices.append(aten_choice) - - _, is_nonzero = _is_static_problem(layout) - - scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) - scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( - device_type - ) - - if is_nonzero and use_triton_template(layout, enable_float8=True): - if use_persistent_tma(k, bias is not None): - for config in scaled_persistent_mm_configs(m, n, k): - kwargs = scaled_mm_options_device_tma( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - input_nodes = (mat_a, mat_b, scale_a, scale_b) - scaled_mm_device_tma_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - workspace_arg=get_tma_workspace_arg( - num_tma_descriptors=2, - device=mat_a.get_device(), - ), - **kwargs, - ) - else: - for config in scaled_mm_configs(m, n, k): - if k == 16 and config.kwargs["BLOCK_M"] >= 64: - continue # Triton crashes in this case - - # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid - # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape - if using_b200() and k < 32: - continue - - kwargs = scaled_mm_options( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - # possibly appends a TritonTemplateCaller to choices - triton_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - ) - - if is_nonzero and use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) - - if should_fallback_to_aten(choices): - return aten_choice.output_node() - - return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) - - -@functools.lru_cache -def using_b200() -> bool: - """Returns true if the device is a NVIDIA B200, otherwise returns false.""" - if not torch.cuda.is_available(): - return False - # compute capability 10.0 or 10.0a is NVIDIA B200 - device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) - return device_properties.major == 10 diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 64b505d5cdac..24520887f6aa 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -40,6 +40,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -222,6 +223,7 @@ def add_layout_constraint(fn, constraint): aten.convolution, aten.convolution_backward, aten.max_pool2d_with_indices, + aten.max_pool3d_with_indices, aten.max_pool2d_with_indices_backward, aten.mm, aten.upsample_nearest2d, @@ -1087,8 +1089,6 @@ def trunc(x): @register_lowering(aten.expand, type_promotion_kind=None) def expand(x, sizes): - from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols - (x,) = promote_constants([x]) if isinstance(x, ir.BaseConstant): return ExpandView.create(x, tuple(sizes)) @@ -1165,8 +1165,9 @@ def inner_fn(index): return x_loader(index) old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) - if old_size_product > 0: - # maybe realize the input + if old_size_product > 0 and not free_unbacked_symbols(new_size): + # maybe realize the input but skip for unbacked symints since it'll + # choke on the size hint. x.mark_reuse( V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product ) @@ -2616,7 +2617,6 @@ def is_aligned(x): make_fallback(aten._adaptive_avg_pool3d) # @isuruf make_fallback(aten.adaptive_max_pool3d) # @isuruf make_fallback(aten.fractional_max_pool3d) # @isuruf -make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) # 1) Easy @@ -3376,6 +3376,11 @@ def fn(idx): @register_lowering(aten.embedding, type_promotion_kind=None) def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + if sparse: + return fallback_handler(aten.embedding.default)( + weight, indices, padding_idx, scale_grad_by_freq, sparse + ) + assert not sparse assert isinstance(weight, TensorBox) assert isinstance(indices, TensorBox) @@ -4348,57 +4353,62 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None return x_out, ceil_mode -def should_fallback_max_pool2d_with_indices(kernel_size): - kernel_size = pad_listlike(kernel_size, 2) - window_size = kernel_size[0] * kernel_size[1] +def should_fallback_max_pool_with_indices(kernel_size, *, n_dim): + kernel_size = pad_listlike(kernel_size, n_dim) + window_size = functools.reduce(operator.mul, kernel_size) return window_size > 25 -def max_pool2d_checks( - x, kernel_size, stride, padding, dilation, *, assert_fallback=None +def max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None ): if padding == 0: - padding = [0, 0] + padding = [0] * n_dim if dilation == 1: - dilation = [1, 1] + dilation = [1] * n_dim if not stride: stride = kernel_size - kernel_size = pad_listlike(kernel_size, 2) - stride = pad_listlike(stride, 2) - padding = pad_listlike(padding, 2) - dilation = pad_listlike(dilation, 2) + kernel_size = pad_listlike(kernel_size, n_dim) + stride = pad_listlike(stride, n_dim) + padding = pad_listlike(padding, n_dim) + dilation = pad_listlike(dilation, n_dim) assert isinstance(x, TensorBox) - assert len(kernel_size) == 2 - assert len(stride) == 2 - assert len(padding) == 2 - assert len(dilation) == 2 - assert len(x.get_size()) in (3, 4) + assert len(kernel_size) == n_dim + assert len(stride) == n_dim + assert len(padding) == n_dim + assert len(dilation) == n_dim + assert len(x.get_size()) in (n_dim + 1, n_dim + 2) - use_fallback = should_fallback_max_pool2d_with_indices(kernel_size) + use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim) if assert_fallback is not None: assert use_fallback == assert_fallback return kernel_size, stride, padding, dilation, use_fallback -def _max_pool2d_with_offsets( +def _max_pool_with_offsets( x, kernel_size, stride, padding, dilation, - ceil_mode=False, + ceil_mode, + *, + n_dim, ): x.realize_hint() - *batch, h, w = x.get_size() + batch = x.shape[:-n_dim] + dhw = x.shape[-n_dim:] - h_out, ceil_mode1 = pooling_size( - h, 0, kernel_size, stride, padding, ceil_mode, dilation=dilation - ) - w_out, ceil_mode2 = pooling_size( - w, 1, kernel_size, stride, padding, ceil_mode, dilation=dilation + dhw_out, ceil_mode = zip( + *[ + pooling_size( + dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation + ) + for d in range(n_dim) + ] ) dtype = x.dtype @@ -4408,27 +4418,18 @@ def _max_pool2d_with_offsets( else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) ) - new_size = list(batch) + [h_out, w_out] - if ( - padding[0] - or padding[1] - or ceil_mode1 - or ceil_mode2 - or (dilation[0] > 1) - or (dilation[1] > 1) - ): - x_loader = constant_boundary_condition(x, min_value, dim=2) + new_size = list(batch) + list(dhw_out) + if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation): + x_loader = constant_boundary_condition(x, min_value, dim=n_dim) else: x_loader = x.make_loader() - dim = 2 - def fn_inner(idx, reduction_idx): - prefix = idx[:-dim] - bh = idx[-dim:] + prefix = idx[:-n_dim] + bh = idx[-n_dim:] ih = [ (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] - for i in range(dim) + for i in range(n_dim) ] return x_loader([*prefix, *ih]) @@ -4462,8 +4463,8 @@ def fn_inner(idx, reduction_idx): return result, offsets -@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None) -def _low_memory_max_pool2d_with_offsets( +@register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool_with_offsets( x, kernel_size, stride, @@ -4471,53 +4472,60 @@ def _low_memory_max_pool2d_with_offsets( dilation, ceil_mode=False, ): + n_dim = len(kernel_size) + # assert we are not on a fallback path, the inductor decomp should have guaranteed this - kernel_size, stride, padding, dilation, _ = max_pool2d_checks( + kernel_size, stride, padding, dilation, _ = max_pool_checks( x, kernel_size, stride, padding, dilation, + n_dim, assert_fallback=False, ) with config.patch(unroll_reductions_threshold=25): - result, offsets = _max_pool2d_with_offsets( + result, offsets = _max_pool_with_offsets( x, kernel_size, stride, padding, dilation, ceil_mode, + n_dim=n_dim, ) return result, to_dtype(offsets, torch.int8) @register_lowering( - prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None + prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None ) -def _low_memory_max_pool2d_offsets_to_indices( - offsets, kernel_width, input_width, stride, padding, dilation +def _low_memory_max_pool_offsets_to_indices( + offsets, kernel_size, input_size, stride, padding, dilation ): - # TODO: Generalize to other max pooling flavors, and arbitrary dim - + # TODO: Generalize to other max pooling flavors + n_dim = len(kernel_size) offsets_loader = offsets.make_loader() - def increments_to_index(h_inc, w_inc, bh, bw): - w_in = ops.index_expr(input_width, torch.int64) - hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) - wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) - ih = hbase + h_inc * ops.constant(dilation[0], torch.int64) - iw = wbase + w_inc * ops.constant(dilation[1], torch.int64) - return ih * w_in + iw + def increments_to_index(dhw_inc, bh): + w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(n_dim)] + hbase = [ + ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64) + for d in range(n_dim) + ] + idhw = [ + hbase[d] + dhw_inc[d] * ops.constant(dilation[d], torch.int64) + for d in range(n_dim) + ] + return inductor_prims._flatten_index(idhw, w_in) def offsets_to_indices(idx): - *prefix, bh, bw = idx - offset = offsets_loader([*prefix, bh, bw]) - kw_const = ops.constant(kernel_width, torch.int32) - h_inc = offset // kw_const - w_inc = offset - (h_inc * kw_const) - return increments_to_index(h_inc, w_inc, bh, bw) + bh = idx[-n_dim:] + offset = offsets_loader(idx) + k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(n_dim)] + dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const) + return increments_to_index(dhw_inc, bh) indices = Pointwise.create( device=offsets.get_device(), @@ -4528,6 +4536,35 @@ def offsets_to_indices(idx): return indices +def _max_pool_with_indices( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + n_dim, +): + kernel_size, stride, padding, dilation, _ = max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim=n_dim + ) + + out, offsets = _max_pool_with_offsets( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim + ) + + indices = _low_memory_max_pool_offsets_to_indices( + offsets, + kernel_size, + x.shape[-n_dim:], + stride, + padding, + dilation, + ) + + return out, indices + + # Fallback when we do not decompose to the low-memory path. @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) def max_pool2d_with_indices( @@ -4538,20 +4575,25 @@ def max_pool2d_with_indices( dilation=1, ceil_mode=False, ): - kernel_size, stride, padding, dilation, _ = max_pool2d_checks( - x, kernel_size, stride, padding, dilation + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2 ) - out, offsets = _max_pool2d_with_offsets( - x, kernel_size, stride, padding, dilation, ceil_mode - ) - indices = _low_memory_max_pool2d_offsets_to_indices( - offsets, kernel_size[-1], x.shape[-1], stride, padding, dilation +# Fallback when we do not decompose to the low-memory path. +@register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None) +def max_pool3d_with_indices( + x, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3 ) - return out, indices - fallback_max_pool2d_with_indices_backward = fallback_handler( aten.max_pool2d_with_indices_backward.default, diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 422b256ca96f..74999462abc8 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -73,6 +73,22 @@ def _conv_input_size( input_size.append(input_size_d) return list(map(int, input_size)) + # Port from aten/src/ATen/native/ConvUtils.h: _conv_output_size + def _conv_output_size(input_size, weight_size, padding, stride, dilation=None): + has_dilation = dilation is not None + dim = len(input_size) + output_size = [] + output_size.append(input_size[0]) + output_size.append(weight_size[0]) + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight_size[d] - 1) + 1 + output_size_d = (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[ + d - 2 + ] + 1 + output_size.append(output_size_d) + return output_size + # The size of prepacked_weight is the prepacked weight size of deconv: # Groups > 1: [g*o, i/g, ...] # Groups == 1: [o, i, ...] @@ -130,21 +146,18 @@ def _original_deconv_weight_size( groups, ) else: - bias_fake = ( - ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias - ) - output = torch.ops.aten.convolution( - x_fake, - weight_fake, - bias_fake, - stride, + x_shape = list(x_fake.shape) + weight_shape = list(weight_fake.shape) + if len(x_shape) != len(weight_shape): + assert len(x_shape) == 3 and len(weight_shape) == 4 + weight_shape.pop(2) + output_size = _conv_output_size( + x_shape, + weight_shape, padding, + stride, dilation, - transposed, - output_padding, - groups, ) - output_size = output.size() req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) req_stride_order = [len(req_stride_order)] + req_stride_order @@ -562,8 +575,8 @@ def __init__( inputs, constant_args, None, - op_overload=torch.ops.onednn.qconv2d_pointwise.default, - cpp_kernel_name="aoti_torch_cpu__qconv2d_pointwise_tensor", + op_overload=torch.ops.onednn.qconv_pointwise.default, + cpp_kernel_name="aoti_torch_cpu__qconv_pointwise_tensor", ) def codegen(self, wrapper): diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index d665aa3b892d..7ac5ee02ac43 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -130,7 +130,7 @@ def register_onednn_fusion_ops(): torch.ops.mkldnn._convolution_transpose_pointwise, torch.ops.mkldnn._linear_pointwise, aten.mkldnn_rnn_layer.default, - torch.ops.onednn.qconv2d_pointwise, + torch.ops.onednn.qconv_pointwise, ] @register_lowering(torch.ops.mkldnn._convolution_pointwise) @@ -428,7 +428,7 @@ def mkldnn_rnn_layer( ), ) - @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) + @register_lowering(torch.ops.onednn.qconv_pointwise, type_promotion_kind=None) def qconvolution_unary( x: TensorBox, x_scale, diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 1891c7d15dca..792a6b4385a2 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -251,14 +251,75 @@ def replace_by_example( else contextlib.nullcontext() ) + def should_propagate_eager_input_vals(nodes: list[torch.fx.Node]) -> bool: + if len(nodes) != 1: + return False + node = nodes[0] + if "eager_input_vals" not in node.meta: + return False + return node.target in OrderedSet( + [ + torch.ops.higher_order.triton_kernel_wrapper_functional, + torch.ops.higher_order.auto_functionalized, + torch.ops.higher_order.auto_functionalized_v2, + ] + ) + with context: if trace_fn is None: trace_fn = functools.partial( fwd_only, run_functional_passes=run_functional_passes ) - replacement = trace_fn( - replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) - ) + + if should_propagate_eager_input_vals(self.nodes): + # Our strategy is: + # 1) trace out the graph with eager_input_vals (which have accurate eager-mode metadata) + # 2) trace out the graph with vals (which have the accurate Inductor metadata) + # 3) Propagate the eager_input_vals from the first graph to the second. + # 4) Use the second graph as the replacement graph. + + # Construct a map of node -> FakeTensor val in eager_input_vals + node_to_val = {} + + fake_args, fake_kwargs = self.nodes[0].meta["eager_input_vals"] + fake_kwargs = {**fake_kwargs} + match_args, match_kwargs = tuple(self.args), self.kwargs + + def record(node: torch.fx.Node, val: Any) -> None: + if isinstance(node, torch.fx.Node): + node_to_val[node] = val + + torch.utils._pytree.tree_map( + record, (match_args, match_kwargs), (fake_args, fake_kwargs) + ) + # map args to their FakeTensor val in eager_input_vals + example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg]) + + # first graph + graph_with_eager_vals = trace_fn(replacement_fn, example_vals) + + # second graph + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(graph_with_eager_vals, example_vals) + + # propagate metadata from first graph to second + # NB: This assertion might not be true in general, but it is true for + # the two use cases we have + # (triton_kernel_wrapper_functional, auto_functionalized) + assert len(graph_with_eager_vals.graph.nodes) == len( + replacement.graph.nodes + ) + for old_node, new_node in zip( + graph_with_eager_vals.graph.nodes, replacement.graph.nodes + ): + if "eager_input_vals" in old_node.meta: + new_node.meta["eager_input_vals"] = old_node.meta[ + "eager_input_vals" + ] + + else: + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(replacement_fn, example_vals) if len(self.nodes) == 1: for n in replacement.graph.nodes: _transfer_meta( @@ -1083,6 +1144,11 @@ def run_node(self, node: torch.fx.Node) -> Any: old_node=node, pass_name="Interpreter_Replacer", ) + # This function copy-pastes the replacement graph into + # the graph. If the replacement graph had any eager_input_vals, + # or val/tensor_meta, we propagate those over. + if "eager_input_vals" in node.meta: + result.meta["eager_input_vals"] = node.meta["eager_input_vals"] if "val" in node.meta and "val" not in result.meta: result.meta["val"] = node.meta["val"] if isinstance(node.meta["val"], torch.Tensor): diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index d19a96a85604..4988f3780812 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -121,7 +121,21 @@ def _setup_local_cache( if not inductor_meta.get("autotune_local_cache", True): return - cache_filename = f"{dirname}/{cache_key}.best_config" + from ..codecache import torch_key + + """ + [Note: torch_key in autotune cache key] + Include torch_key() in the cache key so that different versions + of torch result in cache invalidation. This is important in case + of changes to the best_config format or other code changes that + are not backward compatible w.r.t. the cache. + """ + hasher = hashlib.sha256() + hasher.update(cache_key.encode("utf-8")) + hasher.update(torch_key()) + updated_cache_key = hasher.hexdigest() + + cache_filename = f"{dirname}/{updated_cache_key}.best_config" local_cache = LocalAutotuneCache() self.local_cache = (local_cache, cache_filename) @@ -139,10 +153,13 @@ def _setup_remote_autotune_cache( return assert isinstance(backend_hash, str) + from ..codecache import torch_key + is_fbcode = bool(inductor_meta.get("is_fbcode", False)) salt = "autotune-best-config-v2" - key = backend_hash + self.configs_hash + salt + # re: torch_key - see [Note: torch_key in autotune cache key] + key = torch_key().hex() + backend_hash + self.configs_hash + salt key = hashlib.sha256(key.encode("utf-8")).hexdigest() remote_cache = create_cache( diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 3bc8df35a838..f224217db22b 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -181,14 +181,14 @@ class HalideInputSpec(typing.NamedTuple): alias_of: Optional[str] = None def bindings_type(self) -> str: - if self.ctype in ("half*", "bfloat16*"): + if self.ctype in ("at::Half*", "at::BFloat16*"): return "uint16_t*" # half not defined return self.ctype def halide_type(self) -> str: - if self.ctype == "half*": + if self.ctype == "at::Half*": return "halide_type_t(halide_type_float, 16)" # half not defined - if self.ctype == "bfloat16*": + if self.ctype == "at::BFloat16*": return "halide_type_t(halide_type_bfloat, 16)" # half not defined return f"halide_type_of<{self.ctype.replace('*', '')}>()" diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 4e1a67139f38..daf1afa8ec28 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -31,6 +31,7 @@ import torch from torch._prims_common import compute_required_storage_length +from torch.monitor import _WaitCounter from torch.utils._ordered_set import OrderedSet from ..triton_bundler import TritonBundler @@ -815,13 +816,18 @@ def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: return self.maybe_clone_args(OrderedSet(), *args, **kwargs) def benchmark_all_configs(self, *args, **kwargs): - with dynamo_timed( - "CachingAutotuner.benchmark_all_configs", - log_pt2_compile_event=True, - metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, - dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", - compile_id=self.compile_id, - is_backward=self.is_backward, + with ( + dynamo_timed( + "CachingAutotuner.benchmark_all_configs", + log_pt2_compile_event=True, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ), + _WaitCounter("pytorch.wait_counter.dynamo_compile").guard(), ): timings = { launcher: self.bench(launcher, *args, **kwargs) @@ -972,7 +978,15 @@ def benchmark_one_config(config): self.autotune_time_taken_ns + coordesc_time_taken_ns, found_by_coordesc=True, ) - return config2launcher.get(best_config) + + if best_config not in config2launcher: + # On a Coordesc cache hit, we might not have loaded the launcher + # This can happen because PyCodeCache saves CachingAutotuners in memory, + # even for separate compile IDs (which can have different inputs without changing output code) + config2launcher[best_config] = self._precompile_config( + best_config + ).make_launcher() + return config2launcher[best_config] def run( self, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index b756cdc4aa98..34df7f21b595 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -464,6 +464,38 @@ def decide_inplace_update(self) -> None: | self.scheduler.completed_operations ) + def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: + # Inside of NodeUser, we track that the read and write are equivalent + # before deciding if the use can be inplace. + # But if that use is fused into a larger kernel, we need to check equivalence + # of other accesses in fused scheduler node as well. + fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self) + buf_name = buf_to_be_inplaced.get_name() + # Dedup read/writes with equivalent indices + # TODO - would be nice if we could just cache accesses on ReadWrites, + # and inforce variant that this class & members are functional.. + deps: OrderedSet[Dep] = OrderedSet() + for user in buf_to_be_inplaced.users: + user_node = user.node + if not isinstance(user_node, BaseSchedulerNode): + continue + + if ( + buf_to_be_inplaced.scheduler.get_fused_node(user_node) + is not fused_node + ): + continue + + deps |= ( + o + for o in user_node.read_writes.reads_and_writes() + if o.name == buf_name + ) + if len(deps) > 1: + return False + + return True + for buf in self.get_outputs(): buf_node = buf.node assert buf_node is not None @@ -515,6 +547,7 @@ def decide_inplace_update(self) -> None: and len(input_buf.node.get_inputs_that_alias_output()) > 0 ) and can_match_buffer_size(input_buf.node, buf.node) + and single_index_in_fused_node(input_buf) ): # if there isn't a triton kernel, then we don't need to call triton-specific things. # but TODO this might be a convenient place to signal to the Collective kernels to inplace diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 558e59af81c7..35dae177bac7 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -746,11 +746,10 @@ def load_input( indices, self.range_trees[0].construct_entries(lengths) ): range_tree_entry.set_name(name) - contiguous_index = sympy_dot( - ir.FlexibleLayout.contiguous_strides(lengths), index_symbols - ) - contiguous_index = self.rename_indexing(contiguous_index) - self.body.writeline("xindex = " + texpr(contiguous_index)) + + strided_index = sympy_dot(input_node.get_stride(), index_symbols) + strided_index = self.rename_indexing(strided_index) + self.body.writeline("xindex = " + texpr(strided_index)) xindex_range_root = self.range_trees[0].lookup( sympy.Integer(1), sympy_product(lengths) @@ -823,7 +822,7 @@ def store( output_index = self.rename_indexing(output_index) - if output_index == contiguous_index: + if output_index == strided_index: output_index_str = "xindex" else: out_indexing = self.indexing( @@ -1283,6 +1282,7 @@ def make_kernel_render(out_node): ), "num_stages": num_stages, "num_warps": num_warps, + "GROUP_M": kwargs.get("GROUP_M", -1), "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), "acc_type": str(kwargs.get("ACC_TYPE", None)), }, @@ -2205,6 +2205,9 @@ def log_results( for n in input_nodes ] ) + + strides = ", ".join([str(n.get_stride()) for n in input_nodes]) + dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes]) if config.autotune_num_choices_displayed == 0: return # when autotune_num_choices_displayed is None, [:None] means all @@ -2252,6 +2255,9 @@ def get_choice_info(choice): best_time = timings[best] sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + sys.stderr.write(f"strides: {strides}\n") + sys.stderr.write(f"dtypes: {dtypes}\n") + for choice in top_k: result = timings[choice] if result: diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 400d1ad2b6de..fe6476f317f3 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -1,7 +1,7 @@ from __future__ import annotations +import dataclasses import itertools -from collections import namedtuple from functools import partial from threading import Lock from typing import Any, Callable, TYPE_CHECKING @@ -14,12 +14,59 @@ if TYPE_CHECKING: - from collections.abc import Generator, Sequence + from collections.abc import Generator from triton import Config as TritonConfig -class BaseConfigSingleton(type): +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = 8 + + +ConvConfig = BaseConfig + + +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): """ Thread-safe implementation of single to be used in the config heuristic subclasses to ensure heavy __init__ calls are not repeatedly run @@ -29,7 +76,7 @@ class BaseConfigSingleton(type): _lock: Lock = Lock() def __call__( - cls: BaseConfigSingleton, *args: Any, **kwargs: Any + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any ) -> BaseConfigHeuristic: with cls._lock: if cls not in cls._instances: @@ -38,12 +85,7 @@ def __call__( return cls._instances[cls] -Config = namedtuple( - "Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"] -) - - -class BaseConfigHeuristic(metaclass=BaseConfigSingleton): +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): """ Base class for mm_configs, device specific triton kernels config inherit from here """ @@ -52,36 +94,37 @@ def __init__(self) -> None: # List of dictionaries to store the kernel configs. Configs that evaluate to true # will be utilised on the target platform. The configs are as follows: # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) - self.mm_configs = [ - Config(32, 32, 16, 1, 2), - Config(32, 32, 128, 2, 4), - Config(32, 64, 32, 5, 8), - Config(64, 32, 32, 5, 8), - Config(64, 32, 128, 5, 4), - Config(64, 64, 16, 2, 4), - Config(64, 64, 32, 2, 4), - Config(64, 64, 64, 3, 8), - Config(64, 64, 128, 5, 4), - Config(64, 128, 32, 3, 4), - Config(64, 128, 32, 4, 8), - Config(64, 128, 64, 3, 4), - Config(64, 128, 128, 4, 4), - Config(128, 64, 32, 3, 4), - Config(128, 64, 32, 4, 8), - Config(128, 128, 32, 2, 8), - Config(128, 128, 32, 3, 4), - Config(128, 128, 64, 3, 4), - Config(128, 128, 64, 5, 8), + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), ] # Exhaustive search for mm configs - self.exhaustive_configs = [ - Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m) for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( [16, 32, 64, 128, 256], repeat=3 ) for num_stages in [1, 2, 3, 4, 5] for num_warps in [2, 4, 8] + for group_m in [8] ] # these are only used in tuned_mm when AutoHeuristic is enabled @@ -89,220 +132,237 @@ def __init__(self) -> None: # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 # which saves compilation time (since less configs are autotuned) and potentially increase performance # because the learned heuristic might predict a config that is not part mm_configs - self.extra_mm_configs = [ - Config(16, 32, 16, 3, 2), - Config(16, 32, 32, 4, 2), - Config(16, 32, 32, 5, 2), - Config(64, 64, 128, 3, 4), - Config(128, 64, 32, 2, 2), - Config(128, 64, 64, 3, 8), - Config(128, 64, 128, 4, 8), - Config(128, 128, 32, 4, 4), - Config(128, 128, 64, 3, 8), - Config(128, 128, 64, 5, 4), + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), ] - self.int8_mm_configs = [ - Config(64, 64, 32, 2, 4), - Config(64, 128, 32, 3, 4), - Config(128, 64, 32, 3, 4), - Config(64, 128, 32, 4, 8), - Config(128, 64, 32, 4, 8), - Config(64, 32, 32, 5, 8), - Config(32, 64, 32, 5, 8), - Config(128, 128, 32, 2, 8), - Config(64, 64, 64, 3, 8), - Config(128, 256, 128, 3, 8), - Config(256, 128, 128, 3, 8), + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), ] - self.mixed_mm_configs = [ - Config(16, 128, 256, 3, 4), - Config(16, 128, 256, 5, 8), + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), ] - self.persistent_mm_configs = [ - Config(128, 256, 64, 3, 8), - Config(128, 128, 64, 3, 8), - Config(128, 128, 128, 3, 8), - Config(128, 128, 128, 3, 4), - Config(128, 128, 64, 4, 8), + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), ] - self.scaled_mm_configs = [ - Config(128, 256, 32, 3, 8), - Config(256, 128, 32, 3, 8), - Config(256, 64, 32, 4, 4), - Config(64, 256, 32, 4, 4), - Config(128, 128, 32, 4, 4), - Config(128, 64, 32, 4, 4), - Config(64, 128, 32, 4, 4), - Config(128, 32, 32, 4, 4), - Config(64, 32, 32, 5, 2), - Config(256, 128, 128, 3, 8), - Config(256, 64, 128, 4, 4), - Config(64, 256, 128, 4, 4), - Config(128, 128, 128, 4, 4), - Config(128, 64, 64, 4, 4), - Config(64, 128, 64, 4, 4), - Config(128, 32, 64, 4, 4), - Config(64, 32, 64, 5, 2), - Config(16, 32, 32, 2, 2), - Config(16, 64, 32, 2, 2), - Config(16, 128, 32, 2, 4), - Config(16, 256, 32, 2, 4), - Config(16, 32, 64, 2, 2), - Config(16, 64, 64, 2, 2), - Config(16, 128, 64, 2, 4), - Config(16, 256, 64, 2, 4), - Config(32, 32, 32, 2, 2), - Config(32, 64, 32, 2, 2), - Config(32, 128, 32, 2, 4), - Config(32, 256, 32, 2, 4), - Config(32, 32, 64, 2, 2), - Config(32, 64, 64, 2, 2), - Config(32, 128, 64, 2, 4), - Config(32, 256, 64, 2, 4), - Config(16, 32, 32, 3, 2), - Config(16, 64, 32, 3, 2), - Config(16, 128, 32, 3, 4), - Config(16, 256, 32, 3, 4), - Config(16, 32, 64, 3, 2), - Config(16, 64, 64, 3, 2), - Config(16, 128, 64, 3, 4), - Config(16, 256, 64, 3, 4), - Config(32, 32, 32, 3, 2), - Config(32, 64, 32, 3, 2), - Config(32, 128, 32, 3, 4), - Config(32, 256, 32, 3, 4), - Config(32, 32, 64, 3, 2), - Config(32, 64, 64, 3, 2), - Config(32, 128, 64, 3, 4), - Config(32, 256, 64, 3, 4), - Config(16, 32, 32, 4, 2), - Config(16, 64, 32, 4, 2), - Config(16, 128, 32, 4, 4), - Config(16, 256, 32, 4, 4), - Config(16, 32, 64, 4, 2), - Config(16, 64, 64, 4, 2), - Config(16, 128, 64, 4, 4), - Config(16, 256, 64, 4, 4), - Config(32, 32, 32, 4, 2), - Config(32, 64, 32, 4, 2), - Config(32, 128, 32, 4, 4), - Config(32, 256, 32, 4, 4), - Config(32, 32, 64, 4, 2), - Config(32, 64, 64, 4, 2), - Config(32, 128, 64, 4, 4), - Config(32, 256, 64, 4, 4), - Config(16, 32, 32, 5, 2), - Config(16, 64, 32, 5, 2), - Config(16, 128, 32, 5, 4), - Config(16, 256, 32, 5, 4), - Config(16, 32, 64, 5, 2), - Config(16, 64, 64, 5, 2), - Config(16, 128, 64, 5, 4), - Config(16, 256, 64, 5, 4), - Config(32, 32, 32, 5, 2), - Config(32, 64, 32, 5, 2), - Config(32, 128, 32, 5, 4), - Config(32, 256, 32, 5, 4), - Config(32, 32, 64, 5, 2), - Config(32, 64, 64, 5, 2), - Config(32, 128, 64, 5, 4), - Config(32, 256, 64, 5, 4), - Config(16, 32, 32, 6, 2), - Config(16, 64, 32, 6, 2), - Config(16, 128, 32, 6, 4), - Config(16, 256, 32, 6, 4), - Config(16, 32, 64, 6, 2), - Config(16, 64, 64, 6, 2), - Config(16, 128, 64, 6, 4), - Config(16, 256, 64, 6, 4), - Config(32, 32, 32, 6, 2), - Config(32, 64, 32, 6, 2), - Config(32, 128, 32, 6, 4), - Config(32, 256, 32, 6, 4), - Config(32, 32, 64, 6, 2), - Config(32, 64, 64, 6, 2), - Config(32, 128, 64, 6, 4), - Config(32, 256, 64, 6, 4), + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), ] - self.scaled_persistent_mm_configs = [ - Config(128, 128, 64, 3, 8), - Config(128, 128, 128, 3, 8), - Config(128, 128, 128, 4, 8), - Config(128, 128, 128, 4, 4), - Config(128, 128, 128, 3, 4), - Config(128, 128, 128, 5, 4), - Config(128, 128, 128, 5, 8), - Config(128, 128, 128, 6, 8), - Config(128, 128, 64, 4, 8), + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), ] # TODO: Unify with other gemm patterns, mm_plus_mm currently follows # slightly different pattern than rest - self.mm_plus_mm_configs = [ - Config(64, 64, 32, 2, 4), - Config(64, 64, 32, 3, 8), - Config(64, 64, 32, 4, 16), - Config(64, 32, 32, 4, 8), - Config(32, 64, 32, 4, 8), - Config(128, 128, 32, 1, 8), - Config(64, 64, 64, 1, 8), - Config(32, 32, 128, 1, 8), - Config(64, 64, 16, 2, 4), - Config(32, 32, 16, 1, 2), + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), ] - self.conv_configs = [ - Config(64, 256, 16, 2, 4), - Config(256, 64, 16, 2, 4), - Config(1024, 16, 16, 1, 8), - Config(128, 128, 32, 2, 8), - Config(64, 64, 32, 2, 4), - Config(64, 256, 32, 2, 8), - Config(256, 64, 32, 2, 8), + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), ] def _finalize_mm_configs( self, - configs: list[Config], + configs: list[BaseConfig], ) -> Generator[TritonConfig, None, None]: """ Finalizes configs after scaling, applying additional constraints. """ - used = OrderedSet[Config]() + used: OrderedSet[tuple[int, ...]] = OrderedSet() max_mm_configs = config.test_configs.max_mm_configs - for block_m, block_n, block_k, num_stages, num_warps in configs: + for conf in configs: # Each warp computes a 16x16 tile = 256 elements - num_warps = min(num_warps, block_m * block_n // 256) - - if ( - Config(block_m, block_n, block_k, num_stages, num_warps) - ) not in used and (max_mm_configs is None or len(used) < max_mm_configs): - used.add(Config(block_m, block_n, block_k, num_stages, num_warps)) - yield self.triton_config( - BLOCK_M=block_m, - BLOCK_N=block_n, - BLOCK_K=block_k, - num_stages=num_stages, - num_warps=num_warps, - ) + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": num_warps, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) def _scale_mm_configs( self, m: int, n: int, k: int, - configs: Sequence[Config], + configs: list[BaseConfig], scale: float, has_int8_tensor: bool, exclude: Callable[[int, int, int], bool], - ) -> list[Config]: + ) -> list[BaseConfig]: """ Scales and filters matrix multiplication configs based on input size. """ @@ -341,7 +401,8 @@ def _scale_mm_configs( scaled_configs = [] for c in configs: - scaled_config = c._replace( + scaled_config = dataclasses.replace( + c, block_m=max(min(int(c.block_m * scale), m), min_block_size), block_n=max(min(int(c.block_n * scale), n), min_block_size), block_k=max(min(int(c.block_k * scale), k), min_block_size_k), @@ -359,7 +420,7 @@ def preprocess_mm_configs( m: int, n: int, k: int, - configs: Sequence[Config], + configs: list[BaseConfig], has_int8_tensor: bool = False, scale: int = 1, exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, @@ -430,90 +491,160 @@ def __init__(self) -> None: self.default_num_stages = get_backend_num_stages() + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + # Exhaustive search for mm configs - self.exhaustive_configs = [ - Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m, + matrix_instr_nonkdim, + waves_per_eu, + kpack, + ) for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( [16, 32, 64, 128, 256], repeat=3 ) for num_stages in [1, self.default_num_stages] for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] ] def _filter_configs( - self, configs: list[Config], new_num_stages: int - ) -> list[Config]: - filtered_configs = [ - c._replace(num_stages=self.default_num_stages) for c in configs - ] - return filtered_configs + self, configs: list[BaseConfig], new_num_stages: int + ) -> list[BaseConfig]: + # TODO: _filter_configs can be removed once backend specific configs are added + # for all methods + for c in configs: + c.num_stages = self.default_num_stages + return configs def _finalize_mm_configs( self, - configs: list[Config], + configs: list[BaseConfig], ) -> Generator[TritonConfig, None, None]: - used = OrderedSet[tuple[Config, int, int]]() + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() max_mm_configs = config.test_configs.max_mm_configs - for block_m, block_n, block_k, num_stages, num_warps in configs: - # each warp computes 16x16 tile = 256 - num_warps = min(num_warps, block_m * block_n // 256) - kpack = 2 - for matrix_instr_nonkdim in [0, 16]: - if matrix_instr_nonkdim != 0 and ( - block_m % matrix_instr_nonkdim != 0 - or block_n % matrix_instr_nonkdim != 0 - ): - # block_m and block_n must be a multiple of matrix_instr_nonkdim - continue - if ( - Config( - block_m, - block_n, - block_k, - num_stages, - num_warps, - ), - matrix_instr_nonkdim, - kpack, - ) not in used and ( - max_mm_configs is None or len(used) < max_mm_configs - ): - used.add( - ( - Config( - block_m, - block_n, - block_k, - num_stages, - num_warps, - ), - matrix_instr_nonkdim, - kpack, - ) - ) - - yield self.triton_config( - BLOCK_M=block_m, - BLOCK_N=block_n, - BLOCK_K=block_k, - num_stages=num_stages, - num_warps=num_warps, - matrix_instr_nonkdim=matrix_instr_nonkdim, - kpack=kpack, - ) - def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) - def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.exhaustive_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: filtered_configs = self._filter_configs( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index e93ed88bcbda..bca3f024d134 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1377,7 +1377,7 @@ def _is_tma_compatible(x: IRNode) -> bool: return False dtype = x.get_dtype() - if dtype not in (torch.float16, torch.bfloat16): + if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): return False layout = x.get_layout() @@ -1388,6 +1388,12 @@ def _is_tma_compatible(x: IRNode) -> bool: inner_dim = layout.size[1] if transposed: inner_dim = layout.size[0] + + if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + inner_dim, 32 + ): + return False + inner_bytes = inner_dim * dtype.itemsize return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 5c8c713b6e42..3f3e9295549b 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -105,9 +105,7 @@ def backward(ctx, *grads): # The dispatcher passes any keyword-only-args as kwargs and the # rest of the args (even if specified as kwargs) as args. def autograd_impl(keyset, *args, **keyword_only_args): - if _C.is_grad_enabled() and _pytree.tree_any_only( - Tensor, lambda x: x.requires_grad, args, not_list_of_tensor - ): + if _C.is_grad_enabled() and _C._any_requires_grad(*args): result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] else: result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 66aeccc58a0c..544bbbf61582 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -347,9 +347,10 @@ def get_module(): fn = self._backend_fns[device_type] return inspect.getmodule(fn) - utils.check_aliasing_constraint( + utils._c_check_aliasing_constraint( self._name, - utils.iter_tensors(args, kwargs), + args, + kwargs, result, get_module, ) diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 8348883cee30..908280ecf292 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -373,6 +373,31 @@ def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"): storages.add(key) +def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"): + """ + custom operators' outputs must not have any aliases + This version uses C++ implementation for perf. + Only List container is supported. + Tensors in Lists with not only Tensors are checked. + """ + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result): + raise RuntimeError( + f"{name} (with implementation in {get_module()}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + + class MutationChecker: """ Check if an operator mutated its arguments. diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 1141484db6aa..bd0a2f7b9728 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2152,7 +2152,7 @@ def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> @register_meta([aten.baddbmm.default, aten.baddbmm.out]) -@out_wrapper() +@out_wrapper(exact_dtype=True) def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(0) dim2 = batch1.size(1) @@ -2227,7 +2227,7 @@ def meta__fused_moving_avg_obs_fq_helper( @register_meta(aten.mm) -@out_wrapper() +@out_wrapper(exact_dtype=True) def meta_mm(a, b): torch._check(a.dim() == 2, lambda: "a must be 2D") torch._check(b.dim() == 2, lambda: "b must be 2D") @@ -2508,7 +2508,8 @@ def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size): ) @register_meta(torch.ops.onednn.qconv2d_pointwise.default) - def meta_qconv2d_pointwise( + @register_meta(torch.ops.onednn.qconv_pointwise.default) + def meta_qconv_pointwise( x, x_scale, x_zp, @@ -2539,7 +2540,9 @@ def meta_qconv2d_pointwise( ) assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] out = x.new_empty(shape_out, dtype=output_dtype) - out = out.to(memory_format=torch.channels_last) + assert len(shape_out) in [3, 4], "only conv1d/2d are supported" + format = torch.channels_last if len(shape_out) == 4 else torch.contiguous_format + out = out.to(memory_format=format) return out @register_meta(torch.ops.onednn.qconv2d_pointwise.binary) @@ -3460,7 +3463,7 @@ def meta_convolution_backward( @register_meta([aten.addbmm.default, aten.addbmm.out]) -@out_wrapper() +@out_wrapper(exact_dtype=True) def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) @@ -3636,6 +3639,21 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros): return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) +@register_meta([aten._weight_int4pack_mm_with_scales_and_zeros]) +def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + def kai_roundup(a: int, b: int) -> int: return ((a + b - 1) // b) * b @@ -6182,12 +6200,13 @@ def meta_scaled_mm( out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): - def is_fp8_type(dtype): + def is_fp8_or_fp4_type(dtype): return dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz, + torch.float4_e2m1fn_x2, ) torch._check( @@ -6195,8 +6214,8 @@ def is_fp8_type(dtype): lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", ) torch._check( - is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), - lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", + is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype), + lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) if device_hint(self) == "cuda": @@ -6232,18 +6251,32 @@ def has_zero_dim(tensor_2d): m, _k = self.shape n = mat2.size(1) + is_blockwise_scaling = ( + scale_a.dtype == torch.float8_e8m0fnu + and scale_b.dtype == torch.float8_e8m0fnu + ) or ( + scale_a.dtype == torch.float8_e4m3fn + and scale_b.dtype == torch.float8_e4m3fn + ) + if scale_a.numel() == 1 and scale_b.numel() == 1: # tensorwise scaling torch._check( scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.", ) - elif ( - scale_a.dtype == torch.float8_e8m0fnu - and scale_b.dtype == torch.float8_e8m0fnu - ): + elif is_blockwise_scaling: # blockwise scaling - block_size_k = 32 + + if scale_a.dtype == torch.float8_e4m3fn: + # NVIDIA's nvfp4 recipe: + # * block size is 16 elements packed (32 unpacked) + # * _k needs to be translated to the unpacked version + block_size_k = 16 + _k = _k * 2 + else: + block_size_k = 32 + block_size_mn = 128 def ceil_div(a, b): diff --git a/torch/_ops.py b/torch/_ops.py index c6f5be583e41..4e4c346e25eb 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -470,6 +470,26 @@ def wrapper(): return wrapper() + # NOTE [HigherOrderOprator Schema] + # Each invocation of a HigherOrderOperator (hop) should have its own schema because + # the subgraphs and the arguments can be different even for the same hop. + # + # Each hop should implement its own gen_schema method, which should + # take the same input as the __call__ method and returns a FunctionSchema. + # The schema provides a unified way to check if the hop mutates its inputs, + # which can be useful in implementing optimizations. + # + # If the hop doesn't implement the gen_schema method, + # we expect it to be functional. It should not mutate its inputs and there + # are no input, output aliasing via views or direct referencing. + def gen_schema(self, *args, **kwargs): + raise NotImplementedError( + f"HigherOrderOperator {self._name} does not implement a gen_schema. " + f"This is OK as long as the hop is functional. " + f"e.g. it should not mutate its inputs and there are no input, output aliasing " + f"via views or direct referencing." + ) + def __str__(self): return f"{self.name()}" @@ -1086,7 +1106,7 @@ def _schemas(self): for overload_name in self._overload_names } - def __getattr__(self, key): + def __getattr__(self, key) -> Any: # It is not a valid op_name when __file__ is passed in if key == "__file__": return "torch.ops" @@ -1246,7 +1266,7 @@ def __init__(self, name): def __iter__(self): return iter(self._dir) - def __getattr__(self, op_name): + def __getattr__(self, op_name) -> Any: # It is not a valid op_name when __file__ is passed in if op_name == "__file__": return "torch.ops" diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 3cf2a0b52146..c9080a01ede3 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -6301,7 +6301,7 @@ def wrapper(self, other): @register_decomposition(aten.dot) -@out_wrapper() +@out_wrapper(exact_dtype=True) @_dot_check_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("self", "other"), @@ -6321,7 +6321,7 @@ def dot(self, other): @register_decomposition(aten.vdot) -@out_wrapper() +@out_wrapper(exact_dtype=True) @_dot_check_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("self", "other"), diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 00d95445c6f3..c85962f22842 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -151,6 +151,22 @@ def vector_norm( reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if (dim is None and x.numel() == 1) or ( + dim is not None and (x.ndim > 0 and all(x.shape[d] == 1 for d in dim)) + ): + if x.ndim > 64: + raise RuntimeError( + f"Received a tensor with {x.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + x = torch.abs(x) + if keepdim or x.ndim == 0: + return to_result_dtype(x).contiguous() + elif dim is None: + return x.flatten()[0] + else: + new_shape = [s for d, s in enumerate(x.shape) if d not in dim] + return to_result_dtype(x.view(new_shape)).contiguous() + if not (is_ord_even and utils.is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index bc7bc1ba7f82..9d85bf4c77b3 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -890,7 +890,9 @@ def infer_size(a, b): return tuple(expandedSizes) -def make_fast_binary_impl(slow_ref): +def make_fast_binary_impl( + slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +): def fast_binary_impl(mode, *args, **kwargs): def slow(msg): count_label(f"slow {msg}") @@ -957,7 +959,7 @@ def slow(msg): # compute promotion # TODO: we don't need the compute type _, common_dtype = elementwise_dtypes( - *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + *operands, type_promotion_kind=type_promotion_kind ) # check all tensors on same device @@ -1042,7 +1044,10 @@ def get_fast_op_impls(): ) register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] register_fast_op_impl(torch.ops.aten.div.Tensor)( - make_fast_binary_impl(torch._refs.div) + make_fast_binary_impl( + torch._refs.div, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) ) register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach) return FAST_OP_IMPLEMENTATIONS diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index b1c52f7e1bdf..f8cb248e0a60 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1515,14 +1515,15 @@ def _validate_cache_key( for node in subgraph_mod.graph.nodes: if node.op == "call_function": op = node.target - # Dynamo graphs can have operator.add type of operations. For these operations, it is safe to cache. - if ( - callable(op) - and getattr(op, "__module__", None) - in {"_operator", "operator"} - and not op.__name__.startswith("i") - ): - continue + + # AOTDispatcher first pass does not run make_fx on + # dynamo graphs. As a result, it can have non OpOverload + # ops. + if not isinstance(op, torch._ops.OpOverload): + raise _BypassDispatchCache( + f"{func.name()} hop with a non OpOverload input" + ) + try: self._validate_cache_key(op, [], {}) except _BypassDispatchCache as e: @@ -2307,10 +2308,9 @@ def maybe_to_real_tensor( if ( self.propagate_real_tensors and all(e.real_tensor is not None for e in flat_arg_fake_tensors) - # TODO: Handle SymFloat/SymBool and not any( ( - isinstance(a, SymInt) + isinstance(a, py_sym_types) and (syms := free_unbacked_symbols(a)) and self.shape_env is not None and any(s not in self.shape_env.unbacked_var_to_val for s in syms) diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index fb272adc7ea3..368e30246091 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -260,9 +260,12 @@ def tolist(self) -> Any: def to(self, *args, **kwargs): if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: - # If copy is specified as pos arg, it's always the second one. - if len([arg for arg in args if isinstance(arg, bool)]) <= 1: - return super().to(*args, **{**kwargs, "copy": True}) + torch.ops.aten._assert_tensor_metadata( + self, + dtype=self.dtype, + device=self.device, + layout=self.layout, + ) return super().to(*args, **kwargs) def cuda(self, device=None, *args, **kwargs): @@ -354,23 +357,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if self.export: - # We need to make sure that we don't decompose to() as usual in export mode, - # because it can get optimized away. Instead we always replace it with _to_copy(). - if func == torch.ops.aten.to.dtype_layout: - kwargs.pop("copy", None) - return self.__torch_dispatch__( - torch.ops.aten._to_copy.default, types, args, kwargs - ) - if func == torch.ops.aten.to.dtype: - schema = tuple(arg.name for arg in func._schema.arguments) - for arg, name in zip(args[1:], schema[1:]): - kwargs[name] = arg - kwargs.pop("copy", None) - return self.__torch_dispatch__( - torch.ops.aten._to_copy.default, types, args[:1], kwargs - ) - unrecognized_types = [ t for t in types @@ -527,36 +513,10 @@ def unwrap(x): *args_unwrapped, **kwargs_unwrapped, ) - # We don't allow any mutation on result of dropout or _to_copy + if self.export: - if func in ( - torch.ops.aten.dropout.default, - torch.ops.aten._to_copy.default, - ): - - def must_copy(): - """ - Return True if the output of the op must be copied, not an alias - """ - # output dtype is different from input - return ( - func == torch.ops.aten._to_copy.default - and "dtype" in kwargs - and kwargs["dtype"] != args_unwrapped[0].dtype - ) - - # `args_unwrapped` might be a tensor constant, not a functional tensor. - if must_copy() and torch._is_functional_tensor( - args_unwrapped[0] - ): - # We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think - # we have an aten op for that. - torch.ops.aten._assert_tensor_metadata.default( - torch._from_functional_tensor(args_unwrapped[0]), - dtype=args_unwrapped[0].dtype, - ) - else: - torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] + if func == torch.ops.aten.dropout.default: + torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] outs_wrapped = pytree.tree_map_only( torch.Tensor, wrap, outs_unwrapped ) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 9e1956763242..076491993d46 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -6,7 +6,7 @@ from torch._torch_docs import parse_kwargs, reproducibility_notes -def add_docstr_all(method, docstr): +def add_docstr_all(method: str, docstr: str) -> None: add_docstr(getattr(torch._C.TensorBase, method), docstr) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 59fcc6213f30..1e5d7d340f5a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -11488,8 +11488,8 @@ def merge_dicts(*dicts): Default: if not provided, 0. Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, ``torch.long``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. {device} layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. @@ -11613,8 +11613,8 @@ def merge_dicts(*dicts): Default: if not provided, 0. Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, ``torch.long``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. {device} layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. diff --git a/torch/accelerator/_utils.py b/torch/accelerator/_utils.py index 3a29acd240cd..730f2a82543d 100644 --- a/torch/accelerator/_utils.py +++ b/torch/accelerator/_utils.py @@ -16,7 +16,7 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int: raise RuntimeError("Accelerator expected") if acc.type != device.type: raise ValueError( - f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}." + f"{device.type} doesn't match the current accelerator {acc}." ) device_index = device.index if device_index is None: diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index 93b1d667c08a..2931b5b9fadd 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -336,7 +336,11 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None: # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.double().reciprocal().float() + inv_scale = ( + self._scale.double().reciprocal().float() + if self._scale.device != torch.device("mps:0") + else self._scale.reciprocal() + ) found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) optimizer_state["found_inf_per_device"] = self._unscale_grads_( diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index a3672b5cb01d..673d52e8924e 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -8,6 +8,7 @@ the values observed during calibration (PTQ) or training (QAT). """ +import operator import re import warnings from abc import ABCMeta, abstractmethod @@ -24,6 +25,7 @@ is_per_tensor, validate_qmin_qmax, ) +from torch.fx import Node __all__ = [ @@ -491,6 +493,7 @@ class MinMaxObserver(UniformQuantizationObserverBase): .. note:: If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0. """ + min_val: torch.Tensor max_val: torch.Tensor @@ -700,6 +703,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase): .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ + min_val: torch.Tensor max_val: torch.Tensor @@ -995,6 +999,7 @@ class HistogramObserver(UniformQuantizationObserverBase): 3. Compute the scale and zero point the same way as in the :class:`~torch.ao.quantization.MinMaxObserver` """ + histogram: torch.Tensor min_val: torch.Tensor max_val: torch.Tensor @@ -1522,6 +1527,7 @@ class RecordingObserver(ObserverBase): qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit """ + __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} def __init__(self, dtype=torch.quint8): @@ -1788,7 +1794,7 @@ def get_block_size( ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" return (1, granularity.group_size) elif isinstance(granularity, PerToken): - block_size = list(input_shape) + block_size = [1] * len(input_shape) block_size[-1] = input_shape[-1] return tuple(block_size) raise ValueError(f"Unsupported Granularity: {granularity}") @@ -1850,6 +1856,84 @@ def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: and returns a tuple of scale and zero_point Tensor """ + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + """ + Converts the observer node in the graph into its quantized representation + + Args: + model: graph module to conver the observer node in + observer_node: the observer node to convert + """ + from torch.ao.quantization.fx.utils import create_getattr_from_value + + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert ( + self.original_dtype is not None + ), "Expecting original_dtype to be populated" + if hasattr(self, "is_dynamic") and self.is_dynamic: + choose_qparams_affine = model.graph.call_function( + torch.ops.pt2e_quant.choose_qparams_affine, + ( + observer_node.args[0], + self.mapping_type.name, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain.name, + ), + ) + scale_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 0) + ) + zero_point_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 1) + ) + else: + scale, zero_point = self.calculate_qparams() + scale_node = create_getattr_from_value( + model, model.graph, "_scale", scale + ) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + + q_node = model.graph.call_function( + torch.ops.pt2e_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.pt2e_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + def _is_observer_script_module(mod, obs_type_name): """Returns true if given mod is an instance of Observer script module.""" diff --git a/torch/ao/quantization/pt2e/_affine_quantization.py b/torch/ao/quantization/pt2e/_affine_quantization.py index 70ad5c0cde89..32b4a773f28f 100644 --- a/torch/ao/quantization/pt2e/_affine_quantization.py +++ b/torch/ao/quantization/pt2e/_affine_quantization.py @@ -9,11 +9,11 @@ from torch.ao.quantization.observer import ( AffineQuantizedObserverBase, get_block_size, + Granularity, MappingType, TorchAODType, ZeroPointDomain, ) -from torch.fx import Node ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: @@ -729,47 +729,138 @@ def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: self.zero_point_domain, ) - def convert(self, model: torch.fx.GraphModule, observer_node: Node): - print("calling convert") - from torch.ao.quantization.fx.utils import create_getattr_from_value - scale, zero_point = self.calculate_qparams() - with model.graph.inserting_before(observer_node): - assert self.block_size is not None, "Expecting block_size to be populated" - assert ( - self.original_dtype is not None - ), "Expecting original_dtype to be populated" - scale_node = create_getattr_from_value(model, model.graph, "_scale", scale) - zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point - ) - q_node = model.graph.call_function( - torch.ops.pt2e_quant.quantize_affine, - ( - observer_node.args[0], - self.block_size, - scale_node, - zero_point_node, - self.target_dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain.name, - ), - {}, - ) - dq_node = model.graph.call_function( - torch.ops.pt2e_quant.dequantize_affine, - ( - q_node, - self.block_size, - scale_node, - zero_point_node, - self.target_dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain.name, - ), - {"output_dtype": self.original_dtype}, +class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + averaging_constant=0.01, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + is_dynamic=False, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" ) - observer_node.replace_all_uses_with(dq_node) - model.graph.erase_node(observer_node) + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + assert self.granularity is not None, "granularity is None" + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert ( + self.min_val.shape == min_val.shape + ), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + assert ( + self.max_val.shape == max_val.shape + ), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + min_val = self.min_val + self.averaging_constant * (min_val - self.min_val) + max_val = self.max_val + self.averaging_constant * (max_val - self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr( + self, "max_val" + ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +class AffineQuantizedPlaceholderObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + is_dynamic=False, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input): + self.block_size = get_block_size(input.shape, self.granularity) + self.original_dtype = input.dtype + return input + + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) diff --git a/torch/ao/quantization/pt2e/lowering.py b/torch/ao/quantization/pt2e/lowering.py new file mode 100644 index 000000000000..587cee22560d --- /dev/null +++ b/torch/ao/quantization/pt2e/lowering.py @@ -0,0 +1,60 @@ +import torch +from torch._inductor.constant_folding import constant_fold +from torch._inductor.fx_passes.freezing_patterns import freezing_passes + + +__all__ = [ + "lower_pt2e_quantized_to_x86", +] + + +def lower_pt2e_quantized_to_x86( + model: torch.fx.GraphModule, + example_inputs: tuple[torch.Tensor, ...], +) -> torch.fx.GraphModule: + """Lower a PT2E-qantized model to x86 backend. + + Args: + * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. + * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. + + Return: + A GraphModule lowered to x86 backend. + """ + + def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] + decomp_table = torch.export.default_decompositions() + + # if we are post-autograd, we shouldn't + # decomp prim ops. + for k in list(decomp_table.keys()): + if not torch._export.utils._is_cia_op(k): + del decomp_table[k] + + return decomp_table + + def _node_replace(m): # type: ignore[no-untyped-def] + # Replace aten.t(x) with aten.permute(x, [1, 0]) + aten = torch.ops.aten + g = m.graph + for node in g.nodes: + if node.target == aten.t.default: + with g.inserting_before(node): + x = node.args[0] + dims = [1, 0] + perm_node = g.call_function(aten.permute.default, args=(x, dims)) + node.replace_all_uses_with(perm_node) + g.erase_node(node) + + g.lint() + m.recompile() + + lowered_model = ( + torch.export.export_for_training(model, example_inputs, strict=True) + .run_decompositions(_post_autograd_decomp_table()) + .module() + ) + _node_replace(lowered_model) + freezing_passes(lowered_model, example_inputs) + constant_fold(lowered_model) + return lowered_model diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index b0946d0075c9..0c96f915306d 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -27,17 +27,20 @@ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, ] _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.dequantize_affine, ] _CHOOSE_QPARAMS_OPS = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.pt2e_quant.choose_qparams_affine, ] diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index ed3b30552a1f..ae23b43b9cb0 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Optional import torch +from torch._export.utils import _disable_aten_to_metadata_assertions from torch._higher_order_ops.out_dtype import out_dtype from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.ao.quantization.pt2e.export_utils import _WrapperModule @@ -798,22 +799,23 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: remove_tensor_overload_for_qdq_ops(model) - for rewrite_info in _REWRITE_INFO_LIST: - example_inputs = rewrite_info.example_inputs - pattern = rewrite_info.pattern - replacement = rewrite_info.replacement - pattern_post_trans = rewrite_info.pattern_post_trans - replacement_post_trans = rewrite_info.replacement_post_trans - pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] - remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] - replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] - remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] - if pattern_post_trans: - pattern = pattern_post_trans(pattern) - if replacement_post_trans: - replacement = replacement_post_trans(replacement) - pattern.recompile() # type: ignore[attr-defined] - replacement.recompile() # type: ignore[attr-defined] - replace_pattern(model, pattern, replacement) + with _disable_aten_to_metadata_assertions(): + for rewrite_info in _REWRITE_INFO_LIST: + example_inputs = rewrite_info.example_inputs + pattern = rewrite_info.pattern + replacement = rewrite_info.replacement + pattern_post_trans = rewrite_info.pattern_post_trans + replacement_post_trans = rewrite_info.replacement_post_trans + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] + if pattern_post_trans: + pattern = pattern_post_trans(pattern) + if replacement_post_trans: + replacement = replacement_post_trans(replacement) + pattern.recompile() # type: ignore[attr-defined] + replacement.recompile() # type: ignore[attr-defined] + replace_pattern(model, pattern, replacement) return model diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 47e939f7596a..86304247d151 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -355,6 +355,7 @@ def _get_aten_graph_module_for_pattern( pattern, # type: ignore[arg-type] example_inputs, kwargs, + strict=True, ).module() aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 25a5dfc4a193..3f91c2ddd13b 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -84,6 +84,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): # Operators support the int8 data type # and recipe is configured by default in X86InductorQuantizer. default_quantizable_ops = propagation_quantizable_ops | { + torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, } @@ -185,6 +186,7 @@ def _global_config_filter(nodes: list[Node]) -> bool: def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( + ([torch.nn.Conv2d, F.conv1d], torch.ops.aten.conv1d.default), ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default), @@ -1156,6 +1158,7 @@ def _annotate_conv2d_unary( [torch.nn.Conv2d, torch.nn.Hardswish], [torch.nn.Conv2d, torch.nn.ReLU6], [torch.nn.Conv2d, torch.nn.SiLU], + [torch.nn.Conv1d, torch.nn.ReLU], ] for unary_pattern in unary_patterns: partitions = find_sequential_partitions(gm, unary_pattern) @@ -1168,9 +1171,9 @@ def _annotate_conv2d_unary( conv_node, unary_node = self._get_output_nodes_of_partitions( [conv_partition, unary_partition] ) - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default + if conv_node.op != "call_function" or conv_node.target not in ( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, ): continue if _skip_annotate([unary_node, conv_node], filter_fn): diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index 426523865296..8fcb64beba3b 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import os -from collections import namedtuple -from typing import Any +from typing import Any, NamedTuple, Optional import torch @@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None): return torch._VF._make_dual(tensor, tangent, level=level) -_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"]) - - -class UnpackedDualTensor(_UnpackedDualTensor): +class UnpackedDualTensor(NamedTuple): r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor. See :func:`unpack_dual` for more details. - """ + primal: torch.Tensor + tangent: Optional[torch.Tensor] + def unpack_dual(tensor, *, level=None): r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient. diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 50745586ca63..0a8b9d1e29d3 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -629,7 +629,7 @@ def _device_memory_usage(mem_record): ) max_evt_id = max(max_evt_id, fe.id) if fe.device_type == DeviceType.CPU and not fe.is_async: - if self.use_device == "privateuseone": + if self.use_device == _get_privateuse1_backend_name(): privateuse1_time = kineto_event.privateuse1_elapsed_us() if privateuse1_time > 0: fe.append_kernel(fe.name, fe.device_index, privateuse1_time) diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 321ededbb24a..aa6a27a3dcc3 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -228,7 +228,7 @@ def assume_constant_result(fn): return torch._dynamo.assume_constant_result(fn) -def disable(fn=None, recursive=True): +def disable(fn=None, recursive=True, *, reason=None): """ This function provides a decorator to disable compilation on a function. It also provides the option of recursively disabling called functions. @@ -236,10 +236,11 @@ def disable(fn=None, recursive=True): Args: fn (optional): The function to disable recursive (optional): A boolean value indicating whether the disabling should be recursive. + reason (optional): A string value indicating the reason for disabling the function. """ import torch._dynamo - return torch._dynamo.disable(fn, recursive) + return torch._dynamo.disable(fn, recursive, reason=reason) def set_stance( diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index c67953fc45e2..3612e94a0d02 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -965,10 +965,14 @@ static PyObject* THPModule_setAllowTF32OneDNN( static PyObject* THPModule_allowTF32OneDNN( PyObject* _unused, PyObject* noargs) { +#ifdef USE_XPU if (at::globalContext().allowTF32OneDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; +#else + Py_RETURN_NONE; +#endif } static PyObject* THPModule_deterministicAlgorithms( diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index e5ee41e6fd56..498259c8fa12 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1450,6 +1450,62 @@ Tensor mm_mat2_backward( return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj()); } +Tensor _grouped_mm_mat1_backward( + const Tensor& grad, + const Tensor& mat2, + at::SymIntArrayRef mat1_sizes, + at::SymIntArrayRef mat1_strides, + c10::Layout mat1_layout, + std::optional offs, + const Scalar& alpha) { + TORCH_CHECK( + grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && + mat1_layout == c10::kStrided, + "only strided layout supported for grouped mm"); + // if input was column-major, return grad as column-order for efficiency + if (offs.has_value() && !offs->defined()) { + offs = std::nullopt; + } + auto mat1_dim = mat1_sizes.size(); + if (mat1_strides[mat1_dim - 2] == 1 && + mat1_strides[mat1_dim - 1] == mat1_sizes[mat1_dim - 2]) { + auto grad_inp = + (at::_grouped_mm(mat2, grad.transpose(-2, -1), offs)).transpose(-2, -1); + return maybe_multiply(grad_inp, alpha.conj()); + } else { + auto grad_inp = (at::_grouped_mm(grad, mat2.transpose(-2, -1), offs)); + return maybe_multiply(grad_inp, alpha.conj()); + } +} + +Tensor _grouped_mm_mat2_backward( + const Tensor& grad, + const Tensor& mat1, + at::SymIntArrayRef mat2_sizes, + at::SymIntArrayRef mat2_strides, + c10::Layout mat2_layout, + std::optional offs, + const Scalar& alpha) { + TORCH_CHECK( + grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided && + mat2_layout == c10::kStrided, + "only strided layout supported for grouped mm"); + // if input was column-major, return grad as column-order for efficiency + auto mat2_dim = mat2_sizes.size(); + if (offs.has_value() && !offs->defined()) { + offs = std::nullopt; + } + if (mat2_strides[mat2_dim - 2] == 1 && + mat2_strides[mat2_dim - 1] == mat2_sizes[mat2_dim - 2]) { + auto grad_inp = + at::_grouped_mm(grad.transpose(-2, -1), mat1, offs).transpose(-2, -1); + return maybe_multiply(grad_inp, alpha.conj()); + } else { + auto grad_inp = at::_grouped_mm(mat1.transpose(-2, -1), grad, offs); + return maybe_multiply(grad_inp, alpha.conj()); + } +} + Tensor mm_mat1_sparse_backward( const Tensor& grad, const Tensor& mat1, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 8d01a80eb406..1bbad0ae92dd 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -306,6 +306,22 @@ at::Tensor mm_mat2_backward( at::SymIntArrayRef strides, c10::Layout layout, const at::Scalar& alpha); +at::Tensor _grouped_mm_mat1_backward( + const Tensor& grad, + const Tensor& mat2, + at::SymIntArrayRef mat1_sizes, + at::SymIntArrayRef mat1_strides, + c10::Layout mat1_layout, + std::optional offs, + const Scalar& alpha); +at::Tensor _grouped_mm_mat2_backward( + const at::Tensor& grad, + const at::Tensor& mat1, + at::SymIntArrayRef sizes, + at::SymIntArrayRef strides, + c10::Layout layout, + std::optional offs, + const at::Scalar& alpha); at::Tensor mm_mat1_sparse_backward( const at::Tensor& grad, const at::Tensor& mat1, diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 6eb3cdcdbdfc..b376c295b77a 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -955,6 +955,133 @@ static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { END_HANDLE_TH_ERRORS } +template +static bool visit( + PyObject* o, + const std::function& visit_tensor) { + if (THPVariable_Check(o)) { + auto t = THPVariable_Unpack(o); + if (visit_tensor(t)) { + return true; + } + } else if (PyList_Check(o)) { + // Check that this List is TensorList + if constexpr (skip_tensors_in_non_tensorlist) { + for (const auto i : c10::irange(PyList_GET_SIZE(o))) { + if (!THPVariable_Check(PyList_GET_ITEM(o, i))) { + return false; + } + } + } + for (const auto i : c10::irange(PyList_GET_SIZE(o))) { + if (visit( + PyList_GET_ITEM(o, i), visit_tensor)) { + return true; + }; + } + } + return false; +} + +// Visiting of tensors in args and kwargs, +// only List container is visited. +// skip_tensors_in_non_tensorlist will skip any List with non-Tensor. +// Lambda returning true means short circuit, traversal stops after that. +template +static void visit_tensors( + PyObject* args, + PyObject* kwargs, + const std::function& visit_tensor) { + if (args && PyTuple_Check(args)) { + for (const auto i : c10::irange(PyTuple_GET_SIZE(args))) { + if (visit( + PyTuple_GET_ITEM(args, i), visit_tensor)) { + return; + } + } + } + if (kwargs && PyDict_Check(kwargs)) { + auto vals = THPObjectPtr{PyDict_Values(kwargs)}; + for (const auto i : c10::irange(PyList_Size(vals))) { + if (visit( + PyList_GetItem(vals, i), visit_tensor)) { + return; + } + } + } +} + +// Returns true if any of the args, kwargs tensor leaves have requires_grad. +// Only List[Tensor] container in args is supported. +static PyObject* any_requires_grad( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + bool has_requires_grad = false; + visit_tensors(args, kwargs, [&has_requires_grad](at::Tensor& t) { + if (t.requires_grad()) { + has_requires_grad = true; + return true; + } + return false; + }); + if (has_requires_grad) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; + END_HANDLE_TH_ERRORS +} + +// Checks aliasing constraint for custom ops: +// Returns true if any of outputs is alias to any of inputs or another output +// Args: +// args[0] - inputs args +// args[1] - inputs kwargs +// args[2] - outputs +// Only List container is supported. +// Tensors in Lists that has not only Tensor are checked. +static PyObject* any_output_is_alias_to_input_or_output( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* inps = PyTuple_GET_ITEM(args, 0); + PyObject* inps_kwargs = PyTuple_GET_ITEM(args, 1); + PyObject* outs = PyTuple_GET_ITEM(args, 2); + std::unordered_set s; + visit_tensors(inps, inps_kwargs, [&s](at::Tensor& t) { + if (!t.storage()) { + return false; + } + auto* cp = t.storage().data_ptr().get_context(); + if (cp) { + s.insert(cp); + } + return false; + }); + bool ret = false; + visit_tensors(outs, nullptr, [&s, &ret](at::Tensor& t) { + if (!t.storage()) { + return false; + } + auto* cp = t.storage().data_ptr().get_context(); + if (!cp) { + return false; + } + if (s.find(cp) != s.end()) { + ret = true; + return true; + } + s.insert(cp); + return false; + }); + if (ret) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; + END_HANDLE_TH_ERRORS +} + static PyObject* set_multithreading_enabled( PyObject* self, PyObject* args, @@ -1326,6 +1453,14 @@ static PyMethodDef methods[] = { nullptr}, {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr}, {"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr}, + {"_any_requires_grad", + castPyCFunctionWithKeywords(any_requires_grad), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_any_output_is_alias_to_input_or_output", + any_output_is_alias_to_input_or_output, + METH_VARARGS, + nullptr}, {"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr}, {"is_inference_mode_enabled", is_inference_mode_enabled, diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 2b1e6f2e0104..447ca88f0e84 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -8,7 +8,6 @@ #include #include #include - #include #include #include @@ -21,8 +20,6 @@ #include #include -#include - #include #include @@ -860,6 +857,22 @@ std::unique_ptr disableProfiler() { return result; } +namespace tracer = torch::profiler::impl::python_tracer; +std::unique_ptr memory_tracer; +void startMemoryProfile() { + if (memory_tracer == nullptr) { + memory_tracer = tracer::PythonMemoryTracerBase::make(); + } + memory_tracer->start(); +} + +void stopMemoryProfile() { + memory_tracer->stop(); +} + +void exportMemoryProfile(const std::string& filename) { + memory_tracer->export_memory_history(filename); +} KinetoEvent::KinetoEvent( const std::shared_ptr& result, diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index cedf58123381..2e4b89da4b79 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -185,6 +185,10 @@ TORCH_API void toggleCollectionDynamic( const bool enable, const std::set& activities); +TORCH_API void startMemoryProfile(); +TORCH_API void stopMemoryProfile(); +TORCH_API void exportMemoryProfile(const std::string& path); + /** * When a C++ thread really has no control over how the profiler was enabled, * for example, by some unreachable Python code, it can call these functions diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index a98d1a8b7934..acbc7bdc0d16 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -705,10 +706,13 @@ class PythonTracer final : public python_tracer::PythonTracerBase { void recordCCall( ThreadLocalResults& tls, PyFrameObject* frame, - PyObject* arg); + PyObject* arg, + bool start_frame = false); const std::vector interpreterThreads() const; + PyObject* get_callable_from_frame(PyFrameObject* frame); + std::atomic active_lock_{false}; bool active_{false}; @@ -787,6 +791,16 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { recordPyCall(thread_local_results_.back(), it->get(), true); + PyFrameObject* frame = it->get(); + PyObject* callable = get_callable_from_frame(frame); + if (callable) { + // If the frame has a callable, record it as a C call since + // PyEval_GetFrame only gets the python frame. We need to record this C + // call so that when exiting the profiler we don't have a mismatched C + // call. + recordCCall(thread_local_results_.back(), it->get(), callable, true); + } + auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -890,8 +904,13 @@ void PythonTracer::recordPyCall( void PythonTracer::recordCCall( ThreadLocalResults& tls, PyFrameObject* frame, - PyObject* arg) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg)); + PyObject* arg, + bool start_frame) { + // for starting frames we duplicate callable python functions to avoid having + // empty C frames in trace when exiting + if (!start_frame) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg)); + } auto fn = reinterpret_cast(arg); // NB: For C calls a new frame is not created, so we use `frame` rather than @@ -901,6 +920,26 @@ void PythonTracer::recordCCall( queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); } +PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) { + if (frame == nullptr) { + return nullptr; + } + // Get the code object associated with the frame + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + if (code == nullptr) { + return nullptr; + } + // Get the function name (if needed) + auto name = THPUtils_unpackStringView(code->co_name).data(); + // To get the function object, you will need to look in the globals or the + // frame's f_globals + PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name); + if (func) { + Py_INCREF(func); // Make sure the returned function has a reference + } + return func; // Returns a PyObject* (the function) +} + // ============================================================================ // == Post processing ========================================================= // ============================================================================ @@ -983,9 +1022,13 @@ class PostProcess { using stack_t = std::vector>; const auto initial_size = out.size(); auto pop = [](stack_t& stack, c10::time_t t) { - TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty."); - std::get>(stack.back()->extra_fields_).end_time_ns_ = t; - stack.pop_back(); + if (!stack.empty()) { + std::get>(stack.back()->extra_fields_).end_time_ns_ = t; + stack.pop_back(); + } else { + TORCH_WARN_ONCE( + "Python replay stack is empty during pop operation! May result in incorrect stack tracing."); + } }; ska::flat_hash_map stacks; @@ -1102,6 +1145,78 @@ std::vector> PythonTracer::getEvents( return out; } +// ============================================================================ +// == Memory Tracer ====================================================== +// ============================================================================ + +// Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere +class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase { + public: + explicit PythonMemoryTracer() = default; + ~PythonMemoryTracer() override = default; + void start() override; + void stop() override; + void export_memory_history(const std::string path) override; +}; + +static void toggle_memory_tracing(bool enable) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + THPObjectPtr torch_cuda_memory_module( + PyImport_ImportModule("torch.cuda.memory")); + if (!torch_cuda_memory_module) { + return; + } + THPObjectPtr snapshot_func(PyObject_GetAttrString( + torch_cuda_memory_module.get(), "_record_memory_history_impl")); + if (!snapshot_func) { + return; + } + // Call the function with arguments + PyObject* args = PyTuple_New(6); + PyTuple_SetItem(args, 0, enable ? PyUnicode_FromString("all") : Py_None); + PyTuple_SetItem(args, 1, PyUnicode_FromString("all")); // context + PyTuple_SetItem(args, 2, PyUnicode_FromString("all")); // stacks + PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries + PyTuple_SetItem(args, 4, Py_None); // device (None) + PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False) + PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr); + Py_DECREF(args); + if (result == nullptr) { + return; + } + PyGILState_Release(gil_state); +} + +void PythonMemoryTracer::start() { + toggle_memory_tracing(true); +} + +void PythonMemoryTracer::export_memory_history(const std::string path) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + THPObjectPtr torch_cuda_memory_module( + PyImport_ImportModule("torch.cuda.memory")); + if (!torch_cuda_memory_module) { + return; + } + THPObjectPtr snapshot_func( + PyObject_GetAttrString(torch_cuda_memory_module.get(), "_dump_snapshot")); + if (!snapshot_func) { + return; + } + PyObject* py_filename = PyUnicode_FromString(path.c_str()); + // Call the function with arguments (e.g., a file path) + PyObject* args = PyTuple_Pack(1, py_filename); + PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr); + Py_DECREF(args); + if (result == nullptr) { + return; + } + PyGILState_Release(gil_state); +} + +void PythonMemoryTracer::stop() { + toggle_memory_tracing(false); +} // ============================================================================ // == API ===================================================================== @@ -1139,6 +1254,11 @@ std::unique_ptr getTracer( torch::profiler::impl::RecordQueue* queue) { return std::make_unique(queue); } + +std::unique_ptr getMemoryTracer() { + return std::make_unique(); +} + } // namespace } // namespace torch::profiler::impl @@ -1149,5 +1269,7 @@ void init() { TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0); torch::profiler::impl::python_tracer::registerTracer( &torch::profiler::impl::getTracer); + torch::profiler::impl::python_tracer::registerMemoryTracer( + &torch::profiler::impl::getMemoryTracer); } } // namespace torch::autograd::profiler::python_tracer diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 140ac95a071a..ade983e708c1 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -37,7 +37,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { cudaStream_t stream_{}; }; -#if defined(TORCH_HIP_VERSION) +#if defined(USE_ROCM) using streamType = c10::hip::HIPStream; #else using streamType = c10::cuda::CUDAStream; diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h index c228da413a3e..46bf5ff31987 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h @@ -314,7 +314,7 @@ __device__ __inline__ Vec ld_vec(const T* addr) { template __device__ __inline__ void st_vec(T* addr, const Vec& vec) { -#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST) +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) CUDA_KERNEL_ASSERT(false); #else if constexpr (Alignment == 16) { diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu index 721d2c815875..08f61c80b1bb 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -786,7 +786,7 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( std::string group_name_; // Treat empty string and std::nullopt the same as empty string seems to be // implicitly used that way - if (group_name != "") { + if (group_name.has_value() && group_name != "") { group_name_ = *group_name; } else { if (!block->default_group_name.has_value()) { diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index 0b2044a870eb..e67dfa2f60f1 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -438,7 +438,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ at::Tensor one_shot_all_reduce_out_impl( const at::Tensor& input, - const c10::optional& local_input, + const std::optional& local_input, std::string reduce_op, std::string group_name, at::Tensor out) { @@ -463,6 +463,10 @@ at::Tensor one_shot_all_reduce_out_impl( local_input->numel() <= input.numel(), "one_shot_all_reduce: local input size must be smaller than symm buffer size."); } + if (input.numel() == 0) { + TORCH_CHECK(input.scalar_type() == out.scalar_type()); + return out; + } auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); TORCH_CHECK( symm_mem != nullptr, @@ -522,7 +526,7 @@ at::Tensor one_shot_all_reduce_out( std::string group_name, at::Tensor out) { return one_shot_all_reduce_out_impl( - input, c10::nullopt, reduce_op, group_name, out); + input, std::nullopt, reduce_op, group_name, out); } at::Tensor one_shot_all_reduce_copy_out( @@ -541,7 +545,7 @@ at::Tensor one_shot_all_reduce( std::string group_name) { auto out = at::empty_like(input); return one_shot_all_reduce_out_impl( - input, c10::nullopt, reduce_op, group_name, out); + input, std::nullopt, reduce_op, group_name, out); } at::Tensor one_shot_all_reduce_copy( @@ -555,9 +559,14 @@ at::Tensor one_shot_all_reduce_copy( } constexpr size_t two_shot_all_reduce_max_num_blocks = 24; -constexpr size_t two_shot_all_reduce_max_num_threads = 512; - -template +constexpr size_t two_shot_all_reduce_max_num_threads = 1024; + +template < + typename T, + int alignment, + int k_world_size, + bool reduce_scatter = false, + bool split_last_dim = false> static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ void two_shot_all_reduce_kernel( T** input_ptrs, @@ -566,31 +575,48 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ size_t numel, uint32_t** signal_pads, size_t rank, - size_t world_size) { + size_t world_size, + size_t last_dim_size = 0) { static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - + int32_t N_last_dim = + last_dim_size / world_size; // used only for split_last_dim reduce_scatter sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = - at::round_up(numel, alignment * world_size) / world_size; - const size_t start = numel_per_rank * rank; + at::round_up(numel, numel_per_thread * world_size) / world_size; + const size_t start = split_last_dim ? last_dim_size / world_size * rank + : numel_per_rank * rank; auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; auto stride = blockDim.x * gridDim.x * numel_per_thread; for (size_t i = offset; i < numel_per_rank; i += stride) { - if (start + i >= numel) { - continue; + if constexpr (!reduce_scatter) { + // we call reduce-scatter only with evenly divisible number of elements + if (start + i >= numel) { + continue; + } + } + size_t idx = i; + if constexpr (split_last_dim) { + idx = i / N_last_dim * last_dim_size + i % N_last_dim; } auto vec = load_and_reduce( - input_ptrs, rank, world_size, input_offset + start + i); - // store to local buffer - st_vec(input_ptrs[rank] + input_offset + start + i, vec); + input_ptrs, rank, world_size, input_offset + start + idx); + // store to local buffer or to output + if constexpr (reduce_scatter) { + st_vec(output_ptr + i, vec); + } else { + st_vec(input_ptrs[rank] + input_offset + start + i, vec); + } } __syncthreads(); sync_remote_blocks(signal_pads, rank, world_size); + if constexpr (reduce_scatter) { + return; + } __syncthreads(); for (size_t i = offset; i < numel_per_rank; i += stride) { Vec tmp[k_world_size]; @@ -611,8 +637,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ if (remote_start + i >= numel) { continue; } - st_vec( - output_ptr + remote_start + i, tmp[step]); + st_vec(output_ptr + remote_start + i, tmp[step]); } } // need to make sure all blocks exit simultaneously so that the data @@ -661,7 +686,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ at::Tensor two_shot_all_reduce_impl( at::Tensor input, - c10::optional output, + std::optional output, std::string reduce_op, std::string group_name) { TORCH_CHECK( @@ -679,11 +704,28 @@ at::Tensor two_shot_all_reduce_impl( get_and_verify_alignment(input, "two_shot_all_reduce"); if (output.has_value()) { + TORCH_CHECK( + output->is_contiguous(), + "two_shot_all_reduce: output must be contiguous."); const size_t output_alignment = get_and_verify_alignment(*output, "two_shot_all_reduce"); TORCH_CHECK( alignment <= output_alignment, "two_shot_all_reduce: output alignment must be equal to or larger than input."); + TORCH_CHECK( + output->sizes() == input.sizes(), + "two_shot_all_reduce: input/output size mismatch, input.sizes(): ", + input.sizes(), + ", output.sizes(): ", + output->sizes()); + if (input.numel() == 0) { + TORCH_CHECK(output->scalar_type() == input.scalar_type()); + return *output; + } + } else { + if (input.numel() == 0) { + return input; + } } int num_blocks = 0, num_threads = 0; @@ -754,7 +796,7 @@ at::Tensor two_shot_all_reduce_( at::Tensor input, std::string reduce_op, std::string group_name) { - return two_shot_all_reduce_impl(input, c10::nullopt, reduce_op, group_name); + return two_shot_all_reduce_impl(input, std::nullopt, reduce_op, group_name); } at::Tensor two_shot_all_reduce_out( @@ -764,6 +806,146 @@ at::Tensor two_shot_all_reduce_out( at::Tensor output) { return two_shot_all_reduce_impl(input, output, reduce_op, group_name); } + +at::Tensor reduce_scatter_out( + at::Tensor input, + std::string group_name, + bool split_last_dim, + at::Tensor output) { + TORCH_CHECK( + input.is_contiguous(), "reduce_scatter: input must be contiguous."); + TORCH_CHECK( + output.is_contiguous(), "reduce_scatter: output must be contiguous."); + + auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); + TORCH_CHECK( + symm_mem != nullptr, + "reduce_scatter: input must be allocated with empty_strided_p2p()."); + + const size_t alignment = get_and_verify_alignment(input, "reduce_scatter"); + + const size_t output_alignment = + get_and_verify_alignment(input, "reduce_scatter"); + + TORCH_CHECK( + input.numel() % + (symm_mem->get_world_size() * + (alignment / input.element_size())) == + 0, + "expected number of elements to be divisible by world_size * alignment, number of elements ", + input.numel(), + " world size ", + symm_mem->get_world_size(), + "alignment ", + alignment); + + if (split_last_dim) { + TORCH_CHECK(input.dim() == output.dim()); + bool are_equal_except_last = std::equal( + input.sizes().begin(), input.sizes().end() - 1, output.sizes().begin()); + TORCH_CHECK( + are_equal_except_last, + "reduce_scatter expected input and output to have same sizes except in the last dimension"); + TORCH_CHECK( + output.size(-1) == input.size(-1) / symm_mem->get_world_size(), + "reduce_scatter expected output last dim size to be input last dim size / world_size"); + + TORCH_CHECK( + input.size(-1) % + (symm_mem->get_world_size() * + (alignment / input.element_size())) == + 0, + "expected last dimension to be divisible by world_size * alignment, last dimension ", + input.size(-1), + " world size ", + symm_mem->get_world_size(), + "alignment ", + alignment); + } else { + TORCH_CHECK(input.dim() == 1, "reduce_scatter expected 1D input"); + TORCH_CHECK(output.dim() == 1, "reduce_scatter expected 1D output"); + TORCH_CHECK(output.numel() == input.numel() / symm_mem->get_world_size()); + } + if (input.numel() == 0) { + TORCH_CHECK(input.scalar_type() == output.scalar_type()); + return output; + } + + TORCH_CHECK( + output_alignment >= alignment, + "reduce_scatter: output alignment should be not smaller than input alignment"); + + int num_blocks = 0, num_threads = 0; + init_elementwise_launch_config( + input.numel(), + input.element_size(), + alignment, + symm_mem->get_world_size(), + two_shot_all_reduce_max_num_blocks, + two_shot_all_reduce_max_num_threads, + num_blocks, + num_threads); + if (split_last_dim) { + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel< + scalar_t, + k_alignment, + k_world_size, + true, + true> + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + output.data_ptr(), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size(), + input.size(-1)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + } else { + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel< + scalar_t, + k_alignment, + k_world_size, + true, + false> + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + output.data_ptr(), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size(), + input.size(-1)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + } + return output; +} } // namespace #endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 @@ -899,6 +1081,7 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out); m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_); m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out); + m.impl("reduce_scatter_out", ::reduce_scatter_out); m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm); #endif diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index e15f153d70c7..e134e39cab78 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -24,7 +24,7 @@ namespace c10d { // (minor when adding fields, major when changing existing fields) // Also update both JSON and Pickle dumps to make use of the newly defined // field(s). -DEFINE_CONSTANT(version_val, "2.4") +DEFINE_CONSTANT(version_val, "2.5") DEFINE_CONSTANT(entries_key, "entries") DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state") DEFINE_CONSTANT(version_key, "version") diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp index af09ba39470c..32c4c4f88ac0 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp @@ -39,12 +39,14 @@ C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( GlooDeviceRegistry, ::gloo::transport::Device, const std::string& /* interface */, - const std::string& /* hostname */) + const std::string& /* hostname */, + bool /* lazyInit */) #if GLOO_HAVE_TRANSPORT_TCP static std::shared_ptr<::gloo::transport::Device> makeTCPDevice( const std::string& interfaceName, - const std::string& hostname) { + const std::string& hostname, + bool lazyInit) { TORCH_CHECK( !interfaceName.empty() || !hostname.empty(), "GlooDeviceFactory::makeTCPDevice(): interface or hostname " @@ -56,7 +58,11 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPDevice( } else { attr.hostname = hostname; } - return ::gloo::transport::tcp::CreateDevice(attr); + if (lazyInit) { + return ::gloo::transport::tcp::CreateLazyDevice(attr); + } else { + return ::gloo::transport::tcp::CreateDevice(attr); + } } // Registry priority is per key identifier. We register TCP to `LINUX` for @@ -69,12 +75,15 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice) #if GLOO_HAVE_TRANSPORT_TCP_TLS static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( const std::string& interface, - const std::string& hostname) { + const std::string& hostname, + bool lazyInit) { TORCH_CHECK( !interface.empty() || !hostname.empty(), "GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname " "can't be empty"); + TORCH_CHECK(!lazyInit, "TCP_TLS transport does not support lazy init"); + ::gloo::transport::tcp::attr attr; if (!interface.empty()) { attr.iface = interface; @@ -105,12 +114,15 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice) #if GLOO_HAVE_TRANSPORT_UV static std::shared_ptr<::gloo::transport::Device> makeUVDevice( const std::string& interfaceName, - const std::string& hostname) { + const std::string& hostname, + bool lazyInit) { TORCH_CHECK( !interfaceName.empty() || !hostname.empty(), "GlooDeviceFactory::makeUVDevice(): interface or hostname " "can't be empty"); + TORCH_CHECK(!lazyInit, "UV transport does not support lazy init"); + ::gloo::transport::uv::attr attr; if (!interfaceName.empty()) { attr.iface = interfaceName; @@ -131,23 +143,27 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, UV, makeUVDevice) namespace { std::shared_ptr<::gloo::transport::Device> makeGlooDevice( const std::string& interfaceName, - const std::string& hostName) { + const std::string& hostName, + bool lazyInit) { static auto transportName = c10::utils::get_env("GLOO_DEVICE_TRANSPORT"); if (transportName.has_value()) { return GlooDeviceRegistry()->Create( - transportName.value().c_str(), interfaceName, hostName); + transportName.value().c_str(), interfaceName, hostName, lazyInit); } #ifdef __linux__ - return GlooDeviceRegistry()->Create("LINUX", interfaceName, hostName); + return GlooDeviceRegistry()->Create( + "LINUX", interfaceName, hostName, lazyInit); #endif #ifdef __APPLE__ - return GlooDeviceRegistry()->Create("APPLE", interfaceName, hostName); + return GlooDeviceRegistry()->Create( + "APPLE", interfaceName, hostName, lazyInit); #endif #ifdef _WIN32 - return GlooDeviceRegistry()->Create("WIN32", interfaceName, hostName); + return GlooDeviceRegistry()->Create( + "WIN32", interfaceName, hostName, lazyInit); #endif return nullptr; @@ -155,8 +171,8 @@ std::shared_ptr<::gloo::transport::Device> makeGlooDevice( } // anonymous namespace std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: - makeDeviceForInterface(const std::string& interfaceName) { - auto device = makeGlooDevice(interfaceName, ""); + makeDeviceForInterface(const std::string& interfaceName, bool lazyInit) { + auto device = makeGlooDevice(interfaceName, "", lazyInit); if (!device) { TORCH_CHECK(false, "makeDeviceForInterface(): unsupported gloo device"); } @@ -164,8 +180,8 @@ std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: } std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: - makeDeviceForHostname(const std::string& hostname) { - auto device = makeGlooDevice("", hostname); + makeDeviceForHostname(const std::string& hostname, bool lazyInit) { + auto device = makeGlooDevice("", hostname, lazyInit); if (!device) { TORCH_CHECK(false, "makeDeviceForHostname(): unsupported gloo device"); } diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp index 1221e9d033f2..a7220f0d81c7 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp @@ -14,18 +14,21 @@ class TORCH_API GlooDeviceFactory { public: // Create new device instance for specific interface. static std::shared_ptr<::gloo::transport::Device> makeDeviceForInterface( - const std::string& interface); + const std::string& interface, + bool lazyInit); // Create new device instance for specific hostname or address. static std::shared_ptr<::gloo::transport::Device> makeDeviceForHostname( - const std::string& hostname); + const std::string& hostname, + bool lazyInit); }; TORCH_DECLARE_SHARED_REGISTRY( GlooDeviceRegistry, ::gloo::transport::Device, const std::string&, /* interface */ - const std::string& /* hostname */); + const std::string&, /* hostname */ + bool /* lazyInit */); } // namespace c10d diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index dff8a5f78775..faec5bc449ac 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -92,7 +92,9 @@ std::shared_ptr NCCLComm::create_scalable( int numRanks, int rank, std::vector& commIds, + at::DeviceIndex deviceIndex, ncclConfig_t& config) { + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); auto comm = std::make_shared(); comm->nonBlocking_ = config.blocking == 0; LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " @@ -112,6 +114,7 @@ std::shared_ptr NCCLComm::create_scalable( // in the log file and in the replay tool. comm->ncclId_ = commIds[0]; comm->rank_ = rank; + comm->deviceIndex_ = deviceIndex; comm->initialized_ = !comm->nonBlocking_; return comm; } @@ -150,6 +153,10 @@ ncclComm_t NCCLComm::getNcclComm() { return ncclComm_; } +at::DeviceIndex NCCLComm::getDeviceIndex() { + return deviceIndex_; +} + // Wait for the communicator to be ready. This is a blocking function. // Arguments: // longInterval: if true, wait with sleep of an interval; otherwise, wait diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index c7cd0a30924e..89bf15fc6479 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -221,6 +221,7 @@ class NCCLComm { int numRanks, int rank, std::vector& commIds, + at::DeviceIndex deviceIndex, ncclConfig_t& config); #endif // NCCL_HAS_INIT_RANK_SCALABLE #endif // NCCL_HAS_CONFIG @@ -239,6 +240,7 @@ class NCCLComm { #endif ncclUniqueId getNcclId(); + at::DeviceIndex getDeviceIndex(); // Must not be copyable NCCLComm(const NCCLComm&) = delete; diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 6251bfa1817d..0480f1b9191d 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) { .def("wait", [](const c10::intrusive_ptr& self) { self->wait(); }); m.class_("ReduceOp").def(torch::init<>()); m.def( - "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work"); + "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); + "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); m.def( - "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); m.def( - "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work"); + "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); m.def( - "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work"); + "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); m.def( - "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); m.def( - "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work"); + "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work"); + "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work"); + "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work"); + "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work"); + "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()"); m.def( @@ -118,6 +118,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) const c10::intrusive_ptr& reduce_op, \ int64_t root_rank, \ int64_t root_tensor, \ + bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ return process_group->getBackend(c10::DeviceType::DEV) \ @@ -127,7 +128,8 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) *reduce_op.get(), \ root_rank, \ root_tensor, \ - std::chrono::milliseconds(timeout)}); \ + std::chrono::milliseconds(timeout), \ + asyncOp}); \ } IMPL_REDUCE(CPU) @@ -169,12 +171,13 @@ IMPL_BROADCAST(PrivateUse1) const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ const std::optional& sparse_indices, \ + bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \ tensor_vec, \ AllreduceOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ + *reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \ return std::tuple, c10::intrusive_ptr>( \ std::move(tensor_vec), work); \ } @@ -188,11 +191,13 @@ IMPL_ALLREDUCE(PrivateUse1) at::TensorList tensors, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ + bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \ opts.reduceOp = *reduce_op.get(); \ opts.timeout = std::chrono::milliseconds(timeout); \ + opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ ->allreduce_coalesced(tensor_vec, opts); \ } @@ -209,12 +214,13 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) const std::vector>& output_tensors, \ at::TensorList input_tensors, \ const c10::intrusive_ptr& process_group, \ + bool asyncOp, \ int64_t timeout) { \ auto input_tensors_vec = input_tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \ const_cast>&>(output_tensors), \ input_tensors_vec, \ - AllgatherOptions{std::chrono::milliseconds(timeout)}); \ + AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \ return std:: \ tuple>, c10::intrusive_ptr>( \ output_tensors, work); \ @@ -249,12 +255,16 @@ IMPL__ALLGATHER_BASE(PrivateUse1) c10::intrusive_ptr allgather_coalesced_##DEV( \ const std::vector>& output_lists, \ const at::TensorList& input_list, \ - const c10::intrusive_ptr& process_group) { \ + const c10::intrusive_ptr& process_group, \ + bool asyncOp) { \ auto input_list_vec = input_list.vec(); \ + auto opts = AllgatherOptions{}; \ + opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ ->allgather_coalesced( \ const_cast>&>(output_lists), \ - input_list_vec); \ + input_list_vec, \ + opts); \ } IMPL_ALLGATHER_COALESCED(CPU) @@ -265,11 +275,14 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) c10::intrusive_ptr allgather_into_tensor_coalesced_##DEV( \ at::TensorList outputs, \ at::TensorList inputs, \ - const c10::intrusive_ptr& process_group) { \ + const c10::intrusive_ptr& process_group, \ + bool asyncOp) { \ auto output_vec = outputs.vec(); \ auto input_vec = inputs.vec(); \ + auto opts = AllgatherOptions{}; \ + opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ - ->allgather_into_tensor_coalesced(output_vec, input_vec); \ + ->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \ } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) @@ -283,6 +296,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) const std::vector>& input_tensors, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ + bool asyncOp, \ int64_t timeout) { \ auto output_tensors_vec = output_tensors.vec(); \ auto work = \ @@ -290,7 +304,9 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) output_tensors_vec, \ const_cast>&>(input_tensors), \ ReduceScatterOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ + *reduce_op.get(), \ + std::chrono::milliseconds(timeout), \ + asyncOp}); \ return std::tuple, c10::intrusive_ptr>( \ output_tensors_vec, work); \ } @@ -329,6 +345,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) at::TensorList inputs, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ + bool asyncOp, \ int64_t timeout) { \ auto output_vec = outputs.vec(); \ auto input_vec = inputs.vec(); \ @@ -337,7 +354,9 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) output_vec, \ input_vec, \ ReduceScatterOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ + *reduce_op.get(), \ + std::chrono::milliseconds(timeout), \ + asyncOp}); \ } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) @@ -350,13 +369,15 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) const at::TensorList& input_tensors, \ const c10::intrusive_ptr& process_group, \ int64_t root_rank, \ + bool asyncOp, \ int64_t timeout) { \ auto input_tensors_vec = input_tensors.vec(); \ return process_group->getBackend(c10::DeviceType::DEV) \ ->gather( \ const_cast>&>(output_tensors), \ input_tensors_vec, \ - GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \ + GatherOptions{ \ + root_rank, std::chrono::milliseconds(timeout), asyncOp}); \ } IMPL_GATHER(CPU) @@ -391,13 +412,14 @@ IMPL_SCATTER(PrivateUse1) const at::TensorList& output_tensors, \ const at::TensorList& input_tensors, \ const c10::intrusive_ptr& process_group, \ + bool asyncOp, \ int64_t timeout) { \ auto output_tensors_vec = output_tensors.vec(); \ auto input_tensors_vec = input_tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \ output_tensors_vec, \ input_tensors_vec, \ - AllToAllOptions{std::chrono::milliseconds(timeout)}); \ + AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ return std::tuple, c10::intrusive_ptr>( \ std::move(output_tensors_vec), work); \ } @@ -406,21 +428,22 @@ IMPL_ALLTOALL(CPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) -#define IMPL_ALLTOALL_BASE(DEV) \ - c10::intrusive_ptr alltoall_base_##DEV( \ - at::Tensor& output, \ - at::Tensor& input, \ - const c10::intrusive_ptr& process_group, \ - std::vector output_split_sizes, \ - std::vector input_split_sizes, \ - int64_t timeout) { \ - return process_group->getBackend(c10::DeviceType::DEV) \ - ->alltoall_base( \ - output, \ - input, \ - output_split_sizes, \ - input_split_sizes, \ - AllToAllOptions{std::chrono::milliseconds(timeout)}); \ +#define IMPL_ALLTOALL_BASE(DEV) \ + c10::intrusive_ptr alltoall_base_##DEV( \ + at::Tensor& output, \ + at::Tensor& input, \ + const c10::intrusive_ptr& process_group, \ + std::vector output_split_sizes, \ + std::vector input_split_sizes, \ + bool asyncOp, \ + int64_t timeout) { \ + return process_group->getBackend(c10::DeviceType::DEV) \ + ->alltoall_base( \ + output, \ + input, \ + output_split_sizes, \ + input_split_sizes, \ + AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ } IMPL_ALLTOALL_BASE(CPU) @@ -428,15 +451,18 @@ IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) // NOLINTBEGIN(performance-unnecessary-value-param) -#define IMPL_BARRIER(DEV) \ - c10::intrusive_ptr barrier##DEV( \ - at::Tensor /* unused */, \ - const c10::intrusive_ptr& process_group, \ - const std::vector& device_ids, \ - int64_t timeout) { \ - return process_group->getBackend(c10::DeviceType::DEV) \ - ->barrier( \ - BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \ +#define IMPL_BARRIER(DEV) \ + c10::intrusive_ptr barrier##DEV( \ + at::Tensor /* unused */, \ + const c10::intrusive_ptr& process_group, \ + const std::vector& device_ids, \ + bool asyncOp, \ + int64_t timeout) { \ + auto opts = BarrierOptions{}; \ + opts.device_ids = device_ids; \ + opts.timeout = std::chrono::milliseconds(timeout); \ + opts.asyncOp = asyncOp; \ + return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \ } IMPL_BARRIER(CPU) @@ -464,6 +490,7 @@ allreduce_sparse_cuda_( const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, const std::optional& sparse_indices, + bool asyncOp, int64_t timeout) { auto tensor_vec = tensors.vec(); auto work = process_group->getBackend(c10::DeviceType::CUDA) @@ -472,6 +499,7 @@ allreduce_sparse_cuda_( AllreduceOptions{ *reduce_op, std::chrono::milliseconds(timeout), + asyncOp, sparse_indices}); return std::tuple, c10::intrusive_ptr>( diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index b3f3d9bdd72d..4ce67c9f5798 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -224,6 +224,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, const std::optional& sparse_indices, + bool, int64_t)>(); auto work = std::get<1>(op.call( @@ -231,6 +232,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.sparseIndices, + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, + bool, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -277,6 +281,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t, int64_t, + bool, int64_t)>(); auto work = op.call( tensors, @@ -284,6 +289,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::make_intrusive(opts.reduceOp), opts.rootRank, opts.rootTensor, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector>&, at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allgather_coalesced_", "") - .typed( - const std::vector>&, - const at::TensorList&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allgather_coalesced_", "") + .typed( + const std::vector>&, + const at::TensorList&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool)>(); auto work = op.call( outputTensorLists, inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor_list : outputTensorLists) { @@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { .typed( const at::TensorList, const at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool)>(); auto work = op.call( outputTensors, inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { @@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, + bool, int64_t)>(); auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, + bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -546,6 +561,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, + bool, int64_t)>(); auto work = op.call( @@ -553,6 +569,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -577,6 +594,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, std::vector, std::vector, + bool, int64_t)>(); auto work = op.call( outputBuffer, @@ -584,6 +602,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), outputSplitSizes, inputSplitSizes, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::Tensor, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const std::vector&, + bool, int64_t)>(); auto work = op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(tensor, work); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 345b2741dc97..077bf311284f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -415,6 +415,10 @@ const auto kLoopbackAddress = "127.0.0.1"; } // namespace +bool getDefaultGlooLazyInit() { + return ::c10d::getCvarBool(TORCH_GLOO_LAZY_INIT, false); +} + // static void ProcessGroupGloo::AsyncWork::execute( const c10::intrusive_ptr& work) { @@ -687,23 +691,24 @@ bool doesHostnameResolveToUsableAddress(const std::string& hostname) { } // namespace std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDeviceForInterface(const std::string& interface_name) { - return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name); + createDeviceForInterface(const std::string& interface_name, bool lazyInit) { + return ::c10d::GlooDeviceFactory::makeDeviceForInterface( + interface_name, lazyInit); } std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDeviceForHostname(const std::string& hostname) { + createDeviceForHostname(const std::string& hostname, bool lazyInit) { TORCH_CHECK( doesHostnameResolveToUsableAddress(hostname), "Cannot resolve ", hostname, " to a (local) address"); - return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname); + return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname, lazyInit); } #if defined(__linux__) || defined(_WIN32) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDefaultDevice() { + createDefaultDevice(bool lazyInit) { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. @@ -716,7 +721,8 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: // Use this machine's hostname if it resolves to an address. if (doesHostnameResolveToUsableAddress(hostname.data())) { - return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data()); + return ::c10d::GlooDeviceFactory::makeDeviceForHostname( + hostname.data(), lazyInit); } // Otherwise, use the loopback address. @@ -724,13 +730,13 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: "Unable to resolve hostname to a (local) address. ", "Using the loopback address as fallback. ", "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); - return createDeviceForHostname(kLoopbackAddress); + return createDeviceForHostname(kLoopbackAddress, lazyInit); } #endif #ifdef __APPLE__ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDefaultDevice() { + createDefaultDevice(bool lazyInit) { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. @@ -743,7 +749,8 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: // Use this machine's hostname if it resolves to an address. if (doesHostnameResolveToUsableAddress(hostname.get())) { - return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get()); + return ::c10d::GlooDeviceFactory::makeDeviceForHostname( + hostname.get(), lazyInit); } // Otherwise, use the loopback address. @@ -751,7 +758,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: "Unable to resolve hostname to a (local) address. ", "Using the loopback address as fallback. ", "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); - return createDeviceForHostname(kLoopbackAddress); + return createDeviceForHostname(kLoopbackAddress, lazyInit); } #endif @@ -785,10 +792,25 @@ ProcessGroupGloo::ProcessGroupGloo( contexts_.reserve(options_->devices.size()); for (const auto i : c10::irange(options_->devices.size())) { auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); - auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); + +#ifdef GLOO_SHARED_STORE + auto underlyingStore = store_; +#else + auto& underlyingStore = *store_; +#endif + + auto store = std::make_shared<::gloo::rendezvous::PrefixStore>( + std::to_string(i), underlyingStore); + +#ifdef GLOO_SHARED_STORE + auto connectStore = store; +#else + auto& connectStore = *store; +#endif + context->setTimeout(options_->timeout); try { - context->connectFullMesh(store, options_->devices[i]); + context->connectFullMesh(connectStore, options_->devices[i]); } catch (const std::runtime_error& e) { auto err = e.what(); // TORCH_CHECK to print the cpp stacktrace. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index b44cba9f35a4..917544d9e113 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -28,6 +28,13 @@ namespace c10d { constexpr const char* GLOO_BACKEND_NAME = "gloo"; +// Control whether or not connections are established in a full mesh or lazily +// as needed. +static std::vector TORCH_GLOO_LAZY_INIT = {"TORCH_GLOO_LAZY_INIT"}; + +// Returns default value for lazyInit. +bool TORCH_API getDefaultGlooLazyInit(); + // ProcessGroupGloo implements Gloo bindings for c10d. // // All functions on this class are expected to be called in the same @@ -244,24 +251,20 @@ class TORCH_API ProcessGroupGloo : public Backend { // Create new device instance for specific interface. static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( - const std::string& interface); + const std::string& interface, + bool lazyInit = false); // Create new device instance for specific hostname or address. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( - const std::string& hostname); + const std::string& hostname, + bool lazyInit = false); // Create new device instance. // It tries to resolve this machine's hostname and bind to that address. // If that fails (i.e. the hostname doesn't resolve to an address), it // falls back to binding to the loopback address. - static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); - - // Create ProcessGroupGloo instance. - static c10::intrusive_ptr createProcessGroupGloo( - const c10::intrusive_ptr& store, - int rank, - int size, - std::chrono::milliseconds timeout); + static std::shared_ptr<::gloo::transport::Device> createDefaultDevice( + bool lazyInit = false); explicit ProcessGroupGloo( const c10::intrusive_ptr& store, @@ -367,7 +370,7 @@ class TORCH_API ProcessGroupGloo : public Backend { void enableCollectivesTiming() override; - const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const { + const std::shared_ptr<::gloo::rendezvous::Store>& _getStore() const { return store_; } @@ -393,7 +396,7 @@ class TORCH_API ProcessGroupGloo : public Backend { } protected: - std::unique_ptr<::gloo::rendezvous::Store> store_; + std::shared_ptr<::gloo::rendezvous::Store> store_; const c10::intrusive_ptr options_; // Every Gloo context represents a set of connections to its peers. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 863bc1c4491c..ecfb2b5d10d4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -200,20 +200,6 @@ inline std::string getKeyFromDevice(at::Device& device) { return std::to_string(device.index()); } -inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { - // initialize the device index to -1, which is an invalid value. - int index = -1; - try { - index = std::stoi(deviceKey); - } catch (const std::invalid_argument& e) { - LOG(ERROR) << c10::str( - "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); - } catch (const std::out_of_range& e) { - LOG(ERROR) << "Out of range: " << e.what(); - } - return static_cast(index); -} - std::string getKeySendRecv(int myRank, int peer) { int lowRank = myRank < peer ? myRank : peer; int highRank = myRank < peer ? peer : myRank; @@ -289,6 +275,28 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { } } +// When TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK is set, all tensors (no +// matter how they have been allocated) are registered with all NCCL comms. +bool shouldAllCommunicatorsRegisterAllTensors() { +#ifdef NCCL_HAS_COMM_REGISTER + static const bool flag = [] { + const bool flag = + getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); + if (flag && + c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + expandable_segments()) { + LOG(INFO) + << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + return false; + } + return flag; + }(); + return flag; +#else + return false; +#endif // NCCL_HAS_COMM_REGISTER +} + } // namespace // Map from each communicator to its device index. @@ -303,7 +311,6 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // communicators in all PGs. static std::unordered_map, int> ncclCommDevIdxMap; static std::mutex ncclCommDevIdxMapMutex; -static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); @@ -316,12 +323,12 @@ static void cacheAllocatorRegisterHook( } std::lock_guard lock(ncclCommDevIdxMapMutex); - for (auto& it : ncclCommDevIdxMap) { - auto& ncclComm = it.first; - auto& devIdx = it.second; - if (te.device_ == devIdx) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + if (te.device_ == ncclComm->getDeviceIndex()) { + if (shouldAllCommunicatorsRegisterAllTensors()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + } } } } @@ -335,16 +342,30 @@ static void cacheAllocatorDeregisterHook( } std::lock_guard lock(ncclCommDevIdxMapMutex); - for (auto& it : ncclCommDevIdxMap) { - auto& ncclComm = it.first; - auto& devIdx = it.second; - if (te.device_ == devIdx) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + if (te.device_ == ncclComm->getDeviceIndex()) { + if (shouldAllCommunicatorsRegisterAllTensors()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + } } } } +static void attachAllocatorHooks() { + static c10::once_flag flag; + c10::call_once(flag, [] { + // Attaching hooks fails if CUDACachingAllocator is not initialized, so + // Init for CUDA is called (and is a no-op if CUDA is already + // initialized). + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorRegisterHook); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorDeregisterHook); + }); +} + static std:: unordered_map> getNCCLCommDumpMap() { @@ -359,11 +380,12 @@ static std:: std::vector> allNCCLComms; // within the critical section, we don't want to dump while holding the lock // as dump might hang - ncclCommDevIdxMapMutex.lock(); - for (auto& [ncclComm, _] : ncclCommDevIdxMap) { - allNCCLComms.push_back(ncclComm); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + allNCCLComms.push_back(ncclComm); + } } - ncclCommDevIdxMapMutex.unlock(); for (auto& ncclComm : allNCCLComms) { std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); @@ -454,6 +476,36 @@ std::ostream& operator<<( return output << workInfo; } +/* Implementation of TensorShelf class */ + +void TensorShelf::stash(std::vector& tensors) { + std::lock_guard lock(mutex_); + tVector_.insert(tVector_.end(), tensors.begin(), tensors.end()); +} + +void TensorShelf::stash(TensorShelf& other) { + std::vector& otherVec = other.get(); + this->stash(otherVec); +} + +void TensorShelf::unstash() { + this->clear(); +} + +bool TensorShelf::empty() { + std::lock_guard lock(mutex_); + return tVector_.empty(); +} + +void TensorShelf::clear() { + std::lock_guard lock(mutex_); + tVector_.clear(); +} + +std::vector& TensorShelf::get() { + return tVector_; +} + ProcessGroupNCCL::WorkNCCL::WorkNCCL( std::string pgUID, std::string pgDesc, @@ -496,6 +548,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( } futureWorkResult_ = c10::make_intrusive(c10::AnyEnumType::get()); + // other functions expect an initialized ptr + stashed_for_allocator_safety_ = std::make_shared(); } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) @@ -517,6 +571,11 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) numelIn_(w.numelIn_), numelOut_(w.numelOut_), store_(w.store_), + // Note: the `work` returned to user and the `work` enqueued to watchdog + // share the pointer to the tensor stash. At least one of them should + // clean the tensor stash, the earlier the better, i.e. user calling + // `work.wait` than watchdog detecting work completion. + stashed_for_allocator_safety_(w.stashed_for_allocator_safety_), futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), @@ -714,10 +773,9 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); // Block the current stream on the NCCL stream ncclEndEvent_->block(currentStream); - - if (avoidRecordStreams_) { - stashed_for_allocator_safety_->clear(); - } + // Unstage the stashed tensors so that CachingAllocator can recycle them + // THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE + stashed_for_allocator_safety_->unstash(); } // Same as calling synchronize() when blockingWait_ is false @@ -781,7 +839,7 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { // upgrade. Once a NCCL version is qualified, this code should not be needed // at runtime. #ifdef PGNCCL_ENABLE_HASH - if (distDebugLevel_ >= DebugLevel::Detail) { + if (enableCollectiveHashDebug_.load()) { auto numel = getTensorsNumel(*outputs_); auto hashValue = hashTensors(*outputs_); PRINT_COLLECTIVE_HASH_SIGNATURE( @@ -802,9 +860,10 @@ void ProcessGroupNCCL::WorkNCCL::abort() { // Abort all communicators of this work ncclComm_->abort(); - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.erase(ncclComm_); - ncclCommDevIdxMapMutex.unlock(); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + ncclCommDevIdxMap.erase(ncclComm_); + } } ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; @@ -921,7 +980,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/); coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000); - enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); + enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); // store_ usually is wrapped with PrefixStore and the prefix is different // across different ProcessGroupNCCL(PG) instances. We need to get the // underlying non-PrefixStore for sharing global information shared across @@ -933,18 +992,10 @@ ProcessGroupNCCL::ProcessGroupNCCL( enableTiming_.store( getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); #endif // ENABLE_NCCL_ERROR_CHECKING - avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); -#ifdef NCCL_HAS_COMM_REGISTER - useTensorRegisterAllocatorHook_ = - getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - expandable_segments()) { - useTensorRegisterAllocatorHook_ = false; - LOG(INFO) - << logPrefix() - << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated."); } -#endif // NCCL_HAS_COMM_REGISTER if (blockingWait_) { LOG(INFO) @@ -997,7 +1048,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug #ifdef NCCL_HAS_COMM_REGISTER << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " - << useTensorRegisterAllocatorHook_ + << shouldAllCommunicatorsRegisterAllTensors() #endif // NCCL_HAS_COMM_REGISTER << ", TORCH_NCCL_ENABLE_MONITORING: " << monitorThreadEnabled_.load() @@ -1018,22 +1069,14 @@ ProcessGroupNCCL::ProcessGroupNCCL( // action is called. In the following hooks, we register a newly allocated // segment when SEGMENT_ALLOC action occurs, and deregister a segment when // SEGMENT_FREE action occurs. - // We attach hooks only once at the first PG creation. - // Attaching hooks fails if CUDACachingAllocator is not initialized, so - // Init for CUDA is called (and is a no-op if CUDA is already - // initialized). - if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { - at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); - c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( - &cacheAllocatorRegisterHook); - c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( - &cacheAllocatorDeregisterHook); - allocatorHooksAttached = true; + if (shouldAllCommunicatorsRegisterAllTensors()) { + // This call is idempotent. + attachAllocatorHooks(); } // Enable Desync Debugger per user setting if (desyncDebug_) { - desyncDebugger_.init(rank, size, store_); + desyncDebugger_.init(rank, size, globalRank(), getUid(), store_); } } @@ -1295,7 +1338,7 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( e.what()); debugLog.strings["status"] = "EXCEPTION"; - debugLog.strings["exception"] = e.what(); + debugLog.strings["exception_msg"] = e.what(); LOG(ERROR) << errorMsg; } catch (...) { errorMsg = c10::str( @@ -1303,7 +1346,7 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( "Unknown exception thrown when waiting for future ", futDescription); debugLog.strings["status"] = "EXCEPTION"; - debugLog.strings["exception"] = "Unknown exception"; + debugLog.strings["exception_msg"] = "Unknown exception"; LOG(ERROR) << errorMsg; } } else { @@ -1361,12 +1404,12 @@ bool ProcessGroupNCCL::abortComms( // communicators. Note that ncclCommDevIdxMap is a global container which may // contain other PG's communicators, thus we need to only erase communicators // for the current PG. - ncclCommDevIdxMapMutex.lock(); - for (auto& it : devNCCLCommMap_) { - auto& ncclComm = it.second; - ncclCommDevIdxMap.erase(ncclComm); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& [_, ncclComm] : devNCCLCommMap_) { + ncclCommDevIdxMap.erase(ncclComm); + } } - ncclCommDevIdxMapMutex.unlock(); std::lock_guard lock(mutex_); abortCommsFromMap(devNCCLCommMap_, abortReason); @@ -1790,8 +1833,6 @@ void ProcessGroupNCCL::heartbeatMonitor() { if (logger) { logger->log(debugLog); } - // Indicate to watchdog thread that we have finished dumping. - promiseFlightRecorderDump_.set_value(); } // GIL deadlock check. @@ -1918,9 +1959,13 @@ void ProcessGroupNCCL::ncclCommWatchdog() { void ProcessGroupNCCL::DesyncDebugger::init( int rank, int size, + int globalRank, + int pgId, c10::intrusive_ptr store) { rank_ = rank; size_ = size; + globalRank_ = globalRank; + pgId_ = pgId; store_ = std::move(store); enabled_ = true; traceKeyStart_ = getTraceStartKey("NCCL", rank); @@ -1932,21 +1977,38 @@ void ProcessGroupNCCL::DesyncDebugger::run() { if (!enabled_) return; auto logPrefix = c10::str("Rank ", rank_); + ::c10d::C10dLoggingData log; + log.integers["pg_id"] = pgId_; + log.integers["rank"] = rank_; + log.integers["global_rank"] = globalRank_; + log.integers["world_size"] = size_; + // Use this to differentiate between flight recorder and desync debug report. + log.strings["flight_recorder_version"] = "-1"; + try { std::string desyncMsg = retrieveDesyncReport(store_, "NCCL", rank_, size_); + log.strings["status"] = "SUCCESS"; LOG(ERROR) << logPrefix << desyncMsg; } catch (const std::exception& e) { + log.strings["status"] = "EXCEPTION"; + log.strings["exception_msg"] = e.what(); enabled_ = false; LOG(ERROR) << logPrefix << " Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " << " Please file an issue. Error: " << e.what(); } catch (...) { enabled_ = false; + log.strings["status"] = "EXCEPTION"; + log.strings["exception_msg"] = "Unknown exception"; LOG(ERROR) << logPrefix << " Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." << " Please file an issue."; } + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(log); + } } // Log work start to store. @@ -2095,27 +2157,10 @@ void ProcessGroupNCCL::broadcastDumpSignal() { // broadcast dump signal to all other global ranks. broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank()); // signal the local rank to start dumping - if (shouldDump_.load()) { - // already signaled dump, skipping signal again and wait for the dump - // future. - return; - } - LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; - // signal the monitor thread on PG0 to start dumping - shouldDump_.store(true); - // Give time for dumping before throwing exception - auto start = std::chrono::steady_clock::now(); - auto status = promiseFlightRecorderDump_.get_future().wait_for( - std::chrono::milliseconds(waitTimeoutDumpInMilSec_)); - if (status == std::future_status::timeout) { - LOG(WARNING) << logPrefix() << "timed out after waiting for " - << waitTimeoutDumpInMilSec_ << "ms" - << " flight recorder dumps to finish."; - } else if (status == std::future_status::ready) { - auto end = std::chrono::steady_clock::now(); - LOG(INFO) << logPrefix() << "slept for " << computeDeltaMS(start, end) - << "ms" - << " giving time for flight recorder dumps to finish."; + if (!shouldDump_.load()) { + LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; + // signal the monitor thread on PG0 to start dumping + shouldDump_.store(true); } } @@ -2290,6 +2335,13 @@ void ProcessGroupNCCL::watchdogHandler() { // recorder behavior is independent of desync Debug. if (dumpOnTimeoutOrEx_) { broadcastDumpSignal(); + // Give time for dumping before throwing exception for all ranks. + // It is hard to presume or control what the pattern of watchdog might + // look like, so it is better to let all ranks universally sleep for a + // short period of time, in this case, 60 seconds, which is also the + // maximum time we leave for FR dump. + std::this_thread::sleep_for( + std::chrono::milliseconds(waitTimeoutDumpInMilSec_)); } if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { @@ -2323,6 +2375,23 @@ void ProcessGroupNCCL::watchdogHandler() { // Clean up completed work if (work.isCompleted()) { + // In case user didn't call `work.wait()` with async collectives, + // watchdog would unstage the stashed tensors when detecting completion + // of the collective, to prevent ProcessGroupNCCL from holding reference + // to those tensors forever. + // work.stashed_for_allocator_safety_->unstash(); + // Update: it seems directly unstashing from watchdog thread would cause + // some rare problems. We thus move the unstashing to main thread, + // triggered by a next user call, see `workEnqueue`. But `work` is going + // to be destructed, so we transfer the work's shelf to a shelves + // structure owned by the PG. + if (!work.stashed_for_allocator_safety_->empty()) { + std::lock_guard lock(shelvesMutex_); + // We are just pushing back a shared_ptr here, so the cost should be + // minimal + shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_); + } + // Work status logging for desync debug desyncDebugger_.logWorkEnd(work); @@ -2650,9 +2719,10 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { // Clear used device indices. usedDeviceIdxs_.clear(); - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.erase(ncclComm); - ncclCommDevIdxMapMutex.unlock(); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + ncclCommDevIdxMap.erase(ncclComm); + } } std::shared_ptr ProcessGroupNCCL::initNCCLComm( @@ -2819,8 +2889,8 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( << "ProcessGroupNCCL all-gather unique IDs through store took " << timerDeltaMs << " ms"; #if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG) - ncclComm = - NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config); + ncclComm = NCCLComm::create_scalable( + numRanks, rank, ncclIDs, deviceIndex, options_->config); #else C10_THROW_ERROR( DistBackendError, @@ -2912,7 +2982,7 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( // Now ncclComms are fully initialized. // Register all active CUDA memory segments in cache allocator to // the new NCCL communicators - if (useTensorRegisterAllocatorHook_) { + if (shouldAllCommunicatorsRegisterAllTensors()) { auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); // Register the segment to a new NCCL communicator if on the same device for (const auto& segmentInfo : snapshot.segments) { @@ -2930,9 +3000,10 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( // on the same device. // NOTE: we need remove the communicator from this map when it is // destroyed, otherwise may register onto an invalid communicator. - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.emplace(ncclComm, device.index()); - ncclCommDevIdxMapMutex.unlock(); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + ncclCommDevIdxMap.emplace(ncclComm, device.index()); + } } it = devNCCLCommMap_.find(deviceKey); @@ -3057,6 +3128,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( enableTiming_.load(), cudaEventCacheEnabled_.load(), dist_debug_level_); + if (record) { bool isP2P = isP2POp(opType); // Ideally record every work that we enqueue, rather than every work we @@ -3136,6 +3208,17 @@ void ProcessGroupNCCL::assignTimeoutToWork( void ProcessGroupNCCL::workEnqueue( const c10::intrusive_ptr& work) { + // We clean up the TensorShelf's in case user hasn't called `work.wait()`. + // This has nothing to do with new work enqueue. We are just using a place + // that would be triggered by a next user call. + { + std::lock_guard lock(shelvesMutex_); + for (auto& shelf : shelvesToUnstash_) { + shelf->unstash(); + } + shelvesToUnstash_.clear(); + } + // in blockingWait_ mode, we don't need watchdog thread, so no need to enqueue // the work if (!terminateProcessGroup_.load() && !blockingWait_) { @@ -3172,6 +3255,7 @@ void ProcessGroupNCCL::startCoalescing() { coalescedDevice_.set_index(-1); coalescedComm_ = nullptr; + coalescedTensors_.clear(); coalescing_state_ |= CoalActive; groupStart(); } @@ -3197,6 +3281,9 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { // `getKeyFromDevice` is how we get keys for both collectives and batch P2P const auto key = getKeyFromDevice(device); auto ncclStream = ncclStreams_.at(key); + auto opProfilerTitle = optype != OpType::COALESCED + ? "nccl:" + opTypeToString(optype) + "_coalesced" + : "nccl:coalesced"; // Create Work object c10::cuda::CaptureStatus capture_status = @@ -3208,16 +3295,18 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { rank_, optype, coalescing_state_ & CoalP2P, - "nccl:coalesced", + opProfilerTitle.c_str(), {}, {}, enqueue); work->ncclComm_ = comm; work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams_; work->store_ = store_; assignTimeoutToWork(work, options_); + // Hand over references to tensors during coalescing to work's stash + work->stashed_for_allocator_safety_->stash(coalescedTensors_); + // Record start before ncclGroupEnd if (work->timingEnabled_) { work->ncclStartEvent_->record(ncclStream); @@ -3233,19 +3322,17 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? work->ncclEndEvent_->record(ncclStream); - if (avoidRecordStreams_) { - // other functions expect an initialized ptr if avoidRecordStreams_ is set - work->stashed_for_allocator_safety_ = - std::make_shared>(); - } - if (enqueue) { workEnqueue(work); } + // Reset coalescing state coalescing_state_ = 0; coalescedComm_ = nullptr; - return work; + coalescedTensors_.clear(); + // If in async mode, return work; otherwise, kernel is enqueued on current + // stream, no need to return work + return coalescedAsync_ ? work : nullptr; } c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { @@ -3278,11 +3365,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle, - bool avoidRecordStreams, bool nanCheck) { // Environment setting by the user may add onto collective call's option - avoidRecordStreams |= avoidRecordStreams_; nanCheck &= enableNanCheck_; auto device = getDevice(inputs[0]); @@ -3323,13 +3409,17 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } + coalescedAsync_ = asyncOp; } - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); - - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); + // in asyncOp=false [default] mode, we use currentStream as ncclStream + // otherwise, we use separate ncclStream and let it sync on currentStream + auto ncclStream = asyncOp ? ncclStreams_.at(key) + : at::cuda::getCurrentCUDAStream(device.index()); + if (asyncOp) { + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + } bool enqueue = !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; @@ -3339,9 +3429,19 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - if (avoidRecordStreams) { - work->stashed_for_allocator_safety_ = - std::make_shared>(inputs); + // If we are performing sync operations, i.e. equeuing kernel onto "current" + // stream, we don't need to do anything for tensor lifetime management. + // Otherwise, we need to stage the tensors will `work.wait()`. + if (asyncOp) { + // First select which shelf to stash onto: to `work` if single collective; + // to an inflight shelf if coalescing. + if (coalescing_state_) { + coalescedTensors_.stash(inputs); + coalescedTensors_.stash(outputs); + } else { + work->stashed_for_allocator_safety_->stash(inputs); + work->stashed_for_allocator_safety_->stash(outputs); + } } if (nanCheck) { @@ -3351,7 +3451,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } // Start event should only be recorded before the ncclGroupStart() - if (work->timingEnabled_) { + if (work->timingEnabled_ && !coalescing_state_) { work->ncclStartEvent_->record(ncclStream); } @@ -3367,21 +3467,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. - if (!avoidRecordStreams) { - for (const auto& input : inputs) { - if (!input.is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - input.values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - input.indices().storage().data_ptr(), ncclStream); - } - } - } // Not all collectives have the same signature, e.g, all-reduce take in a Tensor // as the input and output while all-to-all take in a vector of Tensors as input @@ -3433,7 +3518,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams; work->store_ = store_; assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as @@ -3451,7 +3535,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( workEnqueue(work); } - return work; + return asyncOp ? work : nullptr; } template @@ -3460,11 +3544,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::vector& outputs, Fn fn, OpType opType, - const char* profilingTitle, - bool avoidRecordStreams) { - // Environment setting by the user may add onto collective call's option - avoidRecordStreams |= avoidRecordStreams_; - + bool asyncOp, + const char* profilingTitle) { // Currently, the API permits one scenario where inputs.size() and // outputs.size() are > 0. // 1. If the call was a _coalesced call, all inputs must be on the same @@ -3510,13 +3591,17 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } + coalescedAsync_ = asyncOp; } - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); - - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); + // in asyncOp=false [default] mode, we use currentStream as ncclStream + // otherwise, we use separate ncclStream and let it sync on currentStream + auto ncclStream = asyncOp ? ncclStreams_.at(key) + : at::cuda::getCurrentCUDAStream(device.index()); + if (asyncOp) { + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + } auto work = initWork( device, @@ -3531,9 +3616,12 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - if (avoidRecordStreams) { - work->stashed_for_allocator_safety_ = - std::make_shared>(inputs); + // If we are performing sync operations, i.e. equeuing kernel onto "current" + // stream, we don't need to do anything for tensor lifetime management. + // Otherwise, we need to stage the tensors will `work.wait()`. + if (asyncOp) { + work->stashed_for_allocator_safety_->stash(inputs); + work->stashed_for_allocator_safety_->stash(outputs); } // Start event should only be recorded before the ncclGroupStart() (which @@ -3548,7 +3636,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // upgrade. Once a NCCL version is qualified, this code should not be needed at // runtime. #ifdef PGNCCL_ENABLE_HASH - if (enableCollecticeHashDebug_.load()) { + if (enableCollectiveHashDebug_.load()) { auto numel = getTensorsNumel(inputs); auto hashValue = hashTensors(inputs); PRINT_COLLECTIVE_HASH_SIGNATURE( @@ -3559,27 +3647,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( { torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking()); for (const auto i : c10::irange(inputs.size())) { - // Both `inputs' and `outputs' are created on a worker stream and used in - // different ncclStreams. Hence, both must record the ncclStream to - // prevent being freed before the collective finishes. - // - // We only record `inputs' here, and leave recording `outputs' to `fn' for - // operations where `inputs' and `outputs' are not the same. - // - // See [Sync Streams]. - if (!avoidRecordStreams) { - if (!inputs[i].is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].indices().storage().data_ptr(), ncclStream); - } - } #ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( fn(inputs[i], outputs[i], comm, ncclStream), @@ -3620,7 +3687,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams; work->store_ = store_; assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as @@ -3651,7 +3717,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // it, since interactions with it by usercode won't behave normally - they // won't observe work completion, for instance. Will this lead to silent // problems during capture? - return work; + return asyncOp ? work : nullptr; } template @@ -3669,13 +3735,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // to wait() on the returned handle, so ProcessGroupNCCL can't know // when it's safe to release the input back to the allocator, // and the present call has no way to know it's not an isend. - // Therefore, we warn and fall back to the typical recordStream logic: - if (avoidRecordStreams_) { - TORCH_WARN_ONCE( - "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " - "collectives."); - } - + // Therefore, we warn and fall back to the typical recordStream logic. + // TODO( kwen2501 ): revisit this when we have a better solution. auto device = getDevice(tensor); at::cuda::OptionalCUDAGuard gpuGuard(device); @@ -3730,6 +3791,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } + // For now, P2P ops are always put on internal stream + coalescedAsync_ = true; } // Used many times below, so we stash the unordered_map lookup @@ -3901,8 +3964,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle, - bool avoidRecordStreams, bool nanCheck) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; @@ -3913,8 +3976,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( pre, post, opType, + asyncOp, profilingTitle, - avoidRecordStreams, nanCheck); } @@ -3924,8 +3987,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::Tensor& output, Fn fn, OpType opType, + bool asyncOp, const char* profilingTitle, - bool avoidRecordStreams, bool nanCheck) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; @@ -3938,8 +4001,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, opType, + asyncOp, profilingTitle, - avoidRecordStreams, nanCheck); } @@ -3991,6 +4054,8 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( auto recvIndices = indices[0] * colSize; // prevent output and recvIndices from being freed + // TODO: not changing the lifetime management of outputs this time, + // revisit later c10::cuda::CUDACachingAllocator::recordStream( output.storage().data_ptr(), stream); c10::cuda::CUDACachingAllocator::recordStream( @@ -4022,6 +4087,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( } }, OpType::_ALLREDUCE_SPARSE, + opts.asyncOp, "nccl:all_reduce_sparse"); return work; #else @@ -4056,6 +4122,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( stream.stream()); }, OpType::ALLREDUCE, + opts.asyncOp, profilingTitle); } @@ -4156,6 +4223,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( stream.stream()); }, OpType::COALESCED, + opts.asyncOp, "nccl:allreduce_coalesced"); } @@ -4187,12 +4255,10 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( globalRankStride_, // globalRankStride_ this->getSize()); // worldSize - // avoidRecordStreams_ note: collective() will stash tensors. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank + opts.rootTensor; bool nanCheck = (root == rank_); + // avoidRecordStreams_ note: collective() will stash tensors. return collective( tensor, tensor, @@ -4209,8 +4275,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( stream.stream()); }, OpType::BROADCAST, + opts.asyncOp, "nccl:broadcast", - avoidRecordStreams, nanCheck); } @@ -4249,8 +4315,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, + opts.asyncOp, "nccl:_broadcast_oop", - /*avoidRecordStreams=*/false, nanCheck); } @@ -4309,6 +4375,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( stream.stream()); }, OpType::REDUCE, + opts.asyncOp, "nccl:reduce"); } @@ -4350,6 +4417,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( stream.stream()); }, OpType::REDUCE, + opts.asyncOp, "nccl:_reduce_oop"); } @@ -4393,10 +4461,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } + // See [We actually don't need to stash anything here]. return ncclAllGather( input.data_ptr(), output.data_ptr(), @@ -4412,27 +4477,27 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( // - inputTensors is stashed onto work->stashed_for_allocator_safety_ // in collective(). // - outputFlattened is stashed onto work->outputs_ in collective(). - // - User-facing outputTensors should be held by the user until after - // waiting on work_, or the call makes no sense. - // So all participating tensors are accounted for, and won't be - // released back to their allocation streams until after work_ is - // waited on. }, [&](at::cuda::CUDAStream& ncclStream, c10::intrusive_ptr& work) { + // User-facing outputTensors should be held by the user until after + // waiting on work_, or the call makes no sense. We do a stashing here + // in case user doesn't hold the outputTensors in downstream code, + // which can cause an early recyle by the CachingAllocator, which can + // lead to segfault or data corruption. + if (opts.asyncOp) { + work->stashed_for_allocator_safety_->stash(outputTensors_); + } // Copy the flattened output tensors to the outputs. at::cuda::CUDAStreamGuard guard(ncclStream); for (const auto j : c10::irange(outputTensors_.size())) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - outputTensors_[j].storage().data_ptr(), ncclStream); - } + // See [We actually don't need to stash anything here]. outputTensors_[j].copy_( outputFlattened[static_cast(j)], true); } }, OpType::ALLGATHER, + opts.asyncOp, "nccl:all_gather"); } else { const auto num_reduces = outputTensors_.size(); @@ -4440,7 +4505,8 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( for (const int64_t i : c10::irange(static_cast(num_reduces))) { auto& output = outputTensors_[i]; auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout}; + auto broadcastOpts = + BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp}; _broadcast_oop(output, input, broadcastOpts); } auto work = endCoalescing(OpType::ALLGATHER); @@ -4496,6 +4562,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( stream.stream()); }, OpType::COALESCED, + opts.asyncOp, "nccl:all_gather_into_tensor_coalesced"); } @@ -4541,10 +4608,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } const auto ncclDataType = getNcclDataType(input.scalar_type()); const auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4559,27 +4622,18 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( }, [&](at::cuda::CUDAStream& ncclStream, c10::intrusive_ptr& work) { - if (avoidRecordStreams_) { - // We only need to stash inputTensors. - // - inputFlattened is stashed onto - // work->stashed_for_allocator_safety_ - // in collective(). - // - User-facing outputTensors is stashed onto work->outputs_ in - // collective(), - // and should also be held by the user until after waiting on - // work_. - auto& v = work->stashed_for_allocator_safety_; - v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); + // We only need to stash inputTensors. + // - inputFlattened is stashed onto + // work->stashed_for_allocator_safety_ in collective(). + // - User-facing outputTensors is stashed onto work->outputs_ in + // collective(), and should also be held by the user until after + // waiting on work_. + if (opts.asyncOp) { + work->stashed_for_allocator_safety_->stash(inputTensors_); } - // Copy the input tensors to the flattened inputs. at::cuda::CUDAStreamGuard guard(ncclStream); for (const auto j : c10::irange(inputTensors_.size())) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - inputTensors_[j].storage().data_ptr(), ncclStream); - } inputFlattened[static_cast(j)].copy_( inputTensors_[j], true); } @@ -4587,6 +4641,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( [&](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::REDUCE_SCATTER, + opts.asyncOp, "nccl:reduce_scatter"); } else { const auto num_reduces = inputTensors_.size(); @@ -4598,7 +4653,8 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( opts.reduceOp, static_cast(i), static_cast(0), - opts.timeout}; + opts.timeout, + opts.asyncOp}; _reduce_oop(output, input, reduceOpts); } auto work = endCoalescing(OpType::REDUCE_SCATTER); @@ -4652,7 +4708,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( // stream so that the caching allocator can reuse memory pool for this stream // in a clever way. This setting is added for libraries like FSDP which uses // `reduce_scatter_tensor`. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); return collective( inputTensor, @@ -4661,10 +4716,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4678,8 +4729,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( stream.stream()); }, OpType::_REDUCE_SCATTER_BASE, - "nccl:_reduce_scatter_base", - avoidRecordStreams); + opts.asyncOp, + "nccl:_reduce_scatter_base"); } c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( @@ -4716,10 +4767,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4733,6 +4780,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( stream.stream()); }, OpType::COALESCED, + opts.asyncOp, "nccl:reduce_scatter_tensor_coalesced"); } @@ -4764,7 +4812,7 @@ c10::DeviceIndex ProcessGroupNCCL::guessDeviceId() const { devIdx, " as device used by this process is currently unknown. ", "This can potentially cause a hang if this rank to GPU mapping is incorrect. ", - "You can pecify device_id in init_process_group() to force use of a particular device."); + "You can specify device_id in init_process_group() to force use of a particular device."); return static_cast(devIdx); } @@ -4811,13 +4859,28 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); // All reduce to achieve the barrier - auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier"); + AllreduceOptions arOpts = AllreduceOptions(); + arOpts.asyncOp = opts.asyncOp; + auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts); + + if (opts.asyncOp) { + // Work will take over barrierTensors + auto ncclWork = dynamic_cast(work.get()); + // If user specified async, the work should not be nullptr + TORCH_CHECK(ncclWork); + // Put a marker here so that `work.wait()` issue by users does + // barrier-specific thing: CPU sync + ncclWork->isBarrierOp_ = true; + return work; + } - // Work will take over barrierTensors - auto ncclWork = dynamic_cast(work.get()); - TORCH_CHECK(ncclWork); - ncclWork->isBarrierOp_ = true; - return work; + // Otherwise, we are in sync mode, we directly wait here. + // (It is a CPU wait for barrier) + auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx); + // CUDAStream wrapper will correctly use a DeviceGuard here + currentStream.synchronize(); + // No work to return + return nullptr; } c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( @@ -4825,7 +4888,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, - const AllToAllOptions& /* unused */) { + const AllToAllOptions& opts) { check_gpu_single_tensor(outputTensor); check_gpu_single_tensor(inputTensor); if (outputSplitSizes.empty() && inputSplitSizes.empty()) { @@ -4856,16 +4919,12 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } torch::cuda::nccl::all2all_single_equal_split( input, output, this->getSize(), comm, stream); return ncclSuccess; }, OpType::ALLTOALL_BASE, + opts.asyncOp, "nccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); @@ -4907,10 +4966,6 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10d::computeLengthsAndOffsets( outputSplitSizes, output, &recv_lengths, &recv_offsets); // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } torch::cuda::nccl::all2all_single_unequal_split( input.data_ptr(), send_lengths.data(), @@ -4925,6 +4980,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( return ncclSuccess; }, OpType::ALLTOALL_BASE, + opts.asyncOp, "nccl:all_to_all"); } } @@ -4932,7 +4988,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& outputTensors, std::vector& inputTensors, - const AllToAllOptions& /* unused */) { + const AllToAllOptions& opts) { int64_t input_total_numel = 0; int64_t output_total_numel = 0; @@ -4977,18 +5033,11 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( return ncclSuccess; }, [&](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) { - if (avoidRecordStreams_) { - // inputTensor0 and outputTensor0 are stashed redundantly by - // collective(), but that's ok. - auto& v = work->stashed_for_allocator_safety_; - v->insert(v->end(), inputTensors.begin(), inputTensors.end()); - v->insert(v->end(), outputTensors.begin(), outputTensors.end()); - } - }, + c10::intrusive_ptr& work) {}, [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::ALLTOALL, + opts.asyncOp, "nccl:all_to_all"); } @@ -5186,14 +5235,6 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( ncclComm_t comm, at::cuda::CUDAStream& stream) { const auto root = opts.rootRank; - if (getRank() == root) { - if (!avoidRecordStreams_) { - for (auto const& output : outputs) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - } - } torch::cuda::nccl::gather( inputTensor, outputs, comm, stream, static_cast(root)); return ncclSuccess; @@ -5203,6 +5244,7 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::GATHER, + opts.asyncOp, "nccl:gather"); } @@ -5271,8 +5313,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // avoidRecordStreams_ note: collective() will stash outputTensors and // inputs, which == inputTensors[0] on the root rank where it matters. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank; bool nanCheck = (rank_ == root); @@ -5284,14 +5324,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (getRank() == root) { - if (!avoidRecordStreams) { - for (auto const& input : inputs) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), stream); - } - } - } torch::cuda::nccl::scatter( inputs, outputTensor, comm, stream, static_cast(root)); return ncclSuccess; @@ -5301,8 +5333,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::SCATTER, + opts.asyncOp, "nccl:scatter", - avoidRecordStreams, nanCheck); } @@ -5358,7 +5390,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( // stream so that the caching allocator can reuse memory pool for this stream // in a clever way. This setting is added for libraries like FSDP which uses // `all_gather_into_tensor`. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); return collective( input_tensor, @@ -5367,10 +5398,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } return ncclAllGather( input.data_ptr(), output.data_ptr(), @@ -5380,8 +5407,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( stream.stream()); }, OpType::_ALLGATHER_BASE, - "nccl:_all_gather_base", - avoidRecordStreams); + opts.asyncOp, + "nccl:_all_gather_base"); } // Create a memory allocator for NCCL. This allocator is used to allocate memory diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index f65d5955c8dd..82961db0ec17 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -195,11 +195,15 @@ struct DumpPipe { TORCH_CHECK( unlink(filename.c_str()) != -1 || errno == ENOENT, "Error removing existing named pipe ", - filename); + filename, + ", Error: ", + std::strerror(errno)); TORCH_CHECK( mkfifo(filename.c_str(), 0666) != -1, "Error creating named pipe ", - filename); + filename, + ", Error: ", + std::strerror(errno)); fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); LOG(INFO) << "Pipe file " << filename << " has been opened, write to it to trigger NCCL Debug Dump."; @@ -235,6 +239,34 @@ struct DumpPipe { }; #endif +// A shelf for stashing tensors between op call and `work.wait()`. +// Used in case of async ops. +class TensorShelf { + public: + // Stash tensors so that CachingAllocator cannot recycle them prematurely. + void stash(std::vector& tensors); + // Stash tensors from another shelf. + void stash(TensorShelf& other); + // Unstage the stashed tensors so that CachingAllocator can recycle them. + // Same as `clear()`. + void unstash(); + // Whether shelf is empty. + bool empty(); + // Clear the shelf. + void clear(); + + protected: + // Get the inner tensor vector. Use with caution as it is not protected by + // mutex. + std::vector& get(); + + private: + std::vector tVector_; + // Need a mutex to protect `tVector_` because it can be potentially accessed + // from both main thread and watchdog thread. + std::mutex mutex_; +}; + // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -382,9 +414,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Clone of blockingWait_ from ProcessGroupNCCL. bool blockingWait_{false}; - // Clone of avoidRecordStreams_ from ProcessGroupNCCL. - bool avoidRecordStreams_{false}; - // Clone of opTimeout_ from ProcessGroupNCCL. std::chrono::milliseconds opTimeout_{}; @@ -448,7 +477,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // caching allocator safety without any recordStream calls. // For in-place collectives, some refs stashed here may alias outputs_, // but that doesn't do any harm. - std::shared_ptr> stashed_for_allocator_safety_; + std::shared_ptr stashed_for_allocator_safety_; // The future returned by getFuture. c10::intrusive_ptr future_; @@ -530,7 +559,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { class DesyncDebugger { public: // Initialize and enable DesyncDebugger - void init(int rank, int size, c10::intrusive_ptr store); + void init( + int rank, + int size, + int globalRank, + int pgId, + c10::intrusive_ptr store); // Run desync debug. This function is called by watchdog at time of timeout. void run(); @@ -549,6 +583,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { // From ProcessGroupNCCL int rank_; int size_; + int globalRank_; + int pgId_; // Reference to the store so that we can log start/end event. c10::intrusive_ptr store_; @@ -889,8 +925,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { at::Tensor& output, Fn fn, OpType opType, + bool asyncOp, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, bool nanCheck = true); template @@ -901,8 +937,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, bool nanCheck = true); template @@ -913,8 +949,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, bool nanCheck = true); template @@ -923,8 +959,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::vector& output, Fn fn, OpType opType, - const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool asyncOp, + const char* profilingTitle = nullptr); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective @@ -1137,9 +1173,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // timeout for the dump to finish. int waitTimeoutDumpInMilSec_; - // promise to coordinate flight recorder dump. - std::promise promiseFlightRecorderDump_; - // Interval of check coordinated signals in ProcessGroupNCCL from other ranks // e.g., trigger the dump of the debugging info for timeout when notified. int coordCheckIntervalMilSec_; @@ -1233,14 +1266,26 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Stores communicators for all collectives run inside a coalescing block std::shared_ptr coalescedComm_ = nullptr; + // Whether the coalesced calls are sync or async. + bool coalescedAsync_; + + // keeps track of input and output tensors when coalescing is in flight. Will + // hand over these tensors to WorkNCCL's stash when coalescing is ended. + TensorShelf coalescedTensors_; + + // Some ops may have completed, but user still hasn't called `work.wait()`. + // When watchdog detects this, it transfers the TensorShelf from `work` to + // this `shelves` structure. Next time we execute ProcessGroupNCCL's methods + // on main thread, we clear the `shelves` in one shot. This is mainly because + // watchdog (a side thread) unstashing the shelf directly seems to cause some + // problem. + std::vector> shelvesToUnstash_; + std::mutex shelvesMutex_; + // Whether or not wait() and synchronize() are blocking operations that wait // for the operation to complete. bool blockingWait_ = false; - // Whether or not to hook the cache allocator to register all allocated - // tensors - bool useTensorRegisterAllocatorHook_ = false; - // Whether or not the workCleanupThread is used to perform async error // handling. ErrorHandlingMode asyncErrorHandling_ = NoHandling; @@ -1277,7 +1322,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Flag to enable the print of hash value of input/output of collectives for // verification. - std::atomic enableCollecticeHashDebug_{}; + std::atomic enableCollectiveHashDebug_{}; // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set bool avoidRecordStreams_ = false; diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 76eb7205a398..f68681de1698 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -250,6 +250,10 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "two_shot_all_reduce_out(Tensor(a!) input, str reduce_op, str group_name, Tensor(b!) output) -> Tensor(b!)"); + // note this implementation also modified the input tensor + m.def( + "reduce_scatter_out(Tensor(a!) input, str group_name, bool split_last_dim, Tensor(b!) output) -> Tensor(b!)"); + // An mm that supports consuming asynchronous input. It guarantees the // following rasterization order, and that the corresponding signal arrives // before an input chunk is consumed. diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 5d15708c953e..8fec5dd0e9e2 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -122,6 +122,7 @@ struct BroadcastOptions { struct AllreduceOptions { ReduceOp reduceOp = ReduceOp::SUM; std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; std::optional sparseIndices = std::nullopt; }; @@ -132,6 +133,7 @@ struct ReduceOptions { int64_t rootRank = 0; int64_t rootTensor = 0; std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; }; struct AllgatherOptions { @@ -142,6 +144,7 @@ struct AllgatherOptions { struct GatherOptions { int64_t rootRank = 0; std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; }; struct ScatterOptions { @@ -158,12 +161,14 @@ struct ReduceScatterOptions { struct AllToAllOptions { std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; }; struct BarrierOptions { std::vector device_ids; std::chrono::milliseconds timeout = kUnsetTimeout; std::optional device; + bool asyncOp = true; }; struct DistributedBackendOptions { diff --git a/torch/csrc/distributed/c10d/cuda/utils.cpp b/torch/csrc/distributed/c10d/cuda/utils.cpp index 7884be53a1a7..0072fab983f6 100644 --- a/torch/csrc/distributed/c10d/cuda/utils.cpp +++ b/torch/csrc/distributed/c10d/cuda/utils.cpp @@ -1,3 +1,5 @@ +#include + #include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index ddd75d234449..f1bd5fb14cf1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -999,20 +999,23 @@ This class does not support ``__members__`` property.)"); py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp) - .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout); + .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout) + .def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp); py::class_<::c10d::AllreduceCoalescedOptions>( module, "AllreduceCoalescedOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp) - .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout); + .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout) + .def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp); py::class_<::c10d::ReduceOptions>(module, "ReduceOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp) .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank) .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor) - .def_readwrite("timeout", &::c10d::ReduceOptions::timeout); + .def_readwrite("timeout", &::c10d::ReduceOptions::timeout) + .def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp); py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions") .def(py::init<>()) @@ -1022,7 +1025,8 @@ This class does not support ``__members__`` property.)"); py::class_<::c10d::GatherOptions>(module, "GatherOptions") .def(py::init<>()) .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank) - .def_readwrite("timeout", &::c10d::GatherOptions::timeout); + .def_readwrite("timeout", &::c10d::GatherOptions::timeout) + .def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp); py::class_<::c10d::ScatterOptions>(module, "ScatterOptions") .def(py::init<>()) @@ -1040,11 +1044,13 @@ This class does not support ``__members__`` property.)"); .def(py::init<>()) .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) .def_readwrite("timeout", &::c10d::BarrierOptions::timeout) - .def_readwrite("device", &::c10d::BarrierOptions::device); + .def_readwrite("device", &::c10d::BarrierOptions::device) + .def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp); py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") .def(py::init<>()) - .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); + .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout) + .def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp); py::class_<::c10d::DistributedBackendOptions>( module, "_DistributedBackendOptions") @@ -2843,24 +2849,36 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). processGroupGloo .def_static( "create_device", - [](const std::string& hostname, const std::string& interface) + [](const std::string& hostname, + const std::string& interface, + std::optional lazyInit_) -> std::shared_ptr<::gloo::transport::Device> { + bool lazyInit = + lazyInit_.value_or(::c10d::getDefaultGlooLazyInit()); + if (!hostname.empty()) { return ::c10d::ProcessGroupGloo::createDeviceForHostname( - hostname); + hostname, lazyInit); } if (!interface.empty()) { return ::c10d::ProcessGroupGloo::createDeviceForInterface( - interface); + interface, lazyInit); } throw std::invalid_argument( "Specify either `hostname` or `interface` argument."); }, py::arg("hostname") = "", - py::arg("interface") = "") + py::arg("interface") = "", + py::arg("lazy_init") = std::nullopt) .def_static( "create_default_device", - &::c10d::ProcessGroupGloo::createDefaultDevice); + [](std::optional lazyInit_) { + bool lazyInit = + lazyInit_.value_or(::c10d::getDefaultGlooLazyInit()); + + return ::c10d::ProcessGroupGloo::createDefaultDevice(lazyInit); + }, + py::arg("lazy_init") = std::nullopt); processGroupGloo .def( @@ -2892,20 +2910,22 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::gil_scoped_release nogil{}; auto options = ::c10d::ProcessGroupGloo::Options::create(); + bool lazyInit = ::c10d::getDefaultGlooLazyInit(); // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); if (ifnameEnv && strlen(ifnameEnv) > 1) { for (const auto& iface : ::c10d::split(',', ifnameEnv)) { options->devices.push_back( - ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); + ::c10d::ProcessGroupGloo::createDeviceForInterface( + iface, lazyInit)); } } else { // If no hostname is specified, this function looks up // the machine's hostname and returns a device instance // associated with the address that the hostname resolves to. options->devices.push_back( - ::c10d::ProcessGroupGloo::createDefaultDevice()); + ::c10d::ProcessGroupGloo::createDefaultDevice(lazyInit)); } options->timeout = timeout; diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 6795857ed9f5..bbf66afeb18d 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -20,6 +20,8 @@ #include +#include + #ifdef USE_CUDA #include #endif @@ -525,11 +527,11 @@ static PyTypeObject TensorGuardsType = { PyVarObject_HEAD_INIT(nullptr, 0) struct AutocastState { static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES; - std::array enabled; - std::array dtype; + std::array enabled{}; + std::array dtype{}; bool cache_enabled; - AutocastState() : enabled{}, dtype{} { + AutocastState() { for (size_t i = 0; i < DEVICES.size(); i++) { enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]); dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]); @@ -552,6 +554,20 @@ struct AutocastState { } return true; } + + template + friend void to_json(T& json_j, const AutocastState& json_t) { + json_j["enabled"] = json_t.enabled; + json_j["dtype"] = json_t.dtype; + json_j["cached_enabled"] = json_t.cache_enabled; + } + + template + friend void from_json(const T& json_j, AutocastState& json_t) { + json_t.enabled = json_j.at("enabled"); + json_t.dtype = json_j.at("dtype"); + json_t.cache_enabled = json_j.at("cached_enabled"); + } }; // TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is @@ -623,6 +639,40 @@ struct GlobalStateGuard { return os.str(); } + template + friend void to_json(T& json_j, const GlobalStateGuard& json_t) { + json_j["grad_mode"] = json_t._grad_mode; + json_j["autocast_state"] = json_t._autocast_state; + json_j["torch_function"] = json_t._torch_function; + json_j["torch_function_all_disabled"] = json_t._torch_function_all_disabled; + json_j["deterministic_algorithms"] = json_t._deterministic_algorithms; + json_j["deterministic_algorithms_warn_only"] = + json_t._deterministic_algorithms_warn_only; + json_j["allow_tf32"] = json_t._allow_tf32; + json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce; + json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce; + json_j["num_threads"] = json_t._num_threads; + json_j["default_dtype"] = json_t._default_dtype.toScalarType(); + } + + template + friend void from_json(const T& json_j, GlobalStateGuard& json_t) { + json_t._grad_mode = json_j.at("grad_mode"); + json_t._autocast_state = json_j.at("autocast_state"); + json_t._torch_function = json_j.at("torch_function"); + json_t._torch_function_all_disabled = + json_j.at("torch_function_all_disabled"); + json_t._deterministic_algorithms = json_j.at("deterministic_algorithms"); + json_t._deterministic_algorithms_warn_only = + json_j.at("deterministic_algorithms_warn_only"); + json_t._allow_tf32 = json_j.at("allow_tf32"); + json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce"); + json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce"); + json_t._num_threads = json_j.at("num_threads"); + json_t._default_dtype = + caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype")); + } + bool _grad_mode; AutocastState _autocast_state; bool _torch_function; @@ -663,6 +713,25 @@ PyObject* GlobalStateGuard_reason( return PyUnicode_FromString(self->reason().c_str()); } +PyObject* GlobalStateGuard_dump( + GlobalStateGuard* self, + PyObject* args, + PyObject* kwargs) { + return PyUnicode_FromString(nlohmann::json(*self).dump().c_str()); +} + +PyObject* GlobalStateGuard_load( + GlobalStateGuard* self, + PyObject* args, + PyObject* kwargs) { + char* json; + if (!PyArg_ParseTuple(args, "s", &json)) { + throw std::runtime_error("Cannot parse as json string."); + } + nlohmann::json::parse(json).get_to(*self); + Py_RETURN_NONE; +} + // NOLINTNEXTLINE(*array*) static PyMethodDef GlobalStateGuard_methods[] = { {"check", @@ -673,6 +742,14 @@ static PyMethodDef GlobalStateGuard_methods[] = { (PyCFunction)(void*)GlobalStateGuard_reason, METH_NOARGS, "Return string reason for guard check failing"}, + {"dump", + (PyCFunction)(void*)GlobalStateGuard_dump, + METH_NOARGS, + "Return serialized json format"}, + {"load", + (PyCFunction)(void*)GlobalStateGuard_load, + METH_VARARGS, + "Parse serialized json format"}, {nullptr}}; static PyTypeObject GlobalStateGuardType = { PyVarObject_HEAD_INIT(nullptr, 0) }; @@ -1977,8 +2054,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { py::object py_addr_keep_alive, py::object verbose_code_parts) : RelationalGuard(std::move(verbose_code_parts)), - _py_addr_keep_alive(std::move(py_addr_keep_alive)), - _args_seen{0} { + _py_addr_keep_alive(std::move(py_addr_keep_alive)) { _nargs_int = PyLong_AsSize_t(nargs_int.ptr()); _nargs_float = PyLong_AsSize_t(nargs_float.ptr()); _nargs = _nargs_int + _nargs_float; @@ -2048,7 +2124,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { bool result = check_nopybind(value); if (!result) { - std::string msg = "Shape guard failed with values: "; + std::string msg = "\"Shape guard failed with values: "; for (auto v : _args_int) { msg += std::to_string(v) + ","; } @@ -2056,6 +2132,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { msg += std::to_string(v) + ","; } msg.pop_back(); + msg += "\""; auto msgs = py::list(); for (auto code_part : verbose_code_parts()) { msgs.append(code_part); @@ -2072,7 +2149,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { private: py::object _py_addr_keep_alive; - size_t _args_seen, _nargs_float, _nargs_int, _nargs; + size_t _args_seen{0}, _nargs_float, _nargs_int, _nargs; std::vector _args_int; std::vector _args_float; std::function _guard_check_fn; @@ -3496,7 +3573,6 @@ class GetAttrGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GetAttrGuardAccessor(GuardManager* guard_manager, GetAttrGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -3515,7 +3591,7 @@ class GetAttrGuardAccessor : public GuardAccessor { private: // no need of py::object here because the attr_name is already passed on to // the base class as accessor_key which is a py::object. - PyObject* _attr_name; + PyObject* _attr_name{nullptr}; }; /** @@ -3571,7 +3647,6 @@ class GetGenericDictGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GetGenericDictGuardAccessor( GuardManager* guard_manager, GetGenericDictGuardAccessor* from) @@ -3639,7 +3714,6 @@ class GetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GetItemGuardAccessor(GuardManager* guard_manager, GetItemGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -3658,7 +3732,7 @@ class GetItemGuardAccessor : public GuardAccessor { private: // no need of py::object here because the attr_name is already passed on to // the base class as accessor_key which is a py::object. - PyObject* _attr_name; + PyObject* _attr_name{nullptr}; }; /** @@ -3757,7 +3831,6 @@ class FrameLocalsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FrameLocalsGuardAccessor( GuardManager* guard_manager, FrameLocalsGuardAccessor* from) @@ -3778,12 +3851,12 @@ class FrameLocalsGuardAccessor : public GuardAccessor { } private: - PyObject* _key; - int _framelocals_idx; + PyObject* _key{nullptr}; + int _framelocals_idx{-1}; // If immutable object and dict tag matches, we can skip the guard subtree and // return true. - bool _is_immutable_object; + bool _is_immutable_object{false}; }; /** @@ -3847,7 +3920,6 @@ class DictGetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) DictGetItemGuardAccessor( GuardManager* guard_manager, DictGetItemGuardAccessor* from) @@ -3867,11 +3939,11 @@ class DictGetItemGuardAccessor : public GuardAccessor { } private: - PyObject* _key; + PyObject* _key{nullptr}; // If immutable object and dict tag matches, we can skip the guard subtree and // return true. - bool _is_immutable_object; + bool _is_immutable_object{false}; }; /** @@ -3924,7 +3996,6 @@ class ListGetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ListGetItemGuardAccessor( GuardManager* guard_manager, ListGetItemGuardAccessor* from) @@ -3943,7 +4014,7 @@ class ListGetItemGuardAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; /** @@ -3996,7 +4067,6 @@ class TupleGetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TupleGetItemGuardAccessor( GuardManager* guard_manager, TupleGetItemGuardAccessor* from) @@ -4016,7 +4086,7 @@ class TupleGetItemGuardAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; enum class TensorProperty { @@ -4143,7 +4213,6 @@ class TensorPropertyGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorPropertyGuardAccessor( GuardManager* guard_manager, TensorPropertyGuardAccessor<_prop>* from) @@ -4163,7 +4232,7 @@ class TensorPropertyGuardAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; /** @@ -4210,7 +4279,6 @@ class IndexedGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) IndexedGuardAccessor(GuardManager* guard_manager, IndexedGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4227,7 +4295,7 @@ class IndexedGuardAccessor : public GuardAccessor { } private: - py::int_ _index; + py::int_ _index{-1}; }; /** @@ -4287,7 +4355,6 @@ class GradGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GradGuardAccessor(GuardManager* guard_manager, GradGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4361,7 +4428,6 @@ class FuncDefaultsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FuncDefaultsGuardAccessor( GuardManager* guard_manager, FuncDefaultsGuardAccessor* from) @@ -4437,7 +4503,6 @@ class FuncKwDefaultsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FuncKwDefaultsGuardAccessor( GuardManager* guard_manager, FuncKwDefaultsGuardAccessor* from) @@ -4494,7 +4559,6 @@ class GlobalsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GlobalsGuardAccessor(GuardManager* guard_manager, GlobalsGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4513,7 +4577,7 @@ class GlobalsGuardAccessor : public GuardAccessor { private: // no need of py::object here because the globals_dict is already passed on to // the base class as accessor_key which is a py::object. - PyObject* _globals_dict; + PyObject* _globals_dict{nullptr}; }; /** @@ -4554,7 +4618,6 @@ class TypeGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TypeGuardAccessor(GuardManager* guard_manager, TypeGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4623,7 +4686,6 @@ class TupleIteratorGetItemAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TupleIteratorGetItemAccessor( GuardManager* guard_manager, TupleIteratorGetItemAccessor* from) @@ -4643,7 +4705,7 @@ class TupleIteratorGetItemAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; /** @@ -4739,7 +4801,6 @@ class GlobalWeakRefGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GlobalWeakRefGuardAccessor( GuardManager* guard_manager, GlobalWeakRefGuardAccessor* from) @@ -4758,7 +4819,7 @@ class GlobalWeakRefGuardAccessor : public GuardAccessor { } private: - PyObject* _global_name; + PyObject* _global_name{nullptr}; }; /** @@ -4830,7 +4891,6 @@ class WeakRefCallGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) WeakRefCallGuardAccessor( GuardManager* guard_manager, WeakRefCallGuardAccessor* from) @@ -4910,7 +4970,6 @@ class CallFunctionNoArgsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CallFunctionNoArgsGuardAccessor( GuardManager* guard_manager, CallFunctionNoArgsGuardAccessor* from) @@ -4982,7 +5041,6 @@ class PythonLambdaGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) PythonLambdaGuardAccessor( GuardManager* guard_manager, PythonLambdaGuardAccessor* from) diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index 425a28393113..d3244441da16 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -1,14 +1,11 @@ #include -#include #include #include #include -#include namespace { -using NodeSortKey = c10::SmallVector; struct NodeBase; // Thrown to exit out of a C++ function and return an error to Python. @@ -166,22 +163,7 @@ struct NodeBase { PyObject* users; PyObject* _repr_fn; PyObject* meta; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)]; - - inline NodeSortKey& sort_key() { - return *reinterpret_cast(sort_key_buf); - } - - // Equivalent to: - // p, n = self._prev, self._next - // p._next, n._prev = n, p - inline void remove_from_list() { - NodeBase* p = this->_prev; - NodeBase* n = this->_next; - p->_next = n; - n->_prev = p; - } + PyObject* _sort_key; }; static PyObject* NodeBase_new( @@ -191,8 +173,6 @@ static PyObject* NodeBase_new( PyObject* self = type->tp_alloc(type, 0); if (!self) return nullptr; - new (reinterpret_cast(self)->sort_key_buf) - NodeSortKey(); // placement new does not allocate return self; } @@ -221,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { self->users = PyDict_New(); self->_repr_fn = Py_NewRef(Py_None); self->meta = PyDict_New(); + self->_sort_key = PyTuple_New(0); return 0; } @@ -240,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = { {"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr}, {"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr}, {"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr}, + {"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr}, {nullptr} /* Sentinel */ }; @@ -257,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { Py_VISIT(self->users); Py_VISIT(self->_repr_fn); Py_VISIT(self->meta); + Py_VISIT(self->_sort_key); return 0; } @@ -274,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) { Py_CLEAR(self->users); Py_CLEAR(self->_repr_fn); Py_CLEAR(self->meta); + Py_CLEAR(self->_sort_key); return 0; } static void NodeBase_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); - reinterpret_cast(self)->sort_key().~NodeSortKey(); (void)NodeBase_clear((NodeBase*)self); Py_TYPE(self)->tp_free(self); } @@ -338,191 +321,15 @@ static PyObject* NodeBase__update_args_kwargs( } } -static PyObject* NodeBase__remove_from_list( - PyObject* self, - PyObject* _ignored) { - reinterpret_cast(self)->remove_from_list(); - Py_RETURN_NONE; -} - -static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) { - if (self_ == arg) { - Py_RETURN_NONE; - } - if (!is_node(arg)) { - PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node"); - return nullptr; - } - NodeBase* self = reinterpret_cast(self_); - NodeBase* x = reinterpret_cast(arg); - if (self->graph != x->graph) { - PyErr_SetString( - PyExc_AssertionError, - "Attempting to move a Node into a different Graph"); - return nullptr; - } - - x->remove_from_list(); - NodeBase* p = self->_prev; - p->_next = x; - x->_prev = p; - x->_next = self; - self->_prev = x; - - // Now compute x.sort_key() - const NodeSortKey& psk = x->_prev->sort_key(); - const NodeSortKey& nsk = x->_next->sort_key(); - if (psk.size() > nsk.size()) { - // prefix = psk[: len(nsk)+1] - size_t slice_len = nsk.size() + 1; - NodeSortKey prefix(psk.begin(), psk.begin() + slice_len); - // last element is idx => increment by 1 - prefix.back()++; - x->sort_key() = std::move(prefix); - } else if (psk.size() < nsk.size()) { - // prefix = nsk[: len(psk)+1] - size_t slice_len = psk.size() + 1; - NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len); - // last element is idx => decrement by 1 - prefix.back()--; - x->sort_key() = std::move(prefix); - } else { - // same length => add a 0 - x->sort_key() = psk; - x->sort_key().emplace_back(0); - } - Py_RETURN_NONE; -} - -// __lt__(self, other): Return self.sort_key < other.sort_key -static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) { - // METH_O => one argument: 'other' - if (!is_node(other)) { - Py_RETURN_NOTIMPLEMENTED; - } - const NodeSortKey& lhs = reinterpret_cast(self)->sort_key(); - const NodeSortKey& rhs = reinterpret_cast(other)->sort_key(); - bool less = std::lexicographical_compare( - lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); - if (less) - Py_RETURN_TRUE; - Py_RETURN_FALSE; -} - -// __gt__(self, other): Return self.sort_key() > other.sort_key -static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) { - if (!is_node(other)) { - Py_RETURN_NOTIMPLEMENTED; - } - const NodeSortKey& lhs = reinterpret_cast(self)->sort_key(); - const NodeSortKey& rhs = reinterpret_cast(other)->sort_key(); - // "a > b" is equivalent to "b < a" - bool greater = std::lexicographical_compare( - rhs.begin(), rhs.end(), lhs.begin(), lhs.end()); - if (greater) - Py_RETURN_TRUE; - Py_RETURN_FALSE; -} - -static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) { - if (self == other) { - Py_RETURN_TRUE; - } - return NodeBase___gt__(self, other); -} - -// __le__(self, other): Return not (self > other) -static PyObject* NodeBase___le__(PyObject* self, PyObject* other) { - if (self == other) { - Py_RETURN_TRUE; - } - return NodeBase___lt__(self, other); -} - -// Convert the NodeBase::sort_key vector into a Python tuple of ints -// Only used by pickle/__getstate__ -static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) { - NodeBase* node = reinterpret_cast(self); - const NodeSortKey& vec = node->sort_key(); - Py_ssize_t n = static_cast(vec.size()); - THPObjectPtr tuple(PyTuple_New(n)); - if (!tuple) { - return nullptr; // Out of memory - } - for (Py_ssize_t i = 0; i < n; i++) { - PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i])); - } - return tuple.release(); -} - -// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g. -// node._sort_key = (1,2,3) Only used by pickle/__setstate__ -static int NodeBase_set_sort_key( - PyObject* self, - PyObject* value, - void* /*closure*/) { - NodeBase* node = reinterpret_cast(self); - if (!PyTuple_Check(value)) { - PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints"); - return -1; - } - Py_ssize_t size = PyTuple_GET_SIZE(value); - NodeSortKey new_vec; - new_vec.reserve(size); - for (Py_ssize_t i = 0; i < size; i++) { - int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i)); - if (val == -1 && PyErr_Occurred()) { - return -1; - } - new_vec.emplace_back(val); - } - node->sort_key() = std::move(new_vec); - return 0; -} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef NodeBase_methods[] = { {"_update_args_kwargs", (PyCFunction)(void*)(NodeBase__update_args_kwargs), METH_FASTCALL, "Internal method: do not call directly."}, - {"_remove_from_list", - (PyCFunction)(void*)(NodeBase__remove_from_list), - METH_NOARGS, - "Internal method: do not call directly."}, - {"_prepend", - (PyCFunction)(void*)(NodeBase__prepend), - METH_O, - "Internal method: do not call directly."}, - {"__lt__", - (PyCFunction)(void*)NodeBase___lt__, - METH_O, - "Return True if self.sort_key < other.sort_key"}, - {"__gt__", - (PyCFunction)(void*)NodeBase___gt__, - METH_O, - "Return True if self.sort_key > other.sort_key"}, - {"__ge__", - (PyCFunction)(void*)NodeBase___ge__, - METH_O, - "Return True if self.sort_key >= other.sort_key"}, - {"__le__", - (PyCFunction)(void*)NodeBase___le__, - METH_O, - "Return True if self.sort_key <= other.sort_key"}, {nullptr, nullptr, 0, nullptr} // Sentinel }; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -static PyGetSetDef NodeBase_getset[] = { - {"_sort_key", // attribute name in Python - (getter)NodeBase_get_sort_key, // C getter function - (setter)NodeBase_set_sort_key, // C setter function - (char*)"The sort key as a tuple of ints", // docstring - nullptr}, - {nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel -}; - PyTypeObject NodeBaseType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ @@ -554,7 +361,7 @@ PyTypeObject NodeBaseType = { nullptr, /* tp_iternext */ NodeBase_methods, /* tp_methods */ NodeBase_members, /* tp_members */ - NodeBase_getset, /* tp_getset */ + nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ diff --git a/torch/csrc/inductor/aoti_include/common.h b/torch/csrc/inductor/aoti_include/common.h index e942e48823fa..e0e61ac0615d 100644 --- a/torch/csrc/inductor/aoti_include/common.h +++ b/torch/csrc/inductor/aoti_include/common.h @@ -9,8 +9,6 @@ #include #include -using half = at::Half; -using bfloat16 = at::BFloat16; // Round up to the nearest multiple of 64 [[maybe_unused]] inline int64_t align(int64_t nbytes) { diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index 10ea643ae18b..9123c942754f 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -109,7 +109,7 @@ consider rebuild your model with the latest AOTInductor."); if (file_exists(json_filename)) { proxy_executor_ = std::make_unique( - json_filename, device_str == "cpu"); + json_filename, device_str); proxy_executor_handle_ = reinterpret_cast(proxy_executor_.get()); } else { diff --git a/torch/csrc/inductor/aoti_runtime/constant_type.h b/torch/csrc/inductor/aoti_runtime/constant_type.h new file mode 100644 index 000000000000..053eed728fb0 --- /dev/null +++ b/torch/csrc/inductor/aoti_runtime/constant_type.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. + +namespace torch::aot_inductor { + +enum ConstantType : uint8_t { + Unknown = 0, + Parameter = 1, + Buffer = 2, + TensorConstant = 3, + FoldedConstant = 4, +}; + +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index 617548a53a3c..83dd1c4e7437 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -20,6 +20,7 @@ #else #include #endif +#include #define AOTI_RUNTIME_CHECK(EXPR, MSG) \ do { \ @@ -89,15 +90,9 @@ RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { } // anonymous namespace namespace torch::aot_inductor { -enum ConstantType : uint8_t { - Unknown = 0, - Parameter = 1, - Buffer = 2, - TensorConstant = 3, - FoldedConstant = 4, -}; -using ConstantMap = std::unordered_map; +using ConstantMap = + std::unordered_map; // valid device strs are: cpu, cuda, cuda:0, cuda:1, ... // Update the list here if more devices are supported in the future diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 42f6157f5eef..408a9274417c 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -349,7 +349,8 @@ class AOTInductorModelContainer { tensor = it->second; } - constants_map_to_update->insert_or_assign(constant_name, tensor); + constants_map_to_update->insert_or_assign( + constant_name, RAIIAtenTensorHandle(tensor)); } // Update the inactive constant array. update_array_from_map( @@ -437,7 +438,8 @@ class AOTInductorModelContainer { // Now place the tensor to constants_map. Note at this point the ownership // of the tensor_handle will be taken over. - constants_map_to_update->insert_or_assign(constant_name, tensor_handle); + constants_map_to_update->insert_or_assign( + constant_name, RAIIAtenTensorHandle(tensor_handle)); } // Update the inactive constant array. update_array_from_map( diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index 2f23826be77f..9e2f5c160f73 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -135,6 +135,122 @@ class RAIIAtenTensorHandle { std::unique_ptr handle_; }; +class MaybeOwningAtenTensorHandle { + public: + MaybeOwningAtenTensorHandle() : handle_(nullptr), raii_handle_() {} + // We skip copy constructor as MaybeOwningAtenTensorHandle might be RAII which + // makes it undefined. + MaybeOwningAtenTensorHandle(const MaybeOwningAtenTensorHandle& other) = + delete; + MaybeOwningAtenTensorHandle& operator=( + const MaybeOwningAtenTensorHandle& other) = delete; + + // Move constructor and move assignment operator + MaybeOwningAtenTensorHandle(MaybeOwningAtenTensorHandle&& other) = default; + MaybeOwningAtenTensorHandle& operator=(MaybeOwningAtenTensorHandle&& other) = + default; + + // Steal the ownership from another RAIIAtenTensorHandle using std::move + MaybeOwningAtenTensorHandle(RAIIAtenTensorHandle&& other) + : raii_handle_(std::move(other)) { + handle_ = raii_handle_.get(); + } + MaybeOwningAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) { + raii_handle_ = std::move(other); + handle_ = raii_handle_.get(); + return *this; + } + + // By default, steal the ownership from raw AtenTensorHandle + MaybeOwningAtenTensorHandle(AtenTensorHandle handle) : raii_handle_(handle) { + handle_ = raii_handle_.get(); + } + + // If user_managed is true, we do not steal the ownership. + MaybeOwningAtenTensorHandle(AtenTensorHandle handle, bool user_managed) { + if (user_managed) { + handle_ = handle; + } else { + raii_handle_ = RAIIAtenTensorHandle(handle); + handle_ = raii_handle_.get(); + } + } + + ~MaybeOwningAtenTensorHandle() { + // This is no-op if we don't hold raii_handle with the + // MaybeOwningAtenTensorHandle. + raii_handle_.reset(); + } + + // Return a raw AtenTensorHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenTensorHandle() const { + return handle_; + } + + AtenTensorHandle release() { + if (raii_handle_) { + return raii_handle_.release(); + } else { + AtenTensorHandle handle = handle_; + handle_ = nullptr; + return handle; + } + } + + AtenTensorHandle get() const { + return handle_; + } + + void reset() { + handle_ = nullptr; + raii_handle_.reset(); + } + + int64_t size(int64_t d) { + int64_t size = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_, d, &size)); + return size; + } + + int64_t stride(int64_t d) { + int64_t stride = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(handle_, d, &stride)); + return stride; + } + + int64_t storage_offset() { + int64_t storage_offset = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(handle_, &storage_offset)); + return storage_offset; + } + + void* data_ptr() const { + void* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &result)); + return result; + } + + int64_t* sizes() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_, &result)); + return result; + } + + int64_t* strides() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_, &result)); + return result; + } + + private: + // handle_ is the underlying AtenTensorHandle of raii_handle_ if raii_handle_ + // exists. Otherwise it would just be the AtenTensorHandle passed in by users. + AtenTensorHandle handle_; + RAIIAtenTensorHandle raii_handle_; +}; + // Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle inline std::vector steal_from_raw_handles_to_raii_handles( AtenTensorHandle* handles, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index f56f6eca7449..be187c0118a2 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -783,11 +783,6 @@ int32_t aoti_torch_dtype() = delete; return aoti_torch_dtype_##typename(); \ } -namespace c10 { -struct BFloat16; -struct Half; -} // namespace c10 - DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) DEFINE_DTYPE_SPECIALIZATION(c10::complex, complex64) diff --git a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h index 86f09416f9fe..c7b713bf7f87 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h @@ -170,7 +170,7 @@ aoti_torch_cpu__qlinear_pointwise_binary_tensor( const char* unary_post_op_algorithm, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qconv2d_pointwise_tensor( +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qconv_pointwise_tensor( AtenTensorHandle X, AtenTensorHandle act_scale, AtenTensorHandle act_zero_point, diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp index 99d9045a63b0..fc25970c00b3 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -18,6 +18,19 @@ bool has_key( return map.find(key) != map.end(); } +c10::Device normalize_device(const c10::Device& device) { + // cpu device doesn't have an index + // cuda device must have an index + if (device.is_cpu()) { + return c10::Device(c10::DeviceType::CPU); + } else if (device.is_cuda()) { + return c10::Device( + c10::DeviceType::CUDA, device.has_index() ? device.index() : 0); + } else { + TORCH_CHECK(false, "Unsupported device type", device); + } +} + #ifdef _WIN32 const std::string k_separator = "\\"; #else @@ -211,12 +224,11 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( serialized_arg_val["index"].is_number()) { auto index = serialized_arg_val["index"].get(); device_string += ":" + std::to_string(index); - device_->set_index(static_cast(index)); } c10::Device device(device_string); - if (device.type() != device_->type()) { + if (device != *device_) { VLOG(1) << "ProxyExecutor is using " << *device_ << " for " << op_kernel->target_ << " argument #" << index << ", which is different from the one serialized in thrift: " @@ -579,15 +591,12 @@ std::unique_ptr OSSProxyExecutor:: OSSProxyExecutor::OSSProxyExecutor( const std::string& json_path, - bool is_cpu, + const std::string& device_str, std::optional> custom_objs) { - if (is_cpu) { - device_ = std::make_unique(c10::DeviceType::CPU); - } else { - int device_idx = -1; - device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); - } - + // CUDA device must have an index as a kernel may require + // an explicit device index. e.g., merge_pooled_embeddings + c10::Device normalized_device = normalize_device(c10::Device(device_str)); + device_ = std::make_unique(normalized_device); // If custom_objs is provided, use it instead of loading from // custom_objs_config.json If custom_objs is not provided, try to load from // custom_objs_config.json @@ -617,7 +626,7 @@ OSSProxyExecutor::OSSProxyExecutor( for (auto& [customObjName, file_name] : custom_objs_json.items()) { std::string customObjPath = folder_path + k_separator + file_name.get(); - LOG(INFO) << "Loading custom object to FbProxyExecutor from: " + LOG(INFO) << "Loading custom object to OSSProxyExecutor from: " << customObjPath; std::ifstream custom_obj_file(customObjPath, std::ios::binary); diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h index d20ef2e52186..551c89a3b793 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -12,26 +12,11 @@ namespace torch::aot_inductor { -enum class DynamicArgType : int { - TensorType = 0, - ListTensorType = 1, - ListOptionalTensorType = 2, - IntType = 3, - ListIntType = 4, - NoneType = 5, -}; - inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) { os << static_cast(arg_type); return os; } -inline bool isTensorType(DynamicArgType arg_type) { - return arg_type == DynamicArgType::TensorType || - arg_type == DynamicArgType::ListTensorType || - arg_type == DynamicArgType::ListOptionalTensorType; -} - struct OSSDynamicArg { OSSDynamicArg( int arg_index, @@ -118,7 +103,7 @@ class OSSProxyExecutor : public ProxyExecutor { public: explicit OSSProxyExecutor( const std::string& json_path, - bool is_cpu, + const std::string& device_str, std::optional> custom_objs = std::nullopt); diff --git a/torch/csrc/inductor/aoti_torch/proxy_executor.h b/torch/csrc/inductor/aoti_torch/proxy_executor.h index 6943bca5df49..5ce5d0d4f69c 100644 --- a/torch/csrc/inductor/aoti_torch/proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/proxy_executor.h @@ -6,6 +6,21 @@ namespace torch::aot_inductor { +enum DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, + NoneType = 5, +}; + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + class ProxyExecutor { public: ProxyExecutor() = default; diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index 153ee9e0ddbe..9d1bb914db5c 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -372,7 +372,7 @@ AOTITorchError aoti_torch_cpu__qlinear_pointwise_binary_tensor( }); } -AOTITorchError aoti_torch_cpu__qconv2d_pointwise_tensor( +AOTITorchError aoti_torch_cpu__qconv_pointwise_tensor( AtenTensorHandle X, AtenTensorHandle act_scale, AtenTensorHandle act_zero_point, diff --git a/torch/csrc/inductor/cpp_wrapper/common.h b/torch/csrc/inductor/cpp_wrapper/common.h index 3f77347f5274..2b59855cbc6e 100644 --- a/torch/csrc/inductor/cpp_wrapper/common.h +++ b/torch/csrc/inductor/cpp_wrapper/common.h @@ -3,16 +3,33 @@ #include #include #include +#include #include #define PYBIND11_SIMPLE_GIL_MANAGEMENT #include -namespace py = pybind11; + +// Include some often-used cpp_wrapper headers, for precompiling. +#include +#include +#include +#include +#include + +namespace py = pybind11; // NOLINT(misc-unused-alias-decls) class RAIIPyObject { public: - RAIIPyObject() : obj_(nullptr) {} - RAIIPyObject(PyObject* obj) : obj_(obj) {} + RAIIPyObject() = default; + // steals a reference to a PyObject + RAIIPyObject(PyObject* obj) : obj_{obj} {} + RAIIPyObject(const RAIIPyObject& other) : obj_{other.obj_} { + Py_XINCREF(obj_); + } + RAIIPyObject(RAIIPyObject&& other) noexcept { + // refcount doesn't change, and obj_ is currently nullptr + std::swap(obj_, other.obj_); + } ~RAIIPyObject() { Py_XDECREF(obj_); } @@ -24,6 +41,16 @@ class RAIIPyObject { } return *this; } + RAIIPyObject& operator=(RAIIPyObject&& other) noexcept { + // refcount to the current object decreases, but refcount to other.obj_ is + // the same + Py_XDECREF(obj_); + obj_ = std::exchange(other.obj_, nullptr); + return *this; + } + operator bool() const noexcept { + return obj_; + } operator PyObject*() { return obj_; } @@ -32,7 +59,7 @@ class RAIIPyObject { } private: - PyObject* obj_; + PyObject* obj_{nullptr}; }; #include diff --git a/torch/csrc/jit/codegen/fuser/arg_spec.h b/torch/csrc/jit/codegen/fuser/arg_spec.h index 7239e0391b8f..923aa324aa7a 100644 --- a/torch/csrc/jit/codegen/fuser/arg_spec.h +++ b/torch/csrc/jit/codegen/fuser/arg_spec.h @@ -16,7 +16,6 @@ namespace torch::jit::fuser { // Note: the device to run on is included in the arg spec because kernels // are compiled per-device. struct TORCH_API ArgSpec { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ArgSpec(at::TensorList inputs, const int _device) : descs_{c10::fmap(inputs)}, hash_code_{c10::get_hash(_device, inputs.size(), descs_)}, diff --git a/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp index d07e1fd2309e..47454c6eca25 100644 --- a/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp +++ b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp @@ -31,7 +31,7 @@ dnnl::engine& Engine::getEngine() { static dnnl::graph::allocator alloc{ pytorch_default_allocator, pytorch_default_deallocator}; static dnnl::engine cpu_engine = dnnl::graph::make_engine_with_allocator( - dnnl::engine::kind::cpu, /* device_id = */ 0, alloc); + dnnl::engine::kind::cpu, /* index = */ 0, alloc); return cpu_engine; } diff --git a/torch/csrc/jit/ir/attributes.h b/torch/csrc/jit/ir/attributes.h index fb2c44350d2d..f6e8f2148078 100644 --- a/torch/csrc/jit/ir/attributes.h +++ b/torch/csrc/jit/ir/attributes.h @@ -86,7 +86,6 @@ template struct VectorAttributeValue : public AttributeValue { using ConstructorType = std::vector; using ValueType = std::vector; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { @@ -144,7 +143,6 @@ struct TORCH_API GraphAttr : public AttributeValue { struct TORCH_API GraphsAttr : public AttributeValue { using ConstructorType = std::vector>; using ValueType = std::vector>; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GraphsAttr(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 44087074e891..fc780c26c3dd 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1490,7 +1490,6 @@ struct WithCurrentScope { ScopePtr prev_scope_; }; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) inline Value::Value(Node* node_, size_t offset_) : node_(node_), offset_(offset_), @@ -1651,7 +1650,6 @@ struct TORCH_API OperatorSet { }; template -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct OperatorMap { // Type aliasing using OpMapType = typename std::pair, T>; @@ -1659,12 +1657,10 @@ struct OperatorMap { using MapType = std::unordered_map; OperatorMap() = default; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit OperatorMap( std::initializer_list, T>> init) { insert(init); } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit OperatorMap(std::initializer_list> init) { insert(init); } @@ -1760,7 +1756,6 @@ struct OperatorMap { }; template -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct FunctionSchemaMap { // Type aliasing using FuncSchemaMapType = typename std::pair; diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 6c1bfd0ec3ec..089a0df564a0 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -646,6 +647,9 @@ mobile::Module _load_for_mobile( std::optional device, ExtraFilesMap& extra_files, uint64_t module_load_options) { +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) + torch::initialize_torch_libraries(); +#endif auto observer = torch::observerConfig().getModuleObserver(); if (observer) { extra_files.insert(std::make_pair("model_path", filename)); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 5911064b22f2..5c46e936a4ec 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1936,6 +1936,13 @@ void initJITBindings(PyObject* module) { self.addArgumentValues(value_map); }); py::class_(m, "FunctionSchema") + .def(py::init< + std::string, + std::string, + std::vector, + std::vector, + bool, + bool>()) .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) .def_property_readonly( @@ -1993,6 +2000,13 @@ void initJITBindings(PyObject* module) { .def_property_readonly( "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); py::class_(m, "Argument") + .def(py::init< + std::string, + const TypePtr&, + std::optional, + std::optional, + bool, + std::optional>()) .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) .def_property_readonly( @@ -2032,6 +2046,7 @@ void initJITBindings(PyObject* module) { return self.kwarg_only(); }); py::class_(m, "_AliasInfo") + .def(py::init, std::set>()) .def_property_readonly( "is_write", [](AliasInfo& self) { return self.isWrite(); }) .def_property_readonly( diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 8295b9d6c378..d1039216de3e 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -43,7 +43,6 @@ struct ExecutionPlan { // They are only valid only right after you call getDebugState() and should // never be used again once another GraphExecutor function is called. -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct GraphExecutorState { const Graph* graph = nullptr; ExecutionPlan fallback; // XXX: members of this field are optional diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index e6a71dc0a0b9..6ae9f52a0cda 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -111,7 +111,6 @@ struct Suspend : public std::exception { return "Suspend"; } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit Suspend(c10::intrusive_ptr future_) : future(std::move(future_)) {} diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 905c69a47966..02e64d196151 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -18,9 +18,7 @@ TORCH_DECLARE_bool(torch_jit_enable_expanded_stacks); TORCH_DECLARE_bool(torch_jit_expanded_stacks_mangled); -namespace torch::jit { - -namespace interpreter { +namespace torch::jit::interpreter { template Ttarget safe_narrow_cast(Tsource v) { @@ -64,7 +62,7 @@ struct NodeSourceInfo { const char* func_name_{nullptr}; const char* file_name_{nullptr}; size_t line_{0}; - NodeSourceInfo() {} + NodeSourceInfo() = default; }; struct CodeImpl { @@ -1060,5 +1058,4 @@ struct MobileCodeImpl : CodeImpl { bool emit_promoted_ops_; }; -} // namespace interpreter -} // namespace torch::jit +} // namespace torch::jit::interpreter diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 2e609f18ecc0..bde3825f5ea3 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -60,7 +60,6 @@ const std::array kJitOnlyOperatorTags = { // the concrete operator nature. struct TORCH_API Operator { private: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct C10Operator final { c10::OperatorHandle handle_; Operation op_; @@ -69,7 +68,6 @@ struct TORCH_API Operator { std::string schema_string_; mutable std::optional alias_analysis_; }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct JitOnlyOperator final { // The only valid transition for schema_ is from right->left, i.e. // when the schema gets parsed. diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index ec736d006be0..0e2a89544b56 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -38,7 +38,6 @@ #include #include #include -#include #ifdef FBCODE_CAFFE2 #include @@ -953,11 +952,11 @@ BlockRunner::BlockRunner( } } -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) BlockRunner::BlockRunner(BlockRunner&&) noexcept = default; BlockRunner::~BlockRunner() = default; +// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) void BlockRunner::set_arg(const size_t idx, std::vector&& args) { DCHECK(idx < args.size()); Input(idx + first_input_is_self_) = std::move(args[idx]); diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 04a0862f9795..e8a3bdbc42ff 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -815,10 +815,8 @@ class TORCH_API BlockRunner { std::vector nodes_; }; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_API StaticNodeInfo { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) StaticNodeInfo( Node* n, ProcessedFunction* fn, @@ -873,6 +871,9 @@ class TORCH_API ProcessedNodeMetadata { // if the contained type (BlockRunner) is not copyable ProcessedNodeMetadata(const ProcessedNodeMetadata&) = delete; ProcessedNodeMetadata& operator=(const ProcessedNodeMetadata&) = delete; + ProcessedNodeMetadata(ProcessedNodeMetadata&&) = delete; + ProcessedNodeMetadata&& operator=(ProcessedNodeMetadata&&) = delete; + ~ProcessedNodeMetadata() = default; std::vector& block_runners() { return block_runners_; @@ -895,10 +896,8 @@ class TORCH_API ProcessedNodeMetadata { torch::jit::TaskLauncher* launcher_; }; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_API ProcessedNode { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ProcessedNode() = default; ProcessedNode(const StaticNodeInfo& other, IValue* values) @@ -917,6 +916,7 @@ class TORCH_API ProcessedNode { ProcessedNode(const ProcessedNode&) = delete; ProcessedNode& operator=(const ProcessedNode& other) = delete; ProcessedNode& operator=(ProcessedNode&&) = default; + ~ProcessedNode() = default; void run(); @@ -1025,10 +1025,10 @@ class TORCH_API ProcessedNode { [[nodiscard]] bool verify_inputs_dont_overlap_outputs(bool force_check) const; - Node* node_; - const ProcessedFunction* fn_; + Node* node_{nullptr}; + const ProcessedFunction* fn_{nullptr}; ProcessedNodeInputs inputs_; - uint16_t outputs_offset_; + uint16_t outputs_offset_{0}; bool overlap_detected_{false}; IValue* values_ = nullptr; // unowned // Metadata for ProcessedNode. diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 60fca2f87066..d5586a5b9cd7 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1344,7 +1344,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator { namespace { -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct ToArgs { std::optional dtype; c10::Layout layout; diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 3694cd194179..3cd75cedada7 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -394,6 +394,15 @@ struct OptionalArgCaster { } else if (py::isinstance(element)) { auto values = arg.cast>(); setValue(f, idx, values); + } else if (THPVariable_Check(element.ptr())) { + /* List of tensors, most often to overcome the limits of 32-args per + * kernel */ + auto tensorlist = py::cast>(arg); + std::vector tl_ptrs; + for (auto& t : tensorlist) { + tl_ptrs.push_back(at::native::mps::get_tensor_gpu_address(t)); + } + f.setArg(idx, tl_ptrs); } else { TORCH_CHECK(false, "Unexpected argument types"); } diff --git a/torch/csrc/profiler/kineto_client_interface.cpp b/torch/csrc/profiler/kineto_client_interface.cpp index fd145f4c4fa6..89c824cd578f 100644 --- a/torch/csrc/profiler/kineto_client_interface.cpp +++ b/torch/csrc/profiler/kineto_client_interface.cpp @@ -58,6 +58,20 @@ class LibKinetoClient : public libkineto::ClientInterface { (void)disableProfiler(); } + void start_memory_profile() override { + LOG(INFO) << "Starting on-demand memory profile"; + startMemoryProfile(); + } + + void stop_memory_profile() override { + LOG(INFO) << "Stopping on-demand memory profile"; + stopMemoryProfile(); + } + + void export_memory_profile(const std::string& path) override { + exportMemoryProfile(path); + } + private: // Temporarily disable shape collection until // we re-roll out the feature for on-demand cases diff --git a/torch/csrc/profiler/orchestration/python_tracer.cpp b/torch/csrc/profiler/orchestration/python_tracer.cpp index e570a69cb696..73bdf3ccb017 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.cpp +++ b/torch/csrc/profiler/orchestration/python_tracer.cpp @@ -3,6 +3,7 @@ namespace torch::profiler::impl::python_tracer { namespace { MakeFn make_fn; +MakeMemoryFn memory_make_fn; struct NoOpPythonTracer : public PythonTracerBase { NoOpPythonTracer() = default; @@ -17,6 +18,15 @@ struct NoOpPythonTracer : public PythonTracerBase { return {}; } }; + +struct NoOpMemoryPythonTracer : public PythonMemoryTracerBase { + NoOpMemoryPythonTracer() = default; + ~NoOpMemoryPythonTracer() override = default; + void start() override {} + void stop() override {} + void export_memory_history(const std::string path) override {} +}; + } // namespace void registerTracer(MakeFn make_tracer) { @@ -29,4 +39,15 @@ std::unique_ptr PythonTracerBase::make(RecordQueue* queue) { } return make_fn(queue); } + +void registerMemoryTracer(MakeMemoryFn make_memory_tracer) { + memory_make_fn = make_memory_tracer; +} + +std::unique_ptr PythonMemoryTracerBase::make() { + if (memory_make_fn == nullptr) { + return std::make_unique(); + } + return memory_make_fn(); +} } // namespace torch::profiler::impl::python_tracer diff --git a/torch/csrc/profiler/orchestration/python_tracer.h b/torch/csrc/profiler/orchestration/python_tracer.h index 580bf523e7f5..725c6d8a5c95 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.h +++ b/torch/csrc/profiler/orchestration/python_tracer.h @@ -56,5 +56,21 @@ struct TORCH_API PythonTracerBase { using MakeFn = std::unique_ptr (*)(RecordQueue*); TORCH_API void registerTracer(MakeFn make_tracer); + +/** + * Memory Tracer Implementation + */ +struct TORCH_API PythonMemoryTracerBase { + static std::unique_ptr make(); + virtual ~PythonMemoryTracerBase() = default; + + virtual void start() = 0; + virtual void stop() = 0; + virtual void export_memory_history(const std::string path) = 0; +}; + +using MakeMemoryFn = std::unique_ptr (*)(); +TORCH_API void registerMemoryTracer(MakeMemoryFn make_memory_tracer); + } // namespace python_tracer } // namespace torch::profiler::impl diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index f348069b4fbb..8ba2f37d99b5 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>> +// checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>> // clang-format off #pragma once @@ -54,9 +54,9 @@ class ForwardRef { public: ForwardRef(): ptr_(std::make_unique()) {} - ForwardRef(ForwardRef&&) = default; + ForwardRef(ForwardRef&&); ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {} - ForwardRef& operator=(ForwardRef&&) = default; + ForwardRef& operator=(ForwardRef&&); ForwardRef& operator=(const ForwardRef& other) { ptr_ = std::make_unique(*other.ptr_); return *this; @@ -3216,6 +3216,7 @@ class ExternKernelNode { class ExternKernelNodes { private: std::vector nodes; + std::optional protocol = std::nullopt; public: @@ -3227,6 +3228,14 @@ class ExternKernelNodes { nodes = std::move(def); } + const std::optional& get_protocol() const { + return protocol; + } + + void set_protocol(std::optional def) { + protocol = std::move(def); + } + friend void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t); friend void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t); }; @@ -3315,11 +3324,13 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNode& n inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t) { nlohmann_json_j["nodes"] = nlohmann_json_t.nodes; + nlohmann_json_j["protocol"] = nlohmann_json_t.protocol; } inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t) { ExternKernelNodes nlohmann_json_default_obj; nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes); + nlohmann_json_t.protocol = nlohmann_json_j.value("protocol", nlohmann_json_default_obj.protocol); } inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) { @@ -3688,6 +3699,9 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); } + +template ForwardRef::ForwardRef(ForwardRef&&) = default; +template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; } // namespace _export } // namespace torch diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 43ef85ad8fce..9c73f9ca2b9e 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -135,6 +135,16 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return getPyObj().attr("guard_size_oblivious")(file, line).cast(); } + bool guard_or_false(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_or_false")(file, line).cast(); + } + + bool guard_or_true(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_or_true")(file, line).cast(); + } + int64_t int_() override { py::gil_scoped_acquire acquire; return getPyObj().attr("int_")().cast(); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 7e1c6c15b175..cb5c4d0919d2 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1210,8 +1210,7 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) -> int: - handle = _get_amdsmi_handler() - device = _get_amdsmi_device_index(device) + handle = _get_amdsmi_handler(device) # amdsmi_get_gpu_vram_usage returns mem usage in megabytes mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] mem_bytes = mem_mega_bytes * 1024 * 1024 @@ -1219,16 +1218,12 @@ def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: - handle = _get_amdsmi_handler() - device = _get_amdsmi_device_index(device) - handle = amdsmi.amdsmi_get_processor_handles()[device] + handle = _get_amdsmi_handler(device) return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"] def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: - handle = _get_amdsmi_handler() - device = _get_amdsmi_device_index(device) - handle = amdsmi.amdsmi_get_processor_handles()[device] + handle = _get_amdsmi_handler(device) return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index b7b71bdf4b2b..3bd508f4c2c9 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -15,6 +15,7 @@ FileSystemBase, FileSystemReader, FileSystemWriter, + SerializationFormat, ) @@ -90,6 +91,9 @@ def exists(self, path: Union[str, os.PathLike]) -> bool: def rm_file(self, path: Union[str, os.PathLike]) -> None: self.fs.rm(path) + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + return self.fs.ls(path) + # TODO: add the dcp.async_save mixin class FsspecWriter(FileSystemWriter): @@ -115,6 +119,7 @@ def __init__( per_thread_copy_ahead: int = 10_000_000, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, **kwargs, ) -> None: """ @@ -139,6 +144,7 @@ def __init__( per_thread_copy_ahead, overwrite=overwrite, _extensions=_extensions, + serialization_format=serialization_format, ) self.fs = FileSystem() self.path = self.fs.init_path(path, **kwargs) diff --git a/torch/distributed/checkpoint/_hf_storage.py b/torch/distributed/checkpoint/_hf_storage.py index 7b8f2d656e01..ef5f5bacb95b 100644 --- a/torch/distributed/checkpoint/_hf_storage.py +++ b/torch/distributed/checkpoint/_hf_storage.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import dataclasses import json +import os import queue from typing import Optional @@ -11,6 +12,7 @@ _FqnToFileMapping, _HuggingFaceLoadPlanner, ) +from torch.distributed.checkpoint.filesystem import SerializationFormat from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, Metadata, @@ -64,7 +66,11 @@ def __init__( if HfFileSystem.protocol not in fsspec.available_protocols(): fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) - super().__init__(path=path, token=token) + super().__init__( + path=path, + token=token, + serialization_format=SerializationFormat.SAFETENSORS, + ) self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping def prepare_local_plan(self, plan: SavePlan) -> SavePlan: @@ -99,7 +105,7 @@ def write_data( (self.fs.concat_path(self.path, file_name), file_name, write_items) ) - return super()._write_data(planner, file_queue, safe_tensors=True) + return super()._write_data(planner, file_queue) def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: metadata_to_write = {} @@ -201,15 +207,40 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: return fut def read_metadata(self) -> Metadata: - path = self.fs.concat_path(self.path, _metadata_fn) - with self.fs.create_stream(path, "r") as metadata_file: - metadata = json.load(metadata_file) + metadata_path = self.fs.concat_path(self.path, _metadata_fn) state_dict_metadata: dict[str, STORAGE_TYPES] = {} - for key in metadata["weight_map"].keys(): - state_dict_metadata[key] = BytesStorageMetadata() + storage_data: dict[str, str] = {} + + if not self.fs.exists(metadata_path): + # if metadata file doesn't exist, create it from the safetensors file + from safetensors.torch import safe_open # type: ignore[import-not-found] + + safetensors_files = [] + for file in self.fs.ls(self.path): + if file.endswith(SUFFIX): + safetensors_files.append(os.path.basename(file)) + + if len(safetensors_files) != 1: + raise ValueError( + f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files" + ) + storage_data = {} + with safe_open(safetensors_files[0], framework="pt") as f: + for k in f.keys(): + state_dict_metadata[k] = BytesStorageMetadata() + storage_data[k] = safetensors_files[0] + else: + with self.fs.create_stream(metadata_path, "r") as metadata_file: + metadata = json.load(metadata_file) + + for key in metadata["weight_map"].keys(): + state_dict_metadata[key] = BytesStorageMetadata() + storage_data = metadata["weight_map"] + metadata = Metadata( - state_dict_metadata=state_dict_metadata, storage_data=metadata["weight_map"] + state_dict_metadata=state_dict_metadata, + storage_data=storage_data, ) if getattr(metadata, "storage_meta", None) is None: diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 89b82e7bc127..0c3db0416e90 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -13,6 +13,7 @@ from collections.abc import Generator, Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from io import UnsupportedOperation from pathlib import Path from typing import Any, Callable, cast, IO, Optional, Union @@ -49,7 +50,13 @@ from torch.futures import Future -__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"] +__all__ = [ + "FileSystemWriter", + "FileSystemReader", + "FileSystem", + "FileSystemBase", + "SerializationFormat", +] _metadata_fn: str = ".metadata" @@ -72,6 +79,11 @@ class _StoragePrefix: prefix: str +class SerializationFormat(Enum): + TORCH_SAVE = "torch_save" + SAFETENSORS = "safetensors" + + DEFAULT_SUFFIX = ".distcp" @@ -298,7 +310,7 @@ def _write_item( data: Union[io.BytesIO, torch.Tensor], write_item: WriteItem, storage_key: str, - safe_tensors: bool = False, + serialization_format: SerializationFormat, ) -> WriteResult: offset = stream.tell() @@ -312,12 +324,14 @@ def _write_item( else: assert isinstance(data, torch.Tensor) assert data.device == torch.device("cpu") - if not safe_tensors: + if serialization_format == SerializationFormat.TORCH_SAVE: torch.save(data, transform_to) transform_to.close() - if not safe_tensors or isinstance(data, io.BytesIO): + if serialization_format == SerializationFormat.TORCH_SAVE or isinstance( + data, io.BytesIO + ): length = stream.tell() - offset else: length = data.numel() * data.element_size() @@ -349,7 +363,7 @@ def _write_files_from_queue( inflight_threshhold: int, use_fsync: bool, thread_count: int, - safe_tensors: bool, + serialization_format: SerializationFormat, ) -> None: try: while True: @@ -397,7 +411,7 @@ def _write_files_from_queue( data, write_item, storage_key, - safe_tensors, + serialization_format, ) ) @@ -411,12 +425,12 @@ def _write_files_from_queue( tensor, write_item, storage_key, - safe_tensors, + serialization_format, ) ) tensor_dict[write_item.index.fqn] = tensor - if safe_tensors: + if serialization_format == SerializationFormat.SAFETENSORS: from safetensors.torch import save # type: ignore[import-not-found] stream.write(save(tensor_dict)) @@ -465,6 +479,9 @@ def exists(self, path: Union[str, os.PathLike]) -> bool: ... @abstractmethod def rm_file(self, path: Union[str, os.PathLike]) -> None: ... + @abstractmethod + def ls(self, path: Union[str, os.PathLike]) -> list[str]: ... + class FileSystem(FileSystemBase): @contextmanager @@ -525,6 +542,11 @@ def rm_file(self, path: Union[str, os.PathLike]) -> None: path = Path(path) path.unlink() + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + if not isinstance(path, Path): + path = Path(path) + return [str(p) for p in path.iterdir()] + class _FileSystemWriter(StorageWriter): """ @@ -549,6 +571,7 @@ def __init__( per_thread_copy_ahead: int = 10_000_000, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, *args: Any, **kwargs: Any, ) -> None: @@ -576,6 +599,7 @@ def __init__( self.save_id = _generate_uuid() self.overwrite = overwrite self.transforms = _StorageWriterTransforms(_extensions) + self.serialization_format = serialization_format def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: if checkpoint_id: @@ -638,7 +662,6 @@ def _write_data( self, planner: SavePlanner, file_queue: queue.Queue, - safe_tensors: bool = False, ) -> Future[list[WriteResult]]: result_queue: queue.Queue = queue.Queue() @@ -655,7 +678,7 @@ def _write_data( self.per_thread_copy_ahead, self.sync_files, self.thread_count, - safe_tensors, + self.serialization_format, ), ) t.start() @@ -670,7 +693,7 @@ def _write_data( inflight_threshhold=self.per_thread_copy_ahead, use_fsync=self.sync_files, thread_count=self.thread_count, - safe_tensors=safe_tensors, + serialization_format=self.serialization_format, ) for t in threads: @@ -892,6 +915,7 @@ def __init__( cache_staged_state_dict: bool = False, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, ) -> None: """ Initialize the writer pointing to `path`. @@ -919,6 +943,7 @@ def __init__( per_thread_copy_ahead=per_thread_copy_ahead, overwrite=overwrite, _extensions=_extensions, + serialization_format=serialization_format, ) BlockingAsyncStager.__init__( self, diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 0615721228b0..dd9c27f6542c 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -307,6 +307,16 @@ def broadcast( raise final_result return cast(T, final_result) + def barrier(self) -> None: + """ + Add a synchronization point across all processes when using distributed. + If torch.distributed is initialized, this function will invoke a barrier across the global process group. + If torch.distributed is not initialized, this function is a no-op. + """ + if not self.use_dist: + return + dist.barrier(group=self.group) + def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: if index.offset is None: diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 339afeffdc7f..668dbf49a0d0 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2501,7 +2501,7 @@ class _CoalescingManager: def __init__(self) -> None: self.works: list[Work] = [] - def append(self, work: Work): + def append(self, work: Optional[Work] = None): if work: self.works.append(work) @@ -2514,7 +2514,7 @@ def wait(self): def _coalescing_manager( group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, - async_ops: Optional[bool] = False, + async_ops: bool = False, ): """ Context manager used to coalesce collectives or P2P operations when possible. @@ -2553,6 +2553,7 @@ def _coalescing_manager( group._start_coalescing(device) cm = _CoalescingManager() yield cm + work = None op_list = _world.pg_coalesce_state.pop(group) if op_list: # Collectives supporting "Fast Path" coalescing are captured. @@ -2566,6 +2567,7 @@ def _coalescing_manager( tensors = [op.tensor for op in op_list] all_reduce_opts = AllreduceCoalescedOptions() all_reduce_opts.reduceOp = not_none(op_list[0].redop) + all_reduce_opts.asyncOp = async_ops work = group.allreduce_coalesced(tensors, all_reduce_opts) elif op0 == all_gather_into_tensor: inputs = [] @@ -2573,6 +2575,8 @@ def _coalescing_manager( for op in op_list: inputs.append(op.tensor) outputs.append(not_none(op.dst_tensor)) + all_gather_opts = AllgatherOptions() + all_gather_opts.asyncOp = async_ops work = group.allgather_into_tensor_coalesced(outputs, inputs) elif op0 == reduce_scatter_tensor: inputs = [] @@ -2582,6 +2586,7 @@ def _coalescing_manager( outputs.append(not_none(op.dst_tensor)) reduce_opts = ReduceScatterOptions() reduce_opts.reduceOp = not_none(op_list[0].redop) + reduce_opts.asyncOp = async_ops work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) else: raise AssertionError( @@ -2594,9 +2599,12 @@ def _coalescing_manager( work = group._end_coalescing(device) if async_ops: - cm.append(work) # type: ignore[possibly-undefined] - else: - work.wait() # type: ignore[possibly-undefined] + cm.append(work) + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level class _TimeEstimator: @@ -2772,8 +2780,11 @@ def broadcast( work = group.broadcast([tensor], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2853,6 +2864,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceOptions() opts.reduceOp = op + opts.asyncOp = async_op if group is None: group = _get_default_group() @@ -2869,8 +2881,11 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2929,13 +2944,17 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceCoalescedOptions() opts.reduceOp = op + opts.asyncOp = async_op group = group or _get_default_group() work = group.allreduce_coalesced(tensors, opts) if async_op: return work.get_future() - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2980,11 +2999,15 @@ def reduce( opts = ReduceOptions() opts.reduceOp = op opts.rootRank = group_dst + opts.asyncOp = async_op work = group.reduce([tensor], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level def _object_to_tensor(obj, device, group): @@ -3783,12 +3806,17 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) group = group or _get_default_group() - work = group.allgather([tensor_list], [tensor]) + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather([tensor_list], [tensor], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -3891,8 +3919,11 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4002,12 +4033,17 @@ def all_gather_coalesced( ] group = group or _get_default_group() - work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts) if async_op: return work.get_future() - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level def _validate_output_list_for_rank(my_rank, dst, gather_list): @@ -4093,12 +4129,16 @@ def gather( opts = GatherOptions() opts.rootRank = group_dst + opts.asyncOp = async_op work = group.gather(output_tensors, input_tensors, opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4199,8 +4239,11 @@ def scatter( if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4232,14 +4275,18 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal opts = ReduceScatterOptions() opts.reduceOp = op + opts.asyncOp = async_op group = group or _get_default_group() work = group.reduce_scatter([output], [input_list], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4336,8 +4383,11 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @deprecated( @@ -4490,6 +4540,7 @@ def all_to_all_single( return opts = AllToAllOptions() + opts.asyncOp = async_op _check_single_tensor(output, "output") _check_single_tensor(input, "input") _ensure_all_tensors_same_dtype(output, input) @@ -4509,8 +4560,11 @@ def all_to_all_single( if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4611,6 +4665,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False return opts = AllToAllOptions() + opts.asyncOp = async_op _check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list") _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) @@ -4627,8 +4682,11 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4659,6 +4717,7 @@ def barrier( opts = BarrierOptions() opts.device = torch.device(_get_object_coll_device(group)) + opts.asyncOp = async_op if device_ids is not None: if isinstance(device_ids, list): opts.device_ids = device_ids @@ -4672,8 +4731,11 @@ def barrier( if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level def monitored_barrier( diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index e149005ffc2c..c9c36654e882 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -95,17 +95,17 @@ def get_all_gather_streams( # See [Note: Overlapping all-gather copy-in and all-gather] class AllGatherState(NamedTuple): all_gather_result: AllGatherResult - event: torch.Event # all-gather copy-out + event: Optional[torch.Event] # all-gather copy-out class ReduceScatterState(NamedTuple): reduce_scatter_input: torch.Tensor - event: torch.Event # reduce-scatter event + event: Optional[torch.Event] # reduce-scatter event class AllReduceState(NamedTuple): all_reduce_input: torch.Tensor - event: torch.Event # all-reduce event + event: Optional[torch.Event] # all-reduce event class FSDPParamGroup: @@ -310,11 +310,11 @@ def wait_for_unshard(self): self._wait_all_gather_streams_on_event(all_gather_copy_out_event) self._all_gather_result = None # free unless saved in `all_gather_state` - def _wait_all_gather_streams_on_event(self, event: torch.Event): + def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): # Calling `unshard` before lazy init means streams are not initialized - if hasattr(self.comm_ctx, "all_gather_copy_in_stream"): + if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None: self.comm_ctx.all_gather_copy_in_stream.wait_event(event) - if hasattr(self.comm_ctx, "all_gather_stream"): + if hasattr(self.comm_ctx, "all_gather_stream") and event is not None: self.comm_ctx.all_gather_stream.wait_event(event) def reshard(self): @@ -414,11 +414,14 @@ def post_backward(self, *unused: Any): if len(fsdp_params_with_grad) == 0: return with record_function(self._with_fqn("FSDP::post_backward_reduce")): - if self.comm_ctx.reduce_scatter_state is not None: + if ( + self.comm_ctx.reduce_scatter_state is not None + and self.comm_ctx.reduce_scatter_state.event is not None + ): self.device_handle.current_stream().wait_event( self.comm_ctx.reduce_scatter_state.event ) - self.comm_ctx.reduce_scatter_state = None + self.comm_ctx.reduce_scatter_state = None all_reduce_pg = self._all_reduce_process_group if self._is_hsdp else None all_reduce_stream: torch.cuda.Stream if all_reduce_pg is None and self._all_reduce_hook_stream is not None: @@ -458,7 +461,8 @@ def post_backward(self, *unused: Any): reduce_scatter_input, reduce_scatter_event ) if all_reduce_input is not None: - assert all_reduce_event is not None + if self.device.type != "cpu": + assert all_reduce_event is not None self._all_reduce_state = AllReduceState( all_reduce_input, all_reduce_event ) @@ -484,9 +488,12 @@ def _wait_for_post_backward(self): if self._post_reduce_event is not None: self.device_handle.current_stream().wait_event(self._post_reduce_event) self._post_reduce_event = None - if self._all_reduce_state is not None: + if ( + self._all_reduce_state is not None + and self._all_reduce_state.event is not None + ): self.device_handle.current_stream().wait_event(self._all_reduce_state.event) - self._all_reduce_state = None + self._all_reduce_state = None def _backward_prefetch(self) -> None: if self._training_state == TrainingState.PRE_BACKWARD: diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index 5d11f0359f1f..601a77185e40 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -59,7 +59,11 @@ def disable_if_config_true(func): @functools.wraps(func) def fsdp_hook_wrapper(*args, **kwargs): if torch._dynamo.config.skip_fsdp_hooks: - return torch._dynamo.disable(func, recursive=True)(*args, **kwargs) + return torch._dynamo.disable( + func, + recursive=True, + reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set", + )(*args, **kwargs) else: return func(*args, **kwargs) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 416965e80ba3..4e1b9676d7ca 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1003,9 +1003,7 @@ def _trace_with_export( logger.info("Tracing model ...") try: ep = torch.export.export_for_training( - mod, - example_args, - example_kwargs, + mod, example_args, example_kwargs, strict=True ) except Exception as e: raise RuntimeError( diff --git a/torch/distributed/tensor/_ops/__init__.py b/torch/distributed/tensor/_ops/__init__.py index dec4665b1c8b..7cfaa668a183 100644 --- a/torch/distributed/tensor/_ops/__init__.py +++ b/torch/distributed/tensor/_ops/__init__.py @@ -1,7 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from ._conv_ops import * # noqa: F403 from ._embedding_ops import * # noqa: F403 -from ._experimental_ops import * # noqa: F403 from ._math_ops import * # noqa: F403 from ._matrix_ops import * # noqa: F403 from ._pointwise_ops import * # noqa: F403 diff --git a/torch/distributed/tensor/_ops/_experimental_ops.py b/torch/distributed/tensor/_ops/_experimental_ops.py deleted file mode 100644 index 59e907dc5ba1..000000000000 --- a/torch/distributed/tensor/_ops/_experimental_ops.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor - -import torch -from torch.distributed.tensor._dtensor_spec import DTensorSpec -from torch.distributed.tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - StrategyType, -) -from torch.distributed.tensor._ops.utils import register_op_strategy -from torch.distributed.tensor.placement_types import Replicate - - -aten = torch.ops.aten - - -@register_op_strategy(aten.slice_backward.default) -def slice_backward_rules(op_schema: OpSchema) -> StrategyType: - """ - slice_backward is a new_zeros + slice_scatter, we only allow replication - on the input/output for now since new_zeros would produce replication - """ - mesh = op_schema.get_mesh_from_args(validate=False) - replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - return OpStrategy([PlacementStrategy(replicate_spec)]) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index d100aaea4ad7..9b73f36d855f 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -20,6 +20,7 @@ from torch.distributed.tensor._ops._embedding_ops import _MaskPartial from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, + generate_redistribute_costs, is_tensor_dim_sharded, is_tensor_evenly_shardable, is_tensor_partial, @@ -237,7 +238,7 @@ def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType: @register_op_strategy(aten.select.int, schema_info=RuntimeSchemaInfo(1)) -def gen_select_strategy(op_schema: OpSchema) -> StrategyType: +def select_int_strategy(op_schema: OpSchema) -> StrategyType: """ In this select op, first determine the input specs, then determine the output specs. - Input specs: @@ -299,6 +300,38 @@ def gen_select_strategy(op_schema: OpSchema) -> StrategyType: return select_strategy +@register_op_strategy( + aten.select_backward.default, + schema_info=RuntimeSchemaInfo(1), +) +def select_backward_strategy(op_schema: OpSchema) -> OpStrategy: + # func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + args_schema = op_schema.args_schema + input_strategy, dim = args_schema[0], args_schema[2] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + assert isinstance(dim, int) + output_strategies: list[PlacementStrategy] = [] + for placement_strategy in input_strategy.strategies: + input_spec = placement_strategy.output_spec + output_spec_placements: list[Placement] = [] + for placement in input_spec.placements: + if isinstance(placement, Shard): + shard_dim = placement.dim + if shard_dim >= dim: + # NOTE: shard_dim is guaranteed to exist because + # grad_input has one more dim than grad_output + output_spec_placements.append(Shard(shard_dim + 1)) + else: + output_spec_placements.append(Shard(shard_dim)) + else: + output_spec_placements.append(placement) + output_specs = DTensorSpec(input_spec.mesh, tuple(output_spec_placements)) + output_strategies.append( + PlacementStrategy(output_specs=output_specs, input_specs=(input_spec,)) + ) + return OpStrategy(output_strategies) + + @register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: """Forward all shardings except the slice dimension.""" @@ -349,6 +382,33 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: return slice_strategy +@register_op_strategy( + aten.slice_backward.default, + schema_info=RuntimeSchemaInfo(1), +) +def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: + # func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + args_schema = op_schema.args_schema + input_strategy, dim = args_schema[0], args_schema[2] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + output_strategies: list[PlacementStrategy] = [] + for placement_strategy in input_strategy.strategies: + output_spec = placement_strategy.output_spec + new_placements: list[Placement] = [] + for placement in output_spec.placements: + # Redistribute to replicate only if the dim is sharded and matches the slice dim + if isinstance(placement, Shard) and placement.dim == dim: + new_placements.append(Replicate()) + else: + new_placements.append(placement) + new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements)) + redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)] + placement_strategy.redistribute_cost = redistribute_cost + new_strategy = PlacementStrategy(output_specs=new_spec) + output_strategies.append(new_strategy) + return OpStrategy(output_strategies) + + def unshard_tensor_dim( placements: Sequence[Placement], dim: int ) -> tuple[Placement, ...]: diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index c5bb22a92b7d..0e186da56152 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -77,6 +77,8 @@ def __init__(self) -> None: aten.reshape.default: 1, aten.view.default: 1, aten._unsafe_view.default: 1, + aten.select_backward.default: 1, + aten.slice_backward.default: 1, } def register_sharding_prop_rule( diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index 11bdb4ec2ef2..3102b84c11d1 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -21,12 +21,10 @@ ) -aten = ( - torch.ops.aten -) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified. +aten = torch.ops.aten -class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ +class LocalShardsWrapper(torch.Tensor): """ A wrapper class to hold local shards of a DTensor. This class is used largely for checkpointing purposes and implicity subtypes @@ -41,18 +39,39 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new def __new__( cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]] ) -> "LocalShardsWrapper": - assert len(local_shards) > 0 - assert len(local_shards) == len(local_offsets) assert all( tensor.device == local_shards[0].device for tensor in local_shards[1:] ) + # if empty shard, we create a empty tensor + if len(local_shards) == 0: + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + torch.Size([0, 0]), + ) + r._local_shards = [] + r._storage_meta = TensorStorageMetadata( + properties=TensorProperties(), + size=torch.Size([0, 0]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([0, 0]) + ) + ], + ) + return r + # we calculate the total tensor size by "concat" on second tensor dimension cat_tensor_shape = list(local_shards[0].size()) - if len(local_shards) > 1: # column-wise sharding + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding for shard in local_shards[1:]: cat_tensor_shape[1] += shard.size()[1] + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) wrapper_shape = torch.Size(cat_tensor_shape) chunks_meta = [ @@ -78,9 +97,7 @@ def __new__( # necessary for ops dispatching from this subclass to its local shards @classmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] kwargs = kwargs or {} dispatcher = { @@ -91,21 +108,18 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.equal.default: cls.handle_equal, aten.detach.default: cls.handle_detach, aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, } if func in dispatcher: - return dispatcher[func]( - args, kwargs - ) # pyre-ignore [29] - `Variable[_VT]` is not a function. + return dispatcher[func](args, kwargs) else: raise NotImplementedError( f"{func} is not supported for LocalShardsWrapper!" ) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_all_gather_into_tensor(args, kwargs): + def handle_all_gather_into_tensor(args, kwargs) -> torch.Tensor: dim = args[0].local_sizes()[0][1] cat_tensor = torch.cat( [t.view(-1) for t in args[0].local_shards()], dim=0 @@ -115,15 +129,11 @@ def handle_all_gather_into_tensor(args, kwargs): ) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_wait_tensor(args, kwargs): + def handle_wait_tensor(args, kwargs) -> torch.Tensor: return torch.ops._c10d_functional.wait_tensor(args[0]) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_to_copy(args, kwargs): + def handle_to_copy(args, kwargs) -> torch.Tensor: res_shards_list = [ aten._to_copy.default(shard, *args[1:], **kwargs) for shard in args[0].local_shards() @@ -131,20 +141,41 @@ def handle_to_copy(args, kwargs): return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_view(args, kwargs): - # TODO, do we need to change the shape of associated offsets? - res_shards_list = [ - aten.view.default(shard, args[1], **kwargs) - for shard in args[0].local_shards() - ] + def handle_view(args, kwargs) -> "LocalShardsWrapper": + view_shape = args[1] + res_shards_list = [] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") + else: + # view is called per shard + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_equal(args, kwargs): + def handle_equal(args, kwargs) -> bool: """ LocalShardsWrapper equal impl also checks for equality of storage metadata and the order of shards @@ -161,9 +192,7 @@ def handle_equal(args, kwargs): return True @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_detach(args, kwargs): + def handle_detach(args, kwargs) -> "LocalShardsWrapper": self_ls = args[0] deatched_local_shards = [ aten.detach.default(shard) for shard in self_ls.local_shards() @@ -173,9 +202,7 @@ def handle_detach(args, kwargs): return self_ls @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_clone(args, kwargs): + def handle_clone(args, kwargs) -> "LocalShardsWrapper": self_ls = args[0] desired_memory_format = kwargs.get("memory_format", None) if desired_memory_format and desired_memory_format != torch.preserve_format: @@ -188,19 +215,27 @@ def handle_clone(args, kwargs): ] return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + @staticmethod + def handle_new_empty(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + @property def device(self) -> torch._C.device: # type: ignore[override] - return self._local_shards[0].device + return ( + self._local_shards[0].device if self._local_shards else torch.device("meta") + ) @property def is_meta(self) -> bool: # type: ignore[override] - return self._local_shards[0].is_meta + return self._local_shards[0].is_meta if self._local_shards else True - # pyre-ignore[14] def is_pinned(self) -> bool: # type: ignore[override] return self._storage_meta.properties.pin_memory - # pyre-ignore[14] def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": self._storage_meta.properties.requires_grad = requires_grad [shard.requires_grad_(requires_grad) for shard in self._local_shards] @@ -233,7 +268,7 @@ def local_offsets(self) -> list[torch.Size]: @property def local_chunks(self) -> list[ChunkStorageMetadata]: """ - Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the + Returns a :class:`list[ChunkStorageMetadata]` object corresponding to the metadata for each tensor shard """ return self._storage_meta.chunks @@ -245,9 +280,14 @@ def storage_metadata(self) -> TensorStorageMetadata: """ return self._storage_meta - def __create_write_items__( - self, fqn: str, object: Any - ) -> list[WriteItem]: # pyre-ignore[2] + def is_empty_shard(self) -> bool: + """ + Returns a :class:`bool` object indicating if the local tensor on current rank + is an empty tensor + """ + return self._storage_meta.size[0] == 0 and self._storage_meta.size[1] == 0 + + def __create_write_items__(self, fqn: str, object: Any) -> list[WriteItem]: """ For compatibility with DCP, we support creation of WriteItems such that they can be saved properly. @@ -293,6 +333,12 @@ def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: if chunk.offsets == index.offset: return shard + # Empty shard case + if len(self._local_shards) == 0 and self._storage_meta.chunks[ + 0 + ].sizes == torch.Size([0, 0]): + return torch.empty(0) + raise ValueError( f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" ) @@ -303,12 +349,9 @@ def _get_tensor_size_bytes(self) -> int: object_size += shard.nelement() * shard.element_size() return object_size - # pyre-fixme[3]: Return type must be annotated. - def __hash__(self): + def __hash__(self) -> int: return id(self) - # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. - # pyre-fixme[3]: Return type must be annotated. def __repr__(self) -> str: # type: ignore[override] return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 61705610f08f..34b000a34910 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import Sequence from typing import cast @@ -15,6 +16,59 @@ ) +def _explicit_order_placements( + mesh_shape: ShapeType, placements: Sequence[Placement] +) -> Sequence[tuple[int, Placement]]: + """ + Replace Strided Shards with regular shards in an adjusted order. + + Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. + + ex. + [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> + [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] + + """ + if not len(placements) == len(mesh_shape): + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." + ) + ordered = [] + deferred_strided_placements = defaultdict(list) + strided_part_ended_for_dim = set() + for mesh_dim, p in enumerate(placements): + if isinstance(p, _StridedShard): + # validate the stride is the correct multiple of the meshdim and the earlier shard + deferred_strided_placements[p.dim].append((mesh_dim, p)) + + else: + ordered.append((mesh_dim, p)) + if isinstance(p, Shard): + if p.dim in strided_part_ended_for_dim: + raise NotImplementedError( + f"Strided sharding does not allow Shard() to appear after " + f"the strided part has ended. {p} at mesh dim {mesh_dim} in " + f"{placements} violates this assumption." + ) + + if p.dim in deferred_strided_placements: + strided_part_ended_for_dim.add(p.dim) + strided_placements = deferred_strided_placements.pop(p.dim) + aggregate_size = mesh_shape[mesh_dim] + while len(strided_placements) > 0: + strided_mesh_dim, strided = strided_placements.pop() + if not strided.split_factor == aggregate_size: + raise RuntimeError( + f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" + f" == aggregate mesh size ({aggregate_size})" + ) + aggregate_size *= mesh_shape[strided_mesh_dim] + ordered.append((strided_mesh_dim, Shard(p.dim))) + + return ordered + + def compute_local_shape_and_global_offset( global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] ) -> tuple[tuple[int, ...], tuple[int, ...]]: diff --git a/torch/distributed/tensor/parallel/__init__.py b/torch/distributed/tensor/parallel/__init__.py index 9fe378c51b0d..5e4881de4387 100644 --- a/torch/distributed/tensor/parallel/__init__.py +++ b/torch/distributed/tensor/parallel/__init__.py @@ -5,6 +5,7 @@ ColwiseParallel, ParallelStyle, PrepareModuleInput, + PrepareModuleInputOutput, PrepareModuleOutput, RowwiseParallel, SequenceParallel, @@ -15,6 +16,7 @@ "ColwiseParallel", "ParallelStyle", "PrepareModuleInput", + "PrepareModuleInputOutput", "PrepareModuleOutput", "RowwiseParallel", "SequenceParallel", diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index e5ce3371ff96..3580a924d183 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -23,6 +23,7 @@ "SequenceParallel", "ColwiseParallel", "PrepareModuleInput", + "PrepareModuleInputOutput", "PrepareModuleOutput", ] @@ -698,3 +699,114 @@ def __repr__(self) -> str: tmpstr += f"use_local_output={self.use_local_output}" tmpstr += ")" return tmpstr + + +class PrepareModuleInputOutput(ParallelStyle): + """ + Configure the nn.Module's inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module + to DTensors at runtime according to ``input_layouts`` (and output_layouts, respectively), and perform layout redistribution + according to the ``desired_input_layouts`` (and ``desired_output_layouts``, respectively). This is a combination of + :class:`PrepareModuleInput` and :class:`PrepareModuleOutput`. + + Keyword Args: + input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to + DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified + as a placeholder. default: None. + desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. + input_kwarg_layouts (Dict[str, Placement]): + The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. + default: None + desired_input_kwarg_layouts: (Dict[str, Placement]): + The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. default: None. + use_local_input (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. + output_layouts (Union[Placement, Tuple[Placement]]): + The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to + DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, + ``None`` need to be specified as a placeholder. + desired_output_layouts (Union[Placement, Tuple[Placement]]): + The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module + have the desired DTensor layouts. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. + Returns: + A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs and outputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor + >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated + >>> # as Replicated DTensor and then redistributed to Sharded DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan={ + >>> "attn": PrepareModuleInputOutput( + >>> input_layouts=(Shard(0), None, None, ...), + >>> desired_input_layouts=(Replicate(), None, None, ...), + >>> output_layouts=Replicate(), + >>> desired_output_layouts=Shard(0), + >>> ), + >>> } + >>> ) + """ + + def __init__( + self, + *, + input_layouts: Optional[Union[Placement, tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, tuple[Optional[Placement]]] + ] = None, + input_kwarg_layouts: Optional[dict[str, Placement]] = None, + desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None, + use_local_input: bool = False, + output_layouts: Union[Placement, tuple[Placement]], + desired_output_layouts: Union[Placement, tuple[Placement]], + use_local_output: bool = True, + ): + self.prepare_module_input = PrepareModuleInput( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_input, + ) + self.prepare_module_output = PrepareModuleOutput( + output_layouts=output_layouts, + desired_output_layouts=desired_output_layouts, + use_local_output=use_local_output, + ) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + self.prepare_module_input._apply(module, device_mesh) + self.prepare_module_output._apply(module, device_mesh) + + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.prepare_module_input.input_layouts}, " + tmpstr += ( + f"desired_input_layouts={self.prepare_module_input.desired_input_layouts}, " + ) + tmpstr += ( + f"input_kwarg_layouts={self.prepare_module_input.input_kwarg_layouts}, " + ) + tmpstr += f"desired_input_kwarg_layouts={self.prepare_module_input.desired_input_kwarg_layouts}, " + tmpstr += f"use_local_input={self.prepare_module_input.use_local_output}, " + tmpstr += f"output_layouts={self.prepare_module_output.output_layouts}, " + tmpstr += f"desired_output_layouts={self.prepare_module_output.desired_output_layouts}, " + tmpstr += f"use_local_output={self.prepare_module_output.use_local_output}" + tmpstr += ")" + return tmpstr diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index ceb9f170fd3e..7b3302359e03 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -472,6 +472,9 @@ def _split_tensor( f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" ) + # num_chunks represents the size of this StridedShard mesh dim, while self.split_factor + # represents the aggregate num chunks for other shardings applied logically earlier than this strided shard. + # (e.g. in FSDP+TP case, num_chunks is size(dp dim), split_factor is size(tp dim)) total_split = num_chunks * self.split_factor assert tensor.size(self.dim) % total_split == 0, ( "_StridedShard currently only allows even sharding but got tensor size" diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 105038641bcc..659f9a20b10e 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -10,7 +12,7 @@ probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Bernoulli"] @@ -41,7 +43,12 @@ class Bernoulli(ExponentialFamily): has_enumerate_support = True _mean_carrier_measure = 0 - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -50,6 +57,7 @@ def __init__(self, probs=None, logits=None, validate_args=None): is_scalar = isinstance(probs, _Number) (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index e030b648a88e..e06a28ca5aa4 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -36,7 +38,12 @@ class Beta(ExponentialFamily): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: if isinstance(concentration1, _Number) and isinstance(concentration0, _Number): concentration1_concentration0 = torch.tensor( [float(concentration1), float(concentration0)] diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 6cbfae150844..90461784c06d 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -50,7 +52,13 @@ class Binomial(Distribution): } has_enumerate_support = True - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: Union[Tensor, int] = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -62,6 +70,7 @@ def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: + assert logits is not None # helps mypy ( self.total_count, self.logits, diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 715429c66552..1c8fed2636ad 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -51,7 +53,12 @@ class Categorical(Distribution): arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -61,6 +68,7 @@ def __init__(self, probs=None, logits=None, validate_args=None): raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs / probs.sum(-1, keepdim=True) else: + assert logits is not None # helps mypy if logits.dim() < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") # Normalize diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 582c08ebb858..84c1d34bda79 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, nan, Tensor @@ -34,7 +35,12 @@ class Cauchy(Distribution): support = constraints.real has_rsample = True - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index f175bc44f69e..fa23115fc035 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor from torch.distributions import constraints from torch.distributions.gamma import Gamma @@ -25,7 +27,11 @@ class Chi2(Gamma): arg_constraints = {"df": constraints.positive} - def __init__(self, df, validate_args=None): + def __init__( + self, + df: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: super().__init__(0.5 * df, 0.5, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index b1e8eddfb0ec..14d0d6a9c177 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -13,7 +14,7 @@ probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number, _size +from torch.types import _Number, _size, Number __all__ = ["ContinuousBernoulli"] @@ -52,7 +53,11 @@ class ContinuousBernoulli(ExponentialFamily): has_rsample = True def __init__( - self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + lims: tuple[float, float] = (0.499, 0.501), + validate_args: Optional[bool] = None, ) -> None: if (probs is None) == (logits is None): raise ValueError( @@ -68,6 +73,7 @@ def __init__( raise ValueError("The parameter probs has invalid values") self.probs = clamp_probs(self.probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index f656a0582e89..414ad6efe47e 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.autograd import Function @@ -54,7 +56,11 @@ class Dirichlet(ExponentialFamily): support = constraints.simplex has_rsample = True - def __init__(self, concentration, validate_args=None): + def __init__( + self, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: if concentration.dim() < 1: raise ValueError( "`concentration` parameter must be at least one-dimensional." diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 75ea50d24860..b2895cb3b0d7 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -44,7 +44,7 @@ def __init__( batch_shape: torch.Size = torch.Size(), event_shape: torch.Size = torch.Size(), validate_args: Optional[bool] = None, - ): + ) -> None: self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index 8ca2636e1f52..d15cb1f7a258 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -46,7 +48,11 @@ def stddev(self) -> Tensor: def variance(self) -> Tensor: return self.rate.pow(-2) - def __init__(self, rate, validate_args=None): + def __init__( + self, + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: (self.rate,) = broadcast_all(rate) batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size() super().__init__(batch_shape, validate_args=validate_args) diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 053686c6de07..4755bd0d8bde 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -31,7 +33,12 @@ class FisherSnedecor(Distribution): support = constraints.positive has_rsample = True - def __init__(self, df1, df2, validate_args=None): + def __init__( + self, + df1: Union[Tensor, float], + df2: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.df1, self.df2 = broadcast_all(df1, df2) self._gamma1 = Gamma(self.df1 * 0.5, self.df1) self._gamma2 = Gamma(self.df2 * 0.5, self.df2) diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 5e0fe3fc7823..9df91ebee640 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -52,7 +54,12 @@ def mode(self) -> Tensor: def variance(self) -> Tensor: return self.concentration / self.rate.pow(2) - def __init__(self, concentration, rate, validate_args=None): + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.concentration, self.rate = broadcast_all(concentration, rate) if isinstance(concentration, _Number) and isinstance(rate, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index b8b05142db5b..b5ceac39e94e 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -10,7 +12,7 @@ probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Geometric"] @@ -45,7 +47,12 @@ class Geometric(Distribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.nonnegative_integer - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -53,11 +60,13 @@ def __init__(self, probs=None, logits=None, validate_args=None): if probs is not None: (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy (self.logits,) = broadcast_all(logits) probs_or_logits = probs if probs is not None else logits if isinstance(probs_or_logits, _Number): batch_shape = torch.Size() else: + assert probs_or_logits is not None # helps mypy batch_shape = probs_or_logits.size() super().__init__(batch_shape, validate_args=validate_args) if self._validate_args and probs is not None: diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 623cc7edbda6..6d097c9324e2 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -33,7 +34,12 @@ class Gumbel(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) finfo = torch.finfo(self.loc.dtype) if isinstance(loc, _Number) and isinstance(scale, _Number): diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index da17c40da2ed..572ae080ac3e 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, Tensor @@ -33,8 +34,13 @@ class HalfCauchy(TransformedDistribution): arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True + base_dist: Cauchy - def __init__(self, scale, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Cauchy(0, scale, validate_args=False) super().__init__(base_dist, AbsTransform(), validate_args=validate_args) diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 5850f883e908..21e1b9d2c506 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, Tensor @@ -33,8 +34,13 @@ class HalfNormal(TransformedDistribution): arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True + base_dist: Normal - def __init__(self, scale, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(0, scale, validate_args=False) super().__init__(base_dist, AbsTransform(), validate_args=validate_args) diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 0442a4c1b483..b66406681bb8 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs +from typing import Generic, Optional, TypeVar import torch -from torch import Tensor +from torch import Size, Tensor from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _sum_rightmost @@ -11,7 +12,10 @@ __all__ = ["Independent"] -class Independent(Distribution): +D = TypeVar("D", bound=Distribution) + + +class Independent(Distribution, Generic[D]): r""" Reinterprets some of the batch dims of a distribution as event dims. @@ -42,17 +46,21 @@ class Independent(Distribution): """ arg_constraints: dict[str, constraints.Constraint] = {} + base_dist: D def __init__( - self, base_distribution, reinterpreted_batch_ndims, validate_args=None - ): + self, + base_distribution: D, + reinterpreted_batch_ndims: int, + validate_args: Optional[bool] = None, + ) -> None: if reinterpreted_batch_ndims > len(base_distribution.batch_shape): raise ValueError( "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" ) - shape = base_distribution.batch_shape + base_distribution.event_shape - event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) + shape: Size = base_distribution.batch_shape + base_distribution.event_shape + event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape) batch_shape = shape[: len(shape) - event_dim] event_shape = shape[len(shape) - event_dim :] self.base_dist = base_distribution diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index aaee976b7f17..de432a34434e 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -38,8 +40,14 @@ class InverseGamma(TransformedDistribution): } support = constraints.positive has_rsample = True - - def __init__(self, concentration, rate, validate_args=None): + base_dist: Gamma + + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Gamma(concentration, rate, validate_args=validate_args) neg_one = -base_dist.rate.new_ones(()) super().__init__( diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index d38efb631e86..53c09ab9870d 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -45,7 +47,12 @@ class Kumaraswamy(TransformedDistribution): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.concentration1, self.concentration0 = broadcast_all( concentration1, concentration0 ) diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 39ef9b1efdb7..0d50712fb26f 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -46,7 +48,12 @@ def variance(self) -> Tensor: def stddev(self) -> Tensor: return (2**0.5) * self.scale - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index a18f2ed9f52a..d2c29a9286de 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -9,8 +9,10 @@ """ import math +from typing import Optional, Union import torch +from torch import Tensor from torch.distributions import Beta, constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all @@ -61,7 +63,12 @@ class LKJCholesky(Distribution): arg_constraints = {"concentration": constraints.positive} support = constraints.corr_cholesky - def __init__(self, dim, concentration=1.0, validate_args=None): + def __init__( + self, + dim: int, + concentration: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: if dim < 2: raise ValueError( f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index a048f94286c8..2c6dbc6bf55c 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor from torch.distributions import constraints from torch.distributions.normal import Normal @@ -32,8 +34,14 @@ class LogNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.positive has_rsample = True - - def __init__(self, loc, scale, validate_args=None): + base_dist: Normal + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(loc, scale, validate_args=validate_args) super().__init__(base_dist, ExpTransform(), validate_args=validate_args) diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index a8f7c099d1e8..729e3a67419f 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor -from torch.distributions import constraints +from torch.distributions import constraints, Independent from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import StickBreakingTransform @@ -36,8 +38,14 @@ class LogisticNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.simplex has_rsample = True + base_dist: Independent[Normal] - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(loc, scale, validate_args=validate_args) if not base_dist.batch_shape: base_dist = base_dist.expand([1]) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index c6f739a595a3..968e4634ba62 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch from torch import Tensor @@ -93,7 +94,13 @@ class LowRankMultivariateNormal(Distribution): support = constraints.real_vector has_rsample = True - def __init__(self, loc, cov_factor, cov_diag, validate_args=None): + def __init__( + self, + loc: Tensor, + cov_factor: Tensor, + cov_diag: Tensor, + validate_args: Optional[bool] = None, + ) -> None: if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") event_shape = loc.shape[-1:] diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 1fc2c1052d03..79a7029e1d72 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional import torch from torch import Tensor @@ -59,7 +60,7 @@ def __init__( self, mixture_distribution: Categorical, component_distribution: Distribution, - validate_args=None, + validate_args: Optional[bool] = None, ) -> None: self._mixture_distribution = mixture_distribution self._component_distribution = component_distribution diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 85a227f5c403..41d8ded53fd6 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import inf, Tensor from torch.distributions import Categorical, constraints @@ -59,7 +61,13 @@ def mean(self) -> Tensor: def variance(self) -> Tensor: return self.total_count * self.probs * (1 - self.probs) - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: int = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if not isinstance(total_count, int): raise NotImplementedError("inhomogeneous total_count is not supported") self.total_count = total_count diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 849ee4170015..c15a84815b06 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch from torch import Tensor @@ -133,12 +134,12 @@ class MultivariateNormal(Distribution): def __init__( self, - loc, - covariance_matrix=None, - precision_matrix=None, - scale_tril=None, - validate_args=None, - ): + loc: Tensor, + covariance_matrix: Optional[Tensor] = None, + precision_matrix: Optional[Tensor] = None, + scale_tril: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) + (scale_tril is not None) + ( @@ -167,6 +168,7 @@ def __init__( ) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: + assert precision_matrix is not None # helps mypy if precision_matrix.dim() < 2: raise ValueError( "precision_matrix must be at least two-dimensional, " diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index e5b0e128efe6..f28222f92f78 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch import torch.nn.functional as F from torch import Tensor @@ -38,7 +40,13 @@ class NegativeBinomial(Distribution): } support = constraints.nonnegative_integer - def __init__(self, total_count, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: Union[Tensor, float], + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -50,6 +58,7 @@ def __init__(self, total_count, probs=None, logits=None, validate_args=None): ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: + assert logits is not None # helps mypy ( self.total_count, self.logits, diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 86e30ba450f5..626358d14795 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -51,7 +52,12 @@ def stddev(self) -> Tensor: def variance(self) -> Tensor: return self.stddev.pow(2) - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 7e0bc03c5aba..8edb6da0b8dd 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.distributions import constraints @@ -44,7 +46,12 @@ class OneHotCategorical(Distribution): support = constraints.one_hot has_enumerate_support = True - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 2cc1e298ba25..bbca7e0cba35 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from torch import Tensor from torch.distributions import constraints @@ -31,7 +31,10 @@ class Pareto(TransformedDistribution): arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} def __init__( - self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None + self, + scale: Union[Tensor, float], + alpha: Union[Tensor, float], + validate_args: Optional[bool] = None, ) -> None: self.scale, self.alpha = broadcast_all(scale, alpha) base_dist = Exponential(self.alpha, validate_args=validate_args) diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index c3b4bacc54cb..d3fb4446baf4 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -1,10 +1,12 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Poisson"] @@ -45,7 +47,11 @@ def mode(self) -> Tensor: def variance(self) -> Tensor: return self.rate - def __init__(self, rate, validate_args=None): + def __init__( + self, + rate: Union[Tensor, Number], + validate_args: Optional[bool] = None, + ) -> None: (self.rate,) = broadcast_all(rate) if isinstance(rate, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 4c1549660313..16ad4219627e 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -12,7 +14,7 @@ logits_to_probs, probs_to_logits, ) -from torch.types import _Number, _size +from torch.types import _Number, _size, Number __all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] @@ -41,7 +43,13 @@ class LogitRelaxedBernoulli(Distribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: self.temperature = temperature if (probs is None) == (logits is None): raise ValueError( @@ -51,6 +59,7 @@ def __init__(self, temperature, probs=None, logits=None, validate_args=None): is_scalar = isinstance(probs, _Number) (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits @@ -131,8 +140,15 @@ class RelaxedBernoulli(TransformedDistribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval has_rsample = True - - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist: LogitRelaxedBernoulli + + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: base_dist = LogitRelaxedBernoulli(temperature, probs, logits) super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 97ae3ed1857b..47314be9e44a 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.distributions import constraints @@ -42,7 +44,13 @@ class ExpRelaxedCategorical(Distribution): ) # The true support is actually a submanifold of this. has_rsample = True - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: self._categorical = Categorical(probs, logits) self.temperature = temperature batch_shape = self._categorical.batch_shape @@ -121,8 +129,15 @@ class RelaxedOneHotCategorical(TransformedDistribution): arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.simplex has_rsample = True - - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist: ExpRelaxedCategorical + + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: base_dist = ExpRelaxedCategorical( temperature, probs, logits, validate_args=validate_args ) diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index e141939b2745..d98554f413c0 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, nan, Tensor @@ -60,7 +61,13 @@ def variance(self) -> Tensor: m[self.df <= 1] = nan return m - def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): + def __init__( + self, + df: Union[Tensor, float], + loc: Union[Tensor, float] = 0.0, + scale: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: self.df, self.loc, self.scale = broadcast_all(df, loc, scale) self._chi2 = Chi2(self.df) batch_shape = self.df.size() diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 02792ce9d309..d5fbff877413 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional, Union import torch from torch import Tensor @@ -49,7 +50,12 @@ class TransformedDistribution(Distribution): arg_constraints: dict[str, constraints.Constraint] = {} - def __init__(self, base_distribution, transforms, validate_args=None): + def __init__( + self, + base_distribution: Distribution, + transforms: Union[Transform, list[Transform]], + validate_args: Optional[bool] = None, + ) -> None: if isinstance(transforms, Transform): self.transforms = [ transforms, diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 8958f1a63c87..a033ce14408b 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -3,11 +3,14 @@ import math import operator import weakref -from typing import Optional +from collections.abc import Sequence +from typing import Optional, Union import torch import torch.nn.functional as F +from torch import Tensor from torch.distributions import constraints +from torch.distributions.distribution import Distribution from torch.distributions.utils import ( _sum_rightmost, broadcast_all, @@ -92,7 +95,7 @@ class Transform: domain: constraints.Constraint codomain: constraints.Constraint - def __init__(self, cache_size=0): + def __init__(self, cache_size: int = 0) -> None: self._cache_size = cache_size self._inv: Optional[weakref.ReferenceType[Transform]] = None if cache_size == 0: @@ -218,7 +221,7 @@ class _InverseTransform(Transform): This class is private; please instead use the ``Transform.inv`` property. """ - def __init__(self, transform: Transform): + def __init__(self, transform: Transform) -> None: super().__init__(cache_size=transform._cache_size) self._inv: Transform = transform # type: ignore[assignment] @@ -285,7 +288,7 @@ class ComposeTransform(Transform): the latest single value is cached. Only 0 and 1 are supported. """ - def __init__(self, parts: list[Transform], cache_size=0): + def __init__(self, parts: list[Transform], cache_size: int = 0) -> None: if cache_size: parts = [part.with_cache(cache_size) for part in parts] super().__init__(cache_size=cache_size) @@ -413,7 +416,12 @@ class IndependentTransform(Transform): dimensions to treat as dependent. """ - def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): + def __init__( + self, + base_transform: Transform, + reinterpreted_batch_ndims: int, + cache_size: int = 0, + ) -> None: super().__init__(cache_size=cache_size) self.base_transform = base_transform.with_cache(cache_size) self.reinterpreted_batch_ndims = reinterpreted_batch_ndims @@ -442,7 +450,7 @@ def bijective(self) -> bool: # type: ignore[override] return self.base_transform.bijective @property - def sign(self) -> int: # type: ignore[override] + def sign(self) -> int: return self.base_transform.sign def _call(self, x): @@ -486,7 +494,12 @@ class ReshapeTransform(Transform): bijective = True - def __init__(self, in_shape, out_shape, cache_size=0): + def __init__( + self, + in_shape: torch.Size, + out_shape: torch.Size, + cache_size: int = 0, + ) -> None: self.in_shape = torch.Size(in_shape) self.out_shape = torch.Size(out_shape) if self.in_shape.numel() != self.out_shape.numel(): @@ -571,7 +584,7 @@ class PowerTransform(Transform): codomain = constraints.positive bijective = True - def __init__(self, exponent, cache_size=0): + def __init__(self, exponent: Tensor, cache_size: int = 0) -> None: super().__init__(cache_size=cache_size) (self.exponent,) = broadcast_all(exponent) @@ -582,7 +595,7 @@ def with_cache(self, cache_size=1): @lazy_property def sign(self) -> int: # type: ignore[override] - return self.exponent.sign() + return self.exponent.sign() # type: ignore[return-value] def __eq__(self, other): if not isinstance(other, PowerTransform): @@ -734,7 +747,13 @@ class AffineTransform(Transform): bijective = True - def __init__(self, loc, scale, event_dim=0, cache_size=0): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + event_dim: int = 0, + cache_size: int = 0, + ) -> None: super().__init__(cache_size=cache_size) self.loc = loc self.scale = scale @@ -771,20 +790,20 @@ def __eq__(self, other): if self.loc != other.loc: return False else: - if not (self.loc == other.loc).all().item(): + if not (self.loc == other.loc).all().item(): # type: ignore[union-attr] return False if isinstance(self.scale, _Number) and isinstance(other.scale, _Number): if self.scale != other.scale: return False else: - if not (self.scale == other.scale).all().item(): + if not (self.scale == other.scale).all().item(): # type: ignore[union-attr] return False return True @property - def sign(self) -> int: + def sign(self) -> Union[Tensor, int]: # type: ignore[override] if isinstance(self.scale, _Number): return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 return self.scale.sign() @@ -1022,7 +1041,7 @@ class PositiveDefiniteTransform(Transform): """ domain = constraints.independent(constraints.real, 2) - codomain = constraints.positive_definite # type: ignore[assignment] + codomain = constraints.positive_definite def __eq__(self, other): return isinstance(other, PositiveDefiniteTransform) @@ -1053,7 +1072,13 @@ class CatTransform(Transform): transforms: list[Transform] - def __init__(self, tseq, dim=0, lengths=None, cache_size=0): + def __init__( + self, + tseq: Sequence[Transform], + dim: int = 0, + lengths: Optional[Sequence[int]] = None, + cache_size: int = 0, + ) -> None: assert all(isinstance(t, Transform) for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] @@ -1157,7 +1182,9 @@ class StackTransform(Transform): transforms: list[Transform] - def __init__(self, tseq, dim=0, cache_size=0): + def __init__( + self, tseq: Sequence[Transform], dim: int = 0, cache_size: int = 0 + ) -> None: assert all(isinstance(t, Transform) for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] @@ -1237,12 +1264,12 @@ class CumulativeDistributionTransform(Transform): codomain = constraints.unit_interval sign = +1 - def __init__(self, distribution, cache_size=0): + def __init__(self, distribution: Distribution, cache_size: int = 0) -> None: super().__init__(cache_size=cache_size) self.distribution = distribution @property - def domain(self) -> constraints.Constraint: # type: ignore[override] + def domain(self) -> Optional[constraints.Constraint]: # type: ignore[override] return self.distribution.support def _call(self, x): diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 31007c924de0..37decbaadce5 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -50,7 +52,12 @@ def stddev(self) -> Tensor: def variance(self) -> Tensor: return (self.high - self.low).pow(2) / 12 - def __init__(self, low, high, validate_args=None): + def __init__( + self, + low: Union[Tensor, float], + high: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.low, self.high = broadcast_all(low, high) if isinstance(low, _Number) and isinstance(high, _Number): diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index f83d75c904ab..b53c4721ffc7 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import Tensor from torch.overrides import is_tensor_like -from torch.types import _Number +from torch.types import _Number, Number euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant @@ -23,7 +23,9 @@ ] -def broadcast_all(*values): +# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added. +# See https://github.com/python/typing/issues/1216#issuecomment-2126153831 +def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: r""" Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 9a144fe10817..4f96a23cf55b 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch import torch.jit @@ -126,7 +127,12 @@ class VonMises(Distribution): support = constraints.real has_rsample = False - def __init__(self, loc, concentration, validate_args=None): + def __init__( + self, + loc: Tensor, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.concentration = broadcast_all(loc, concentration) batch_shape = self.loc.shape event_shape = torch.Size() diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index e7b3c5e0cebe..98132472b4ee 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -34,7 +36,12 @@ class Weibull(TransformedDistribution): } support = constraints.positive - def __init__(self, scale, concentration, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + concentration: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.scale, self.concentration = broadcast_all(scale, concentration) self.concentration_reciprocal = self.concentration.reciprocal() base_dist = Exponential( diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 225aeeb97430..1b5a51ea88f9 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -80,8 +80,8 @@ def __init__( covariance_matrix: Optional[Tensor] = None, precision_matrix: Optional[Tensor] = None, scale_tril: Optional[Tensor] = None, - validate_args=None, - ): + validate_args: Optional[bool] = None, + ) -> None: assert (covariance_matrix is not None) + (scale_tril is not None) + ( precision_matrix is not None ) == 1, ( diff --git a/torch/export/__init__.py b/torch/export/__init__.py index f3cd894185e6..e95ac3f3a1df 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -51,13 +51,14 @@ "unflatten", "FlatArgsAdapter", "UnflattenedModule", + "AdditionalInputs", ] # To make sure export specific custom ops are loaded import torch.export.custom_ops from .decomp_utils import CustomDecompTable -from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection +from .dynamic_shapes import AdditionalInputs, Constraint, Dim, dims, ShapesCollection from .exported_program import ( default_decompositions, ExportedProgram, @@ -522,9 +523,4 @@ def forward(self, x: InputDataClass) -> OutputDataClass: print(ep) """ - - from torch._export.utils import register_dataclass_as_pytree_node - - return register_dataclass_as_pytree_node( - cls, serialized_type_name=serialized_type_name - ) + pytree.register_dataclass(cls, serialized_type_name=serialized_type_name) diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 604f865a2b08..103a4abb0540 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -11,10 +11,11 @@ import torch import torch._logging._internal import torch._logging.structured +import torch.utils._pytree as pytree from torch._export.passes.insert_custom_op_guards import insert_custom_op_guards from torch.export import ExportedProgram from torch.export._trace import _export -from torch.export.dynamic_shapes import refine_dynamic_shapes_from_suggested_fixes +from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim log = logging.getLogger(__name__) @@ -23,7 +24,7 @@ class FailureType(IntEnum): MISSING_FAKE_KERNEL = 1 DATA_DEPENDENT_ERROR = 2 - CONSTRAINT_VIOLATION_ERROR = 3 + GUARD_ADDED = 3 MISMATCHED_FAKE_KERNEL = 4 def __str__(self) -> str: @@ -94,17 +95,19 @@ def print(self, str_to_filename: dict[int, str]) -> str: Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation. """ # noqa: B950 - elif self.failure_type == FailureType.CONSTRAINT_VIOLATION_ERROR: + elif self.failure_type == FailureType.GUARD_ADDED: locals_info = ( prettify_frame_locals(**self.data["frame_locals"]) if self.data["frame_locals"] else "" ) - return f"""Constraint violation error. - The specified input dynamic_shapes spec was found to be incorrect during tracing. + return f"""Guard Added. + A guard was added during tracing, which might've resulted in some incorrect + tracing or constraint violation error. Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}. - This occurred at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}: + This occurred at the following stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}: {locals_info} + And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n Because of this, we have modified the dynamic shapes structure to be the following. You can also use torch.export.Dim.AUTO instead to specify your dynamic shapes, and we will automatically infer the dynamism for you. @@ -216,6 +219,8 @@ def _hash(self, element: tuple[str, dict[str, Any]]) -> int: return hash((key, data["op"], data["reason"])) elif key == "propagate_real_tensors_provenance": return hash((key, json.dumps(data["user_stack"]))) + elif key == "guard_added": + return hash((key, json.dumps(data["user_stack"]))) elif key == "create_unbacked_symbol": return hash((key, json.dumps(data["user_stack"]))) @@ -377,10 +382,16 @@ def draft_export( pre_dispatch=pre_dispatch, preserve_module_call_signature=preserve_module_call_signature, ) - except torch._dynamo.exc.UserError as exc: - new_shapes = refine_dynamic_shapes_from_suggested_fixes( - exc.msg, dynamic_shapes - ) + except torch._dynamo.exc.UserError: + + def convert_dim_to_auto(dim: Any) -> Any: + if isinstance(dim, Dim): + return Dim.AUTO(min=dim.min, max=dim.max) + elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC: + return Dim.AUTO(min=dim.min, max=dim.max) + return dim + + new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes) ep = _export( mod, args, @@ -420,7 +431,7 @@ def draft_export( if new_shapes is None: continue - failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR + failure_type = FailureType.GUARD_ADDED log_contents["new_dynamic_shapes"] = new_shapes elif log_name == "missing_fake_kernel": failure_type = FailureType.MISSING_FAKE_KERNEL diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 830710f44e7d..72269acc2625 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -500,11 +500,12 @@ def _produce_aten_artifact( It does: 1. Applies runtime assertion pass - 2. Populate meta val when missing - 3. Lift constants as placeholders - 4. Replace raw autograd and autocast ops with HOPs - 5. Prettify names for placeholders - 6. Preserve requires_grad value on node meta val + 2. Recompute unbacked_bindings pass + 3. Populate meta val when missing + 4. Lift constants as placeholders + 5. Replace raw autograd and autocast ops with HOPs + 6. Prettify names for placeholders + 7. Preserve requires_grad value on node meta val """ # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. # Overwrite output specs afterwards. @@ -1155,10 +1156,15 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes): kwargs = kwargs if kwargs is not None else {} _, original_in_spec = pytree.tree_flatten((args, kwargs)) - if isinstance(dynamic_shapes, torch.export.ShapesCollection): + if isinstance(dynamic_shapes, torch.export.AdditionalInputs): + verify_additional_inputs = dynamic_shapes.verify dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + else: + verify_additional_inputs = lambda ep: None # noqa: E731 + if isinstance(dynamic_shapes, torch.export.ShapesCollection): + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) - return args, kwargs, original_in_spec, dynamic_shapes + return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs def _get_module_call_graph( @@ -1971,6 +1977,7 @@ def _export_for_training( kwargs, orig_in_spec, dynamic_shapes, + verify_additional_inputs, ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) original_state_dict = _get_original_state_dict(mod) @@ -2033,6 +2040,7 @@ def _export_for_training( verifiers=[TrainingIRVerifier], ) + verify_additional_inputs(exported_program) return exported_program @@ -2132,6 +2140,7 @@ def _export( kwargs, original_in_spec, dynamic_shapes, + verify_additional_inputs, ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) original_state_dict = _get_original_state_dict(mod) @@ -2205,4 +2214,5 @@ def _export( dtrace_structured("exported_program", payload_fn=lambda: str(exported_program)) + verify_additional_inputs(exported_program) return exported_program diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 0caf82160054..e51c12800ad9 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -105,7 +105,12 @@ def _unlift_inputs_as_getattr( else: with gm.graph.inserting_after(input_node): - getattr_node = gm.graph.get_attr(lifted_node) + # It is fine to ignore this warning because + # it is guaranteed that we will populate this + # attr later. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + getattr_node = gm.graph.get_attr(lifted_node) input_node.replace_all_uses_with(getattr_node) metadata = input_node.meta gm.graph.erase_node(input_node) diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 285b0555034b..50682a948aaf 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -34,6 +34,7 @@ "Dim", "dims", "refine_dynamic_shapes_from_suggested_fixes", + "AdditionalInputs", ] @@ -56,6 +57,9 @@ class _DimHintType(Enum): @dataclasses.dataclass class _DimHint: type: _DimHintType + min: Optional[int] = None + max: Optional[int] = None + _factory: Optional[bool] = True @staticmethod def AUTO(): @@ -69,6 +73,14 @@ def DYNAMIC(): def STATIC(): return _DimHint(_DimHintType.STATIC) + def __call__(self, min=None, max=None) -> "_DimHint": + if not self._factory: + raise TypeError(f"'{type(self)}' object is not callable") + assert min is None or min >= 0, "min must be non-negative" + assert max is None or max >= 0, "max must be non-negative" + assert min is None or max is None or min <= max, "min must be <= max" + return _DimHint(self.type, min=min, max=max, _factory=False) + class Dim: """ @@ -702,6 +714,84 @@ def find_shape(path, t): return dynamic_shapes +class AdditionalInputs: + """ + Infers dynamic_shapes based on additional inputs. + + This is useful particularly for deployment engineers who, on the one hand, may + have access to ample testing or profiling data that can provide a fair sense of + representative inputs for a model, but on the other hand, may not know enough + about the model to guess which input shapes should be dynamic. + + Input shapes that are different than the original are considered dynamic; conversely, + those that are the same as the original are considered static. Moreover, we verify + that the additional inputs are valid for the exported program. This guarantees that + tracing with them instead of the original would have generated the same graph. + + Example:: + + args0, kwargs0 = ... # example inputs for export + + # other representative inputs that the exported program will run on + dynamic_shapes = torch.export.AdditionalInputs() + dynamic_shapes.add(args1, kwargs1) + ... + dynamic_shapes.add(argsN, kwargsN) + + torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._examples = [] + + def add(self, args, kwargs=None): + """ + Additional input :func:`args` and :func:`kwargs`. + """ + + assert type(args) is tuple, f"Representative args {args} must be a tuple" + assert ( + kwargs is None or type(kwargs) is dict + ), f"Representative kwargs {kwargs} must be None or a dict" + self._examples.append((args, kwargs)) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Infers a :func:`dynamic_shapes` pytree structure by merging shapes of the + original input :func:`args` and :func:`kwargs` and of each additional input + args and kwargs. + """ + + dynamic_shapes, *other_dynamic_shapes = [ + _tree_map_with_path( + lambda path, t: tuple(t.shape), _combine_args(m, args, kwargs) + ) + for args, kwargs in [(args, kwargs), *self._examples] + ] + + return tree_map_with_path( + lambda path, dim, *other_dims: ( + dim + if all(other_dim == dim for other_dim in other_dims) + else Dim.DYNAMIC + ), + dynamic_shapes, + *other_dynamic_shapes, + is_leaf=lambda i: type(i) is int, + ) + + def verify(self, ep): + """ + Verifies that an exported program is valid for each additional input. + """ + + epm = ep.module() + for args, kwargs in self._examples: + torch.export._unlift._check_input_constraints_pre_hook( + epm, args, kwargs or {} + ) + + def _warn_on_None_dynamic_shape_dimension(): msg = ( "Using None as a dynamic shape dimension is deprecated. " diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index ee8640c3ade7..bcaf8645f795 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -187,7 +187,7 @@ def _fx_collection_equivalence_fn( @contextmanager -def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): +def _override_composite_implicit_decomp(cia_ops_to_callable): # This function overrides CompositeImplicitAutograd decomp for # functional composite ops that user specified. Ideally we want to not-decompose # ALL composite ops but today's C++ functinalization relies on @@ -195,13 +195,6 @@ def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): # Hence we can only do it for functional ops. One caveat is that # there are some composite ops that lie about their schema (claimed to be # functional but not really aka dropout), for these cases, we just decompose. - - # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing - # and their usual decompositions need to be shadowed rather than overridden. - # Thus we will avoid asserting that they are valid to preserve, and will not - # replace their CompositeImplicitAutograd kernels with NotImplemented. - # The only current users of this mode are variants of aten::to that we will - # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. saved_tables = {} patched_ops = set() for op_overload, decomp_callable in cia_ops_to_callable.items(): @@ -219,10 +212,9 @@ def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] - if safe: - op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( - decomp_callable - ) + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( + decomp_callable + ) # [NOTE] Directly registering fake tensor rule to CIA ops # The problem we are facing here is if your CIA custom rule @@ -278,21 +270,6 @@ def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs): _deregister_op_impl(op) -@contextmanager -def _override_decomp_aten_to_variants(): - # Preserve variants of aten::to understanding that they are mutating/aliasing - # and their CompositeImplicitAutograd kernels will not become NotImplemented. - # We will later replace them with aten._to_copy when functionalizing. - with _override_composite_implicit_decomp( - { - torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia, - torch.ops.aten.to.dtype: _special_op_to_preserve_cia, - }, - safe=False, - ): - yield - - def _split_decomp_table_to_cia_and_python_decomp( decomp_table: dict[torch._ops.OperatorBase, Callable] ) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: @@ -465,15 +442,9 @@ def _is_joint_ir_decomp(ep, joint_loss_index): tx = TracingContext(fake_mode) - with ( - fake_mode - ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( + with fake_mode, _override_composite_implicit_decomp( cia_to_decomp, - ), _enable_graph_inputs_of_type_nn_module( - ep.example_inputs - ), tracing( - tx - ): + ), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx): retracing_args_unwrapped = pytree.tree_unflatten( retracing_args, mod._in_spec ) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 833961170001..85a01ea13ee7 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -589,7 +589,10 @@ def process_forward_inputs(self, *args, **kwargs): return flat_args def forward(self, *args, **kwargs): - flat_args = torch._dynamo.disable(self.process_forward_inputs)(*args, **kwargs) + flat_args = torch._dynamo.disable( + self.process_forward_inputs, + reason="do not trace into preprocessing the inputs", + )(*args, **kwargs) signature = self.module_call_graph[0].signature if is_fx_tracing(): diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 483b7e8b2ea2..525014bf1e80 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -252,13 +252,20 @@ def mod_partition(node: torch.fx.Node): # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) # return add root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + + # The order of placeholders in the const_gm graph should match the order of + # args in the outer module, so we can simply use an index for the + # placeholder mapping + ph_idx = 0 for node in root_const_gm.graph.nodes: if node.op == "output": multiple_outputs = isinstance(node.args[0], tuple) continue if node.op != "placeholder": continue - in_node = next(n for n in call_const_gm_args if n.name == node.target) + assert ph_idx < len(call_const_gm_args) + in_node = call_const_gm_args[ph_idx] + ph_idx += 1 assert in_node.op == "get_attr" with root_const_gm.graph.inserting_before(node): new_node = root_const_gm.graph.get_attr(in_node.target) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 45e3309208e9..4193606d849d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -14,7 +14,6 @@ import traceback import typing import typing_extensions -import warnings import weakref from collections import defaultdict, OrderedDict from collections.abc import Generator, Mapping, Sequence @@ -1125,20 +1124,55 @@ def map_fn(v: Any) -> Optional[_ExtractValType]: return None return extract_val(v.meta["val"]) - # TODO: opt-in mechanism ? - if isinstance( - target, - ( - torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, - torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, - ), - ): + if _should_save_eager_input_vals(target, (args, kwargs)): + # NOTE "eager_input_vals" + # We save the original (args, kwargs) FakeTensor values for nodes + # that have exact stride requirements. This is useful downstream. + # We use this information inside Inductor to ensure that inputs to + # stride-sensitive operators have the correct strides. arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] - node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp) + node.meta["eager_input_vals"] = (arg_inp, kwarg_inp) return node +def _should_save_eager_input_vals( + target: Any, + args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, +) -> bool: + if not callable(target): + return False + if isinstance( + target, + ( + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + ), + ): + return True + if args_kwargs is not None and ( + target is torch.ops.higher_order.auto_functionalized + or target is torch.ops.higher_order.auto_functionalized_v2 + ): + args = args_kwargs[0] + assert isinstance(args[0], torch._ops.OpOverload) + return _should_save_eager_input_vals(args[0], None) + if target is torch.ops.higher_order.with_effects: + # TODO: inductor lowering for with_effects needs to be updated to propagate + # the arg_kwarg_vals + return False + if isinstance(target, torch._ops.HigherOrderOperator): + if pytree.tree_any(_should_save_eager_input_vals, args_kwargs): + raise RuntimeError( + f"NYI: The HOP {target} has an input that is an OpOverload that " + f"needs exact strides. We probably need special logic to " + f"propagate the FakeTensor vals. Please file an issue." + ) + if isinstance(target, torch._ops.OpOverload): + return torch._C.Tag.needs_exact_strides in target.tags + return False + + def _make_temp_remove_mode_context_manager( mode_ty: type[TorchFunctionMode], ) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: @@ -1820,11 +1854,12 @@ def call_module( try: return Tracer.call_module(self, m, forward, args, kwargs) except _ModuleNotInstalledAsSubmoduleError: - warnings.warn( - f"Unable to find the path of the module {m}. " + log.debug( + "Unable to find the path of the module %s. " "This might be because the module was not properly registered " "as a submodule, which is not good practice. We will trace " - "through the module without recording stack information." + "through the module without recording stack information.", + str(m), ) return forward(*args, **kwargs) diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index fa4443b1b5d5..66f80e8dbc49 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -592,6 +592,18 @@ def guard_size_oblivious(self, file, line): log.warning("Failed to convert to bool: %s", r) raise + def guard_or_false(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + assert self.is_bool() + return guard_or_false(SymBool(self)) + + def guard_or_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + assert self.is_bool() + return guard_or_true(SymBool(self)) + def bool_(self): return self.guard_bool("", 0) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 95ca2cef4fd2..9cb699483628 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1195,20 +1195,30 @@ def guard_or_false(a: BoolLikeType) -> bool: """ Try to guard a, if data dependent error encountered just return false. """ - try: - return bool(guard_bool(a)) - except GuardOnDataDependentSymNode: - return False + if torch.fx.experimental._config.backed_size_oblivious: + return statically_known_true(a) + else: + try: + return bool(guard_bool(a)) + except GuardOnDataDependentSymNode: + return False def guard_or_true(a: BoolLikeType) -> bool: """ Try to guard a, if data dependent error encountered just return true. """ - try: - return bool(guard_bool(a)) - except GuardOnDataDependentSymNode: - return True + if torch.fx.experimental._config.backed_size_oblivious: + result = _static_eval(a) + if result is not None: + return result + else: + return True + else: + try: + return bool(guard_bool(a)) + except GuardOnDataDependentSymNode: + return True def definitely_true(a: BoolLikeType) -> bool: @@ -1253,6 +1263,23 @@ def definitely_false(a: BoolLikeType) -> bool: return not bool(a) +def _static_eval(x: Union[bool, SymBool]) -> Optional[bool]: + if isinstance(x, SymBool): + expr = x.node.expr + shape_env = x.node.shape_env + try: + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + else: + return None + except Exception: + log.debug("Could not simplify %s", expr) + return None + assert isinstance(x, bool) + return x + + def statically_known_true(x: Union[bool, SymBool]) -> bool: """ Returns True if x can be simplified to a constant and is true. @@ -1264,18 +1291,11 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool: Args: x (bool, SymBool): The expression to try statically evaluating """ - if isinstance(x, SymBool): - expr = x.node.expr - shape_env = x.node.shape_env - try: - simplified = shape_env._maybe_evaluate_static(expr) - if simplified is not None: - return bool(simplified) - except Exception: - log.debug("Could not simplify %s", expr) + result = _static_eval(x) + if result is None: return False - assert isinstance(x, bool) - return x + else: + return result def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]: @@ -3330,8 +3350,6 @@ def _init( # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable self.val_to_var: dict[int, sympy.Symbol] = {} - if specialize_zero_one: - self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One} self.unbacked_symfloat_counter = itertools.count() self.unbacked_symint_counter = itertools.count() # Similar to guards, but these MUST evaluate to true and can @@ -4541,7 +4559,10 @@ def create_symbol( sloc = self._get_sloc() if val in (0, 1) and specialize_zero_one: - r = self.val_to_var[val] + if val == 0: + return sympy.S.Zero + else: + return sympy.S.One elif not duck or val not in self.val_to_var: # If we're not duck shaping, we always create a new symbol # Even if we're duck shaping, if we haven't seen this particular diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 0e483f19c866..75c0eb8081fb 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -222,6 +222,7 @@ def _rename_object(self, obj: Any, name: str): torch.float8_e4m3fnuz: "f8e4m3fnuz", torch.float8_e5m2fnuz: "f8e5m2fnuz", torch.float8_e8m0fnu: "f8e8m0fnu", + torch.float4_e2m1fn_x2: "f4e2m1fnx2", torch.complex32: "c32", torch.complex64: "c64", torch.complex128: "c128", @@ -438,7 +439,7 @@ def add_global(name_hint: str, obj: Any): global_name = namespace.create_name(name_hint, obj) if global_name in globals_: - assert globals_[global_name] is obj + assert globals_[global_name] == obj return global_name globals_[global_name] = obj return global_name @@ -1809,10 +1810,14 @@ def forward(self, x): # DCE below will not behave as expected. self.lint() + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + def has_side_effect(node): if is_impure_node is not None: return is_impure_node(node) - return node.is_impure() + return node.is_impure(impure_random) # Reverse iterate so that when we remove a node, any nodes used as an # input to that node have an updated user count that no longer reflects diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 3910020cfad9..57f65acef9b2 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -372,7 +372,10 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: all_src_lines = linecache.getlines(frame_summary.filename) # constituent substrings of the error message - tb_repr = torch._dynamo.disable(traceback.format_exc)() + tb_repr = torch._dynamo.disable( + traceback.format_exc, + reason="do not trace into traceback.format_exc when generating error message", + )() custom_msg = ( "Call using an FX-traced Module, " f"line {err_lineno} of the traced Module's " diff --git a/torch/fx/node.py b/torch/fx/node.py index 722de170bfd5..59a946deec23 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -375,7 +375,41 @@ def prepend(self, x: "Node") -> None: Args: x (Node): The node to put before this node. Must be a member of the same graph. """ - self._prepend(x) + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + if self == x: + log.debug( + "Trying to prepend a node to itself. This behavior has no effect on the graph." + ) + return + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + # compute x._sort_key + psk = x._prev._sort_key + nsk = x._next._sort_key + if len(psk) > len(nsk): + idx: int + *prefix, idx = psk[: len(nsk) + 1] + x._sort_key = (*prefix, idx + 1) + elif len(psk) < len(nsk): + *prefix, idx = nsk[: len(psk) + 1] + x._sort_key = (*prefix, idx - 1) + else: # same length, increase length by 1 + x._sort_key = (*psk, 0) + + def __gt__(self, other: "Node") -> bool: + return self._sort_key > other._sort_key + + def __lt__(self, other: "Node") -> bool: + return self._sort_key < other._sort_key + + def __ge__(self, other: "Node") -> bool: + return self > other or self == other + + def __le__(self, other: "Node") -> bool: + return self < other or self == other @compatibility(is_backward_compatible=True) def append(self, x: "Node") -> None: @@ -386,7 +420,11 @@ def append(self, x: "Node") -> None: Args: x (Node): The node to put after this node. Must be a member of the same graph. """ - self._next._prepend(x) + self._next.prepend(x) + + def _remove_from_list(self) -> None: + p, n = self._prev, self._next + p._next, n._prev = n, p @property def args(self) -> tuple[Argument, ...]: @@ -676,11 +714,14 @@ def maybe_replace_node(n: Node) -> Node: return [n for n in to_process if n not in skipped] @compatibility(is_backward_compatible=False) - def is_impure(self) -> bool: + def is_impure(self, impure_random: bool = True) -> bool: """ Returns whether this op is impure, i.e. if its op is a placeholder or output, or if a call_function or call_module which is impure. + Args: + impure_random (bool): Whether to treat rand op as impure. + Returns: bool: If the op is impure or not. @@ -694,9 +735,10 @@ def is_impure(self) -> bool: # impure since it mutates inputs return True - if getattr(self.target, "_nondeterministic_seeded", False): - # impure since it mutates RNG state - return True + if impure_random: + if getattr(self.target, "_nondeterministic_seeded", False): + # impure since it mutates RNG state + return True return self.target in _side_effectful_functions diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index ce1814dd7f29..e40cb13d5558 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -116,7 +116,6 @@ def __exit__(self, *args): "_numeric_debug_handle", # TODO deprecated "custom", "partitioner_tag", - "arg_kwarg_vals", ] diff --git a/torch/lib/libshm/socket.h b/torch/lib/libshm/socket.h index e3ff98cbc9fb..6b7207eb70a8 100644 --- a/torch/lib/libshm/socket.h +++ b/torch/lib/libshm/socket.h @@ -17,12 +17,12 @@ class Socket { public: int socket_fd; + Socket(const Socket& other) = delete; protected: Socket() { SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0)); } - Socket(const Socket& other) = delete; Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { other.socket_fd = -1; }; @@ -122,7 +122,7 @@ class ManagerServerSocket : public Socket { SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str())); } - virtual ~ManagerServerSocket() { + ~ManagerServerSocket() override { unlink(socket_path.c_str()); } diff --git a/torch/library.h b/torch/library.h index ef92bee6c93b..5f6b94439b84 100644 --- a/torch/library.h +++ b/torch/library.h @@ -884,8 +884,48 @@ class TORCH_API Library final { at::OperatorName _parseNameForLib(const char* name_str) const; }; +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) +void initialize_torch_libraries(); +#endif + namespace detail { +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) +// This is an experimental feature to defer TorchLibraryInit cost to run either +// at model load time, or when a client application explicitly calls +// torch::initialize_torch_libraries(). +// +// This is not thread safe, the client is required to ensure that libraries +// containing TORCH_LIBRARY initializers are loaded in a thread safe manner. +extern std::vector torch_library_initializers; +class TorchLibraryInit final { + private: + using InitFn = void(Library&); + Library::Kind kind; + InitFn* init_function; + const char* ns; + std::optional key; + const char* file; + uint32_t line; + std::unique_ptr lib = nullptr; + + public: + TorchLibraryInit( + Library::Kind kind, + InitFn* fn, + const char* ns, + std::optional k, + const char* file, + uint32_t line) : kind(kind), init_function(fn), ns(ns), key(k), file(file), line(line) { + torch_library_initializers.push_back(this); + } + + void initialize() { + lib = std::unique_ptr(new Library(kind, ns, key, file, line)); + init_function(*lib); + } +}; +#else class TorchLibraryInit final { private: using InitFn = void(Library&); @@ -903,6 +943,7 @@ class TorchLibraryInit final { fn(lib_); } }; +#endif } // namespace detail diff --git a/torch/library.py b/torch/library.py index fbe6f3ea1cd3..4caa6a698f66 100644 --- a/torch/library.py +++ b/torch/library.py @@ -443,6 +443,21 @@ def _del_library( op_defs, registration_handles, ): + import torch.fx + + for op_def in op_defs: + name = op_def + overload_name = "" + if "." in op_def: + name, overload_name = op_def.split(".") + if ( + name, + overload_name, + ) in torch.fx.operator_schemas._SCHEMA_TO_SIGNATURE_CACHE: + del torch.fx.operator_schemas._SCHEMA_TO_SIGNATURE_CACHE[ + (name, overload_name) + ] + captured_impls -= op_impls captured_defs -= op_defs for handle in registration_handles: diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index 7ec94e8189f7..5b09a2c14c24 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs """Defines utilities for interacting with scaled_dot_product_attention""" import math -from typing import Optional, Union +from typing import Optional import torch @@ -31,14 +31,6 @@ def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: return 1.0 / math.sqrt(head_dim_size) -_SUPPORTED_HEAD_DIMS = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] - - -def _supported_head_dim(n: Union[int, torch.SymInt]) -> bool: - """Returns true if the head dim is supported by FlexAttention""" - return n in _SUPPORTED_HEAD_DIMS - - def _validate_sdpa_input( query: torch.Tensor, key: torch.Tensor, diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index da7acb957d96..36c0a18cdd12 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -283,11 +283,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" if kwargs is None: kwargs = {} - if func != torch.nn.functional.scaled_dot_product_attention: - raise NotImplementedError( - "CausalBias only supports scaled_dot_product_attention" - ) - return cls._dispatch(*args, **kwargs) + if func is torch.nn.functional.scaled_dot_product_attention: + return cls._dispatch(*args, **kwargs) + return super().__torch_function__(func, types, args, kwargs) def __repr__(self): # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 6bf74aab2029..8bf87c60d411 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,7 +19,7 @@ _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) -from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input +from torch.nn.attention._utils import _validate_sdpa_input from torch.utils._pytree import tree_map_only @@ -1118,16 +1118,6 @@ def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): f"Expect query and key/value to have the same embedding dimension " f"but got E={query.size(-1)} and E={key.size(-1)}." ) - return - # TODO this config segfaults with Triton without: - # https://github.com/triton-lang/triton/pull/4540 - if not ( - _supported_head_dim(query.size(-1)) and _supported_head_dim(value.size(-1)) - ): - raise ValueError( - f"NYI: Currently non power of 2 embedding dimension are not supported. " - f"Got E={query.size(-1)} and Ev={value.size(-1)}." - ) def _validate_device(query: Tensor, key: Tensor, value: Tensor): diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 564a516a2477..54a2dec94e18 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -66,10 +66,12 @@ class Threshold(Module): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. + .. image:: ../scripts/activation_images/Threshold.png + Examples:: - >>> m = nn.Threshold(0.1, 20) - >>> input = torch.randn(2) + >>> m = nn.Threshold(0, 0.5) + >>> input = torch.arange(-3, 3) >>> output = m(input) """ @@ -674,6 +676,8 @@ class GLU(Module): dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + .. image:: ../scripts/activation_images/GLU.png + Examples:: >>> m = nn.GLU() diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index de8b55575fb3..75d5c91756df 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -154,8 +154,8 @@ class NLLLoss(_WeightedLoss): The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - w_{y_n} x_{n,y_n}, \quad + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \\ + l_n = - w_{y_n} x_{n,y_n}, \\ w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index eb14df567095..7b7f58b5ff5f 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs from typing import Optional -import numpy as np - import torch import torch.nn.functional as F @@ -213,6 +211,8 @@ def conv_unfold_weight_grad_sample( groups, func, ): + import numpy as np + n = input.shape[0] in_channels = input.shape[1] @@ -318,6 +318,9 @@ def unfold3d( >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape torch.Size([3, 32, 120]) """ + + import numpy as np + if len(tensor.shape) != 5: raise ValueError( f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index a38203d2314d..b570b20bd02c 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -50,7 +50,7 @@ def export_compat( verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, - opset_version: int | None = None, + opset_version: int | None = _constants.TORCHLIB_OPSET, custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] @@ -105,8 +105,7 @@ def export_compat( dynamic_shapes_with_export_dim, need_axis_mapping = ( _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes) ) - - registry = _registration.ONNXRegistry.from_torchlib() + registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version) if custom_translation_table is not None: for torch_op, onnx_ops in custom_translation_table.items(): # TODO(justinchuby): Support complex inputs with annotations diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py index 804e93acbd6f..8a715e245597 100644 --- a/torch/onnx/_internal/exporter/_ir_passes.py +++ b/torch/onnx/_internal/exporter/_ir_passes.py @@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None: value.shape = ir.Shape(new_shape) -def add_torchlib_common_imports(model: ir.Model) -> None: +def add_torchlib_common_imports( + model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET +) -> None: """Hack to add torchlib common imports to the model.""" try: @@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None: model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) + rank_func.opset_imports[""] = opset_version is_scalar_func = ir.serde.deserialize_function( common_ops.IsScalar.to_function_proto() ) + is_scalar_func.opset_imports[""] = opset_version model.functions[rank_func.identifier()] = rank_func model.functions[is_scalar_func.identifier()] = is_scalar_func except Exception: diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index ac81d2301cc2..fefc8022d7e8 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -42,6 +42,9 @@ class OnnxDecompMeta: signature: The ONNX signature of the function. When None, the signature is inferred. is_custom: Whether the function is a custom function. is_complex: Whether the function is a function that handles complex valued inputs. + opset_introduced: + The ONNX opset version in which the function was introduced. + Its specifies the minimum ONNX opset version required to use the function. device: The device the function is registered to. If None, it is registered to all devices. skip_signature_inference: Whether to skip signature inference for the function. """ @@ -51,6 +54,7 @@ class OnnxDecompMeta: signature: _schemas.OpSignature | None is_custom: bool = False is_complex: bool = False + opset_introduced: int = 18 device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 skip_signature_inference: bool = False @@ -150,13 +154,14 @@ def opset_version(self) -> int: return self._opset_version @classmethod - def from_torchlib(cls) -> ONNXRegistry: + def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ registry = cls() + registry._opset_version = opset_version for meta in _torchlib_registry.get_torchlib_ops(): registry._register(meta.fx_target, meta) @@ -185,6 +190,7 @@ def from_torchlib(cls) -> ONNXRegistry: logger.exception("Failed to register '%s'. Skipped", qualified_name) continue + registry._cleanup_registry_based_on_opset_version() return registry def _register( @@ -274,5 +280,24 @@ def is_registered(self, target: TorchOp) -> bool: """ return bool(self.get_decomps(target)) + def _cleanup_registry_based_on_opset_version(self) -> None: + """Pick the implementation with the highest opset version valid until the current opset version.""" + cleaned_functions = {} + for target_or_name, decomps in self.functions.items(): + # Filter decompositions to only include those with opset_introduced <= opset_version + decomps = [d for d in decomps if d.opset_introduced <= self.opset_version] + + # Keep only the decomposition with the highest opset_introduced + if decomps: + # Find the maximum opset_introduced + max_opset = max(d.opset_introduced for d in decomps) + + # Keep all decompositions with the maximum opset_introduced + cleaned_functions[target_or_name] = [ + d for d in decomps if d.opset_introduced == max_opset + ] + + self.functions = cleaned_functions + def __repr__(self) -> str: return f"{self.__class__.__name__}(functions={self.functions})" diff --git a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py index e71bdeb0c68e..039eeb3e2fc2 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py +++ b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -30,6 +30,7 @@ def onnx_impl( *, trace_only: bool = False, complex: bool = False, + opset_introduced: int = 18, no_compile: bool = False, private: bool = False, ) -> Callable[[_T], _T]: @@ -74,6 +75,7 @@ def wrapper( fx_target=t, signature=None, is_complex=complex, + opset_introduced=opset_introduced, skip_signature_inference=no_compile, ) ) diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py index d07768f252ba..bff8860fcb1f 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["core", "hop", "symbolic"] +__all__ = ["core", "hop", "nn", "symbolic"] -from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic +from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py new file mode 100644 index 000000000000..4ca21662d696 --- /dev/null +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -0,0 +1,26 @@ +"""torch.ops.aten operators under the `core` module.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# ruff: noqa: TCH001,TCH002 +# flake8: noqa + +from __future__ import annotations + +import math + +from onnxscript.onnx_opset import opset20 as op20 + +import torch +from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +aten = torch.ops.aten + + +@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) +def aten_gelu_opset20( + self: TReal, + approximate: str = "none", +) -> TReal: + """gelu(Tensor self, *, bool approximate=False) -> Tensor""" + return op20.Gelu(self, approximate=approximate) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 2f01b1d683bb..a86cb340082f 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -832,9 +832,6 @@ def _fused_adam( device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_) device_state_steps = cast(list[Tensor], device_state_steps_) - if device.type == "mps": # type: ignore[union-attr] - assert found_inf is None and grad_scale is None - device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 834e0ed10071..9cd0661cac15 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1982,3 +1982,13 @@ def get_all_device_types() -> list[str]: and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) +if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName: + e4m3_type = torch.float8_e4m3fnuz + e5m2_type = torch.float8_e5m2fnuz + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max +else: + e4m3_type = torch.float8_e4m3fn + e5m2_type = torch.float8_e5m2 + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 2a8fc04265c4..6a3e654d9e71 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -442,11 +442,11 @@ def create_tcp_store( TIMEOUT_OVERRIDE["test_join_kwargs"] = 200 -def create_device(interface=None): +def create_device(interface=None, lazy_init: bool = False): if sys.platform == "win32" or interface is None: - return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1") + return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1", lazy_init=lazy_init) else: - return c10d.ProcessGroupGloo.create_device(interface=interface) + return c10d.ProcessGroupGloo.create_device(interface=interface, lazy_init=lazy_init) def get_timeout(test_id) -> int: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24f651020d75..8e32eaa861aa 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1933,6 +1933,10 @@ def get_val(dtype): if torch.cuda.is_available(): inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) + if not dtype.is_signed: + # For unsigned dtypes, negative values are converted. + inputs.append(((S,), -get_val(dtype), {})) + for shape, fill_value, kwargs in inputs: t = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, @@ -2775,7 +2779,7 @@ def error_inputs_ormqr(op_info, device, **kwargs): bool_3 = True bool_4 = True yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError, - error_regex=r"tau.shape\[-1\] must be less than or equal to input.shape\[-1\]") + error_regex=r"tau.shape\[-1\] must be equal to min\(other.shape\[-2\], input.shape\[-1\]\)") def error_inputs_diag(op_info, device, **kwargs): @@ -18386,6 +18390,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), OpInfo('index_fill', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + inplace_variant=torch.Tensor.index_fill_, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18421,6 +18426,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('index_add', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + inplace_variant=torch.Tensor.index_add_, supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18909,12 +18915,12 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), )), OpInfo('full_like', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, + torch.uint16, torch.uint32), supports_out=False, sample_inputs_func=sample_inputs_full_like, supports_autograd=False, - skips=( - )), + ), OpInfo('new_zeros', op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), @@ -19342,6 +19348,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), OpInfo('scatter_add', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + inplace_variant=torch.Tensor.scatter_add_, sample_inputs_func=sample_inputs_scatter_add, error_inputs_func=error_inputs_scatter_and_scatter_add, supports_forward_ad=True, @@ -21506,6 +21513,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): OpInfo( 'scatter_reduce', variant_test_name='sum', + inplace_variant=torch.Tensor.scatter_reduce_, # complex not added to dtypes as complex gradients are not properly handled # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 07e7da55eafc..e114a37b04df 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -4,24 +4,48 @@ checking quantization api and properties of resulting modules. """ -from functorch.experimental import control_flow - import torch -import torch.nn as nn -import torch.nn.functional as F import torch.ao.nn.intrinsic.quantized.dynamic as nniqd import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized.dynamic as nnqd -from torch.ao.nn.intrinsic import _FusedModule import torch.distributed as dist -from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM - -from torch.export import export_for_training +import torch.nn as nn +import torch.nn.functional as F +from functorch.experimental import control_flow +from torch.ao.nn.intrinsic import _FusedModule from torch.ao.quantization import ( - QuantType, + convert, default_dynamic_qat_qconfig, + default_dynamic_qconfig, + default_dynamic_quant_observer, default_embedding_qat_qconfig, + default_observer, + default_per_channel_qconfig, + default_qconfig, default_symmetric_qnnpack_qat_qconfig, + default_weight_observer, + DeQuantStub, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qat_qconfig_mapping, + get_default_qconfig, + get_default_qconfig_mapping, + PerChannelMinMaxObserver, + propagate_qconfig_, + QConfig, + QConfigMapping, + quantize, + quantize_dynamic_jit, + quantize_jit, + QuantStub, + QuantType, + QuantWrapper, +) +from torch.ao.quantization.backend_config import get_executorch_backend_config +from torch.ao.quantization.quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, ) from torch.ao.quantization.quantize_pt2e import ( _convert_to_reference_decomposed_fx, @@ -29,83 +53,75 @@ prepare_pt2e, prepare_qat_pt2e, ) -from torch.ao.quantization.backend_config import ( - get_executorch_backend_config, -) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - XNNPACKQuantizer, get_symmetric_quantization_config, + XNNPACKQuantizer, ) -from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \ - default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ - propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ - get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \ - QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping -from torch.ao.quantization.quantization_mappings import ( - get_default_dynamic_quant_module_mappings, - get_default_qconfig_propagation_list, - get_default_qat_module_mappings, -) -from torch.testing._internal.common_quantized import ( - override_quantized_engine, -) + +from torch.export import export_for_training from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase try: + from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph + # graph mode quantization based on fx from torch.ao.quantization.quantize_fx import ( - prepare_fx, - prepare_qat_fx, convert_fx, convert_to_reference_fx, + prepare_fx, + prepare_qat_fx, ) - from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph - from torch.fx.graph import Node from torch.fx import GraphModule + from torch.fx.graph import Node + HAS_FX = True except ImportError: HAS_FX = False +import contextlib import copy -import io import functools +import io import os import unittest +from typing import Any, Callable, Optional, Union + import numpy as np -from torch.testing import FileCheck -from typing import Callable, Any, Union, Optional import torch._dynamo as torchdynamo import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer -import contextlib +from torch.testing import FileCheck + class NodeSpec: - ''' Used for checking GraphModule Node - ''' + """Used for checking GraphModule Node""" + def __init__(self, op, target): - ''' + """ op: call_function | call_module target: for call_function, target would be a function for call_module, target would be the type of PyTorch module - ''' + """ self.op = op self.target = target @classmethod def call_function(cls, target): - return NodeSpec('call_function', target) + return NodeSpec("call_function", target) @classmethod def call_method(cls, target): - return NodeSpec('call_method', target) + return NodeSpec("call_method", target) @classmethod def call_module(cls, target): - return NodeSpec('call_module', target) + return NodeSpec("call_module", target) def __hash__(self): return hash((self.op, self.target)) @@ -119,8 +135,12 @@ def __eq__(self, other): def __repr__(self): return repr(self.op) + " " + repr(self.target) + def get_supported_device_types(): - return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu'] + return ( + ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"] + ) + def test_only_eval_fn(model, calib_data): r""" @@ -130,7 +150,10 @@ def test_only_eval_fn(model, calib_data): for inp in calib_data: model(*inp) + _default_loss_fn = torch.nn.CrossEntropyLoss() + + def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): r""" Default train function takes a torch.utils.data.Dataset and train the model @@ -153,9 +176,11 @@ def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): correct += (predicted == target).sum().item() return train_loss, correct, total + class AverageMeter: """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f'): + + def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() @@ -173,7 +198,7 @@ def update(self, val, n=1): self.avg = self.sum / self.count def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) @@ -193,10 +218,11 @@ def accuracy(output, target, topk=(1,)): res.append(correct_k.mul_(100.0 / batch_size)) return res + def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): model.train() for cnt, (image, target) in enumerate(data_loader, start=1): - print('.', end='') + print(".", end="") image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) @@ -208,16 +234,19 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_bat return return + def ddp_setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) + def ddp_cleanup(): dist.destroy_process_group() + def run_ddp(rank, world_size, prepared): ddp_setup(rank, world_size) prepared.cuda() @@ -232,24 +261,42 @@ def run_ddp(rank, world_size, prepared): def convert_dynamic(module): convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) + def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) + def _make_conv_test_input( - batch_size, in_channels_per_group, input_feature_map_size, - out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, - W_zero_point, use_bias, use_channelwise, + batch_size, + in_channels_per_group, + input_feature_map_size, + out_channels_per_group, + groups, + kernel_size, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + use_bias, + use_channelwise, ): in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( - X_value_min, X_value_max, - (batch_size, in_channels,) + input_feature_map_size) + X_value_min, + X_value_max, + ( + batch_size, + in_channels, + ) + + input_feature_map_size, + ) X = X_scale * (X_init - X_zero_point).float() X_q = torch.quantize_per_tensor( - X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) + X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8 + ) W_scale = W_scale * out_channels W_zero_point = W_zero_point * out_channels @@ -266,109 +313,132 @@ def _make_conv_test_input( # The operator expects them in the format # (out_channels, in_channels/groups,) + kernel_size W_init = torch.randint( - W_value_min, W_value_max, - (out_channels, in_channels_per_group,) + kernel_size) + W_value_min, + W_value_max, + ( + out_channels, + in_channels_per_group, + ) + + kernel_size, + ) b_init = torch.randint(0, 10, (out_channels,)) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernel_size) W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) - W = W_scales_tensor.reshape(*W_shape) * ( - W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + W = ( + W_scales_tensor.reshape(*W_shape) + * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + ) b = X_scale * W_scales_tensor * b_init.float() W_q = torch.quantize_per_channel( - W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0, - dtype=torch.qint8) + W, + W_scales_tensor.double(), + W_zero_points_tensor.long(), + 0, + dtype=torch.qint8, + ) else: W = W_scale[0] * (W_init - W_zero_point[0]).float() b = X_scale * W_scale[0] * b_init.float() W_q = torch.quantize_per_tensor( - W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) + W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8 + ) return (X, X_q, W, W_q, b if use_bias else None) + def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, - sizes # Infer the size of tensor to do the add + sizes, # Infer the size of tensor to do the add ) X = scale * (X_init - zero_point).float() X_q = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=torch.quint8) + X, scale=scale, zero_point=zero_point, dtype=torch.quint8 + ) return X, X_q + def skipIfNoFBGEMM(fn): - reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.' + reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer." if isinstance(fn, type): - if 'fbgemm' not in torch.backends.quantized.supported_engines: + if "fbgemm" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'fbgemm' not in torch.backends.quantized.supported_engines: + if "fbgemm" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoQNNPACK(fn): - reason = 'Quantized operations require QNNPACK.' + reason = "Quantized operations require QNNPACK." if isinstance(fn, type): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def withQNNPACKBackend(fn): # TODO(future PR): consider combining with skipIfNoQNNPACK, # will require testing of existing callsites - reason = 'Quantized operations require QNNPACK.' + reason = "Quantized operations require QNNPACK." if isinstance(fn, type): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) - with override_quantized_engine('qnnpack'): + with override_quantized_engine("qnnpack"): fn(*args, **kwargs) return wrapper + def skipIfNoONEDNN(fn): - reason = 'Quantized operations require ONEDNN.' + reason = "Quantized operations require ONEDNN." if isinstance(fn, type): - if 'onednn' not in torch.backends.quantized.supported_engines: + if "onednn" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'onednn' not in torch.backends.quantized.supported_engines: + if "onednn" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoONEDNNBF16(fn): - reason = 'Quantized operations require BF16 support.' + reason = "Quantized operations require BF16 support." if isinstance(fn, type): if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): fn.__unittest_skip__ = True @@ -381,24 +451,28 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoX86(fn): - reason = 'Quantized operations require X86.' + reason = "Quantized operations require X86." if isinstance(fn, type): - if 'x86' not in torch.backends.quantized.supported_engines: + if "x86" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'x86' not in torch.backends.quantized.supported_engines: + if "x86" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoDynamoSupport(fn): reason = "dynamo doesn't support." if isinstance(fn, type): @@ -413,8 +487,10 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoInductorSupport(fn): reason = "inductor doesn't support." if isinstance(fn, type): @@ -429,18 +505,23 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + try: import torchvision # noqa: F401 + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + def get_script_module(model, tracing, data): return torch.jit.trace(model, data) if tracing else torch.jit.script(model) + def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): """ Convert lengths to offsets for embedding_bag @@ -464,7 +545,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2 ** n_bit - 1 + max_int = 2**n_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-6) / max_int assert torch.isnan(scales).sum() == 0 @@ -476,7 +557,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): assert torch.isnan(out).sum() == 0 out = out.to(dtype=torch.int32).reshape(w.shape) - if out.device != torch.device('cpu'): + if out.device != torch.device("cpu"): out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) # Scales and zeros for the same q-group should be contiguous, so we can @@ -490,15 +571,15 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): zeros.reshape(zeros.size(0), zeros.size(1), 1), ], 2, - ).transpose(0, 1).contiguous() + ) + .transpose(0, 1) + .contiguous() ) return out, scales_and_zeros -def _group_quantize_tensor_symmetric( - w, n_bit=4, groupsize=32 -): +def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): # W is of shape [K x N] # We transpose W as Quantization is applied on [N x K] w = w.transpose(0, 1).contiguous() @@ -566,26 +647,47 @@ class QuantizationTestCase(TestCase): def setUp(self): super().setUp() self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] - self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)] - self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] - for _ in range(2)] - self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)] - for _ in range(2)] - self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] - for _ in range(2)] - self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - - self.img_data_dict = {1 : self.img_data_1d, - 2 : self.img_data_2d, - 3 : self.img_data_3d} + self.train_data = [ + [ + torch.rand(2, 5, dtype=torch.float), + torch.randint(0, 1, (2,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)] + self.img_data_2d = [ + [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2) + ] + self.img_data_3d = [ + [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2) + ] + self.img_data_1d_train = [ + [ + torch.rand(2, 3, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_2d_train = [ + [ + torch.rand(1, 3, 10, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_3d_train = [ + [ + torch.rand(1, 3, 5, 5, 5, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + + self.img_data_dict = { + 1: self.img_data_1d, + 2: self.img_data_2d, + 3: self.img_data_3d, + } # Quant types that produce statically quantized ops self.static_quant_types = [QuantType.STATIC, QuantType.QAT] @@ -594,75 +696,92 @@ def setUp(self): def checkNoPrepModules(self, module): r"""Checks the module does not contain child - modules for quantization preparation, e.g. - quant, dequant and observer + modules for quantization preparation, e.g. + quant, dequant and observer """ - self.assertFalse(hasattr(module, 'quant')) - self.assertFalse(hasattr(module, 'dequant')) + self.assertFalse(hasattr(module, "quant")) + self.assertFalse(hasattr(module, "dequant")) def checkNoQconfig(self, module): - r"""Checks the module does not contain qconfig - """ - self.assertFalse(hasattr(module, 'qconfig')) + r"""Checks the module does not contain qconfig""" + self.assertFalse(hasattr(module, "qconfig")) for child in module.children(): self.checkNoQconfig(child) def checkHasPrepModules(self, module): r"""Checks the module contains child - modules for quantization preparation, e.g. - quant, dequant and observer + modules for quantization preparation, e.g. + quant, dequant and observer """ - self.assertTrue(hasattr(module, 'module')) - self.assertTrue(hasattr(module, 'quant')) - self.assertTrue(hasattr(module, 'dequant')) + self.assertTrue(hasattr(module, "module")) + self.assertTrue(hasattr(module, "quant")) + self.assertTrue(hasattr(module, "dequant")) - def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None): + def checkObservers( + self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None + ): r"""Checks the module or module's leaf descendants - have observers in preparation for quantization + have observers in preparation for quantization """ if propagate_qconfig_list is None: propagate_qconfig_list = get_default_qconfig_propagation_list() if prepare_custom_config_dict is None: prepare_custom_config_dict = {} - float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + float_to_observed_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) # check if a module is a leaf module, ignoring activation_post_process attribute def is_leaf_module(module): submodule_name_count = 0 for name, _ in module.named_children(): - if name != 'activation_post_process': + if name != "activation_post_process": submodule_name_count += 1 return submodule_name_count == 0 - if hasattr(module, 'qconfig') and module.qconfig is not None and \ - ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) - and type(module) in propagate_qconfig_list) or - type(module) in float_to_observed_module_class_mapping.keys()) and \ - not isinstance(module, torch.ao.quantization.DeQuantStub): - self.assertTrue(hasattr(module, 'activation_post_process'), - 'module: ' + str(type(module)) + ' do not have observer') + if ( + hasattr(module, "qconfig") + and module.qconfig is not None + and ( + ( + is_leaf_module(module) + and not isinstance(module, torch.nn.Sequential) + and type(module) in propagate_qconfig_list + ) + or type(module) in float_to_observed_module_class_mapping.keys() + ) + and not isinstance(module, torch.ao.quantization.DeQuantStub) + ): + self.assertTrue( + hasattr(module, "activation_post_process"), + "module: " + str(type(module)) + " do not have observer", + ) # we don't need to check observers for child modules of the # qat modules - if type(module) not in get_default_qat_module_mappings().values() and \ - type(module) not in float_to_observed_module_class_mapping.values() and \ - not isinstance(module, _FusedModule): + if ( + type(module) not in get_default_qat_module_mappings().values() + and type(module) not in float_to_observed_module_class_mapping.values() + and not isinstance(module, _FusedModule) + ): for child in module.children(): if type(child) in [nn.Dropout]: continue - self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict) + self.checkObservers( + child, propagate_qconfig_list, prepare_custom_config_dict + ) def checkQuantDequant(self, mod): r"""Checks that mod has nn.Quantize and - nn.DeQuantize submodules inserted + nn.DeQuantize submodules inserted """ self.assertEqual(type(mod.quant), nnq.Quantize) self.assertEqual(type(mod.dequant), nnq.DeQuantize) def checkWrappedQuantizedLinear(self, mod): r"""Checks that mod has been swapped for an nnq.Linear - module, the bias is qint32, and that the module - has Quantize and DeQuantize submodules + module, the bias is qint32, and that the module + has Quantize and DeQuantize submodules """ self.assertEqual(type(mod.module), nnq.Linear) self.checkQuantDequant(mod) @@ -672,14 +791,14 @@ def checkQuantizedLinear(self, mod): def checkDynamicQuantizedLinear(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ self.assertEqual(type(mod), nnqd.Linear) self.assertEqual(mod._packed_params.dtype, dtype) def checkDynamicQuantizedLinearRelu(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ self.assertEqual(type(mod), nniqd.LinearReLU) self.assertEqual(mod._packed_params.dtype, dtype) @@ -721,25 +840,35 @@ def check_weight_bias_api(self, ref_model, weight_keys, bias_keys): def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.LSTM type - module, the bias is float. + module, the bias is float. """ - wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } self.assertEqual(type(mod), reference_module_type) for packed_params in mod._all_weight_values: - self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) def checkLinear(self, mod): self.assertEqual(type(mod), torch.nn.Linear) def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ - wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } self.assertEqual(type(mod), reference_module_type) - if hasattr(mod, '_all_weight_values'): + if hasattr(mod, "_all_weight_values"): for packed_params in mod._all_weight_values: - self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) def checkScriptable(self, orig_mod, calib_data, check_save_load=False): scripted = torch.jit.script(orig_mod) @@ -770,20 +899,29 @@ def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data): scripted_output = test_mod(*inp) self.assertEqual(scripted_output, ref_output) - - def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False, - check=True, eval_mode=True, dynamic=False, qconfig=None): + def checkGraphModeOp( + self, + module, + inputs, + quantized_op, + tracing=False, + debug=False, + check=True, + eval_mode=True, + dynamic=False, + qconfig=None, + ): if debug: - print('Testing:', str(module)) - qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} + print("Testing:", str(module)) + qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: - qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig} + qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig} model = get_script_module(module, tracing, inputs[0]).eval() if debug: - print('input graph:', model.graph) + print("input graph:", model.graph) models = {} outputs = {} for debug in [True, False]: @@ -796,31 +934,37 @@ def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=Fa # input data staying constant for comparisons inputs_copy = copy.deepcopy(inputs) models[debug] = quantize_jit( - model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False, - debug=debug) + model, + qconfig_dict, + test_only_eval_fn, + [inputs_copy], + inplace=False, + debug=debug, + ) # make sure it runs outputs[debug] = models[debug](*inputs[0]) if debug: - print('debug graph:', models[True].graph) - print('non debug graph:', models[False].graph) + print("debug graph:", models[True].graph) + print("non debug graph:", models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op - FileCheck().check(quantized_op) \ - .run(models[False].graph) + FileCheck().check(quantized_op).run(models[False].graph) return models[False] def checkGraphModuleNodes( - self, graph_module, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None): - """ Check if GraphModule contains the target node + self, + graph_module, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + ): + """Check if GraphModule contains the target node Args: graph_module: the GraphModule instance we want to check expected_node, expected_node_occurrence, expected_node_list: @@ -831,9 +975,9 @@ def checkGraphModuleNodes( modules = dict(graph_module.named_modules(remove_duplicate=False)) for node in graph_module.graph.nodes: n = None - if node.op == 'call_function' or node.op == 'call_method': + if node.op == "call_function" or node.op == "call_method": n = NodeSpec(node.op, node.target) - elif node.op == 'call_module': + elif node.op == "call_module": n = NodeSpec(node.op, type(modules[node.target])) if n is not None: @@ -844,26 +988,34 @@ def checkGraphModuleNodes( nodes_in_graph[n] = 1 if expected_node is not None: - self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) + - ' not found in the graph module') + self.assertTrue( + expected_node in nodes_in_graph, + "node:" + str(expected_node) + " not found in the graph module", + ) if expected_node_occurrence is not None: for expected_node, occurrence in expected_node_occurrence.items(): if occurrence != 0: self.assertTrue( expected_node in nodes_in_graph, - 'Check failed for node:' + str(expected_node) + - ' not found') + "Check failed for node:" + str(expected_node) + " not found", + ) self.assertTrue( nodes_in_graph[expected_node] == occurrence, - 'Check failed for node:' + str(expected_node) + - ' Expected occurrence:' + str(occurrence) + - ' Found occurrence:' + str(nodes_in_graph[expected_node])) + "Check failed for node:" + + str(expected_node) + + " Expected occurrence:" + + str(occurrence) + + " Found occurrence:" + + str(nodes_in_graph[expected_node]), + ) else: self.assertTrue( expected_node not in nodes_in_graph, - 'Check failed for node:' + str(expected_node) + - ' expected no occurrence but found') + "Check failed for node:" + + str(expected_node) + + " expected no occurrence but found", + ) if expected_node_list is not None: cur_index = 0 @@ -874,20 +1026,21 @@ def checkGraphModuleNodes( cur_index += 1 self.assertTrue( cur_index == len(expected_node_list), - "Check failed for graph:" + - self.printGraphModule(graph_module, print_str=False) + - "Expected ordered list:" + - str(expected_node_list)) + "Check failed for graph:" + + self.printGraphModule(graph_module, print_str=False) + + "Expected ordered list:" + + str(expected_node_list), + ) def printGraphModule(self, graph_module, print_str=True): modules = dict(graph_module.named_modules(remove_duplicate=False)) node_infos = [] for n in graph_module.graph.nodes: - node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) - if n.op == 'call_module': - node_info += ' module type: ' + repr(type(modules[n.target])) + node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) + if n.op == "call_module": + node_info += " module type: " + repr(type(modules[n.target])) node_infos.append(node_info) - str_to_print = '\n'.join(node_infos) + str_to_print = "\n".join(node_infos) if print_str: print(str_to_print) return str_to_print @@ -897,7 +1050,9 @@ def printGraphModule(self, graph_module, print_str=True): def assert_types_for_matched_subgraph_pairs( self, matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], - expected_types: dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]], + expected_types: dict[ + str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]] + ], gm_a: GraphModule, gm_b: GraphModule, ) -> None: @@ -917,16 +1072,16 @@ def assert_types_for_matched_subgraph_pairs( def _get_underlying_op_type( node: Node, gm: GraphModule ) -> Union[Callable, str]: - if node.op == 'call_module': + if node.op == "call_module": mod = getattr(gm, node.target) return type(mod) else: - assert node.op in ('call_function', 'call_method') + assert node.op in ("call_function", "call_method") return node.target self.assertTrue( len(matched_subgraph_pairs) == len(expected_types), - f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}' + f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}", ) for k, v in expected_types.items(): expected_types_a, expected_types_b = v @@ -938,14 +1093,16 @@ def _get_underlying_op_type( act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) - types_match = (exp_type_start_a is act_type_start_a) and \ - (exp_type_end_a is act_type_end_a) and \ - (exp_type_start_b is act_type_start_b) and \ - (exp_type_end_b is act_type_end_b) + types_match = ( + (exp_type_start_a is act_type_start_a) + and (exp_type_end_a is act_type_end_a) + and (exp_type_start_b is act_type_start_b) + and (exp_type_end_b is act_type_end_b) + ) self.assertTrue( types_match, - f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, ' - f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}' + f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, " + f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}", ) def assert_ns_compare_dict_valid( @@ -962,48 +1119,53 @@ def assert_ns_compare_dict_valid( for result_type, layer_data in result_type_to_data.items(): self.assertTrue( len(layer_data) == 2, - f"Layer {layer_name} does not have exactly two model results.") + f"Layer {layer_name} does not have exactly two model results.", + ) model_name_0, model_name_1 = layer_data.keys() for res_idx in range(len(layer_data[model_name_0])): layer_data_0 = layer_data[model_name_0][res_idx] layer_data_1 = layer_data[model_name_1][res_idx] self.assertTrue( - layer_data_0['type'] == layer_data_0['type'], - f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.") + layer_data_0["type"] == layer_data_0["type"], + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.", + ) self.assertTrue( - len(layer_data_0['values']) == - len(layer_data_1['values']), - f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.") + len(layer_data_0["values"]) == len(layer_data_1["values"]), + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.", + ) # F.conv1d weight has rank 3, and toq.conv1d unpacked weight # has rank 4. For now, skip the length check for conv1d only. is_weight_functional_conv1d = ( - result_type == NSSingleResultValuesType.WEIGHT.value and - ( - 'conv1d' in layer_data_0['prev_node_target_type'] or - 'conv1d' in layer_data_1['prev_node_target_type'] + result_type == NSSingleResultValuesType.WEIGHT.value + and ( + "conv1d" in layer_data_0["prev_node_target_type"] + or "conv1d" in layer_data_1["prev_node_target_type"] ) ) if not is_weight_functional_conv1d: - for idx in range(len(layer_data_0['values'])): - values_0 = layer_data_0['values'][idx] - values_1 = layer_data_1['values'][idx] + for idx in range(len(layer_data_0["values"])): + values_0 = layer_data_0["values"][idx] + values_1 = layer_data_1["values"][idx] if isinstance(values_0, torch.Tensor): self.assertTrue( values_0.shape == values_1.shape, - f"Layer {layer_name}, {model_name_0} and {model_name_1} " + - f"have a shape mismatch at idx {idx}.") + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) elif isinstance(values_0, list): values_0 = values_0[0] values_1 = values_1[0] self.assertTrue( values_0.shape == values_1.shape, - f"Layer {layer_name}, {model_name_0} and {model_name_1} " + - f"have a shape mismatch at idx {idx}.") + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) else: - assert isinstance(values_0, tuple), \ - f"unhandled type {type(values_0)}" + assert isinstance( + values_0, tuple + ), f"unhandled type {type(values_0)}" assert len(values_0) == 2 assert len(values_0[1]) == 2 assert values_0[0].shape == values_1[0].shape @@ -1011,80 +1173,91 @@ def assert_ns_compare_dict_valid( assert values_0[1][1].shape == values_1[1][1].shape # verify that ref_node_name is valid - ref_node_name_0 = layer_data_0['ref_node_name'] - ref_node_name_1 = layer_data_1['ref_node_name'] - prev_node_name_0 = layer_data_0['prev_node_name'] - prev_node_name_1 = layer_data_1['prev_node_name'] - if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value: + ref_node_name_0 = layer_data_0["ref_node_name"] + ref_node_name_1 = layer_data_1["ref_node_name"] + prev_node_name_0 = layer_data_0["prev_node_name"] + prev_node_name_1 = layer_data_1["prev_node_name"] + if ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_OUTPUT.value + ): self.assertTrue(ref_node_name_0 == prev_node_name_0) self.assertTrue(ref_node_name_1 == prev_node_name_1) - elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value: + elif ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_INPUT.value + ): self.assertTrue(ref_node_name_0 != prev_node_name_0) self.assertTrue(ref_node_name_1 != prev_node_name_1) def checkGraphModeFxOp( - self, - model, - inputs, - quant_type, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None, - is_reference=False, - print_debug_info=False, - custom_qconfig_dict=None, - prepare_expected_node=None, - prepare_expected_node_occurrence=None, - prepare_expected_node_list=None, - prepare_custom_config=None, - backend_config=None): - """ Quantizes model with graph mode quantization on fx and check if the - quantized model contains the quantized_node - - Args: - model: floating point torch.nn.Module - inputs: one positional sample input arguments for model - expected_node: NodeSpec - e.g. NodeSpec.call_function(torch.quantize_per_tensor) - expected_node_occurrence: a dict from NodeSpec to - expected number of occurrences (int) - e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, - NodeSpec.call_method('dequantize'): 1} - expected_node_list: a list of NodeSpec, used to check the order - of the occurrence of Node - e.g. [NodeSpec.call_function(torch.quantize_per_tensor), - NodeSpec.call_module(nnq.Conv2d), - NodeSpec.call_function(F.hardtanh_), - NodeSpec.call_method('dequantize')] - is_reference: if True, enables reference mode - print_debug_info: if True, prints debug info - custom_qconfig_dict: overrides default qconfig_dict - prepare_expected_node: same as expected_node, but for prepare - prepare_expected_node_occurrence: same as - expected_node_occurrence, but for prepare - prepare_expected_node_list: same as expected_node_list, but - for prepare - - Returns: - A dictionary with the following structure: - { - "prepared": ..., # the prepared model - "quantized": ..., # the quantized non-reference model - "quantized_reference": ..., # the quantized reference model - "result": ..., # the result for either quantized or - # quantized_reference model depending on the - # is_reference argument - } + self, + model, + inputs, + quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + is_reference=False, + print_debug_info=False, + custom_qconfig_dict=None, + prepare_expected_node=None, + prepare_expected_node_occurrence=None, + prepare_expected_node_list=None, + prepare_custom_config=None, + backend_config=None, + ): + """Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurrences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + is_reference: if True, enables reference mode + print_debug_info: if True, prints debug info + custom_qconfig_dict: overrides default qconfig_dict + prepare_expected_node: same as expected_node, but for prepare + prepare_expected_node_occurrence: same as + expected_node_occurrence, but for prepare + prepare_expected_node_list: same as expected_node_list, but + for prepare + + Returns: + A dictionary with the following structure: + { + "prepared": ..., # the prepared model + "quantized": ..., # the quantized non-reference model + "quantized_reference": ..., # the quantized reference model + "result": ..., # the result for either quantized or + # quantized_reference model depending on the + # is_reference argument + } """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: - qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine) + qconfig_mapping = get_default_qat_qconfig_mapping( + torch.backends.quantized.engine + ) model.train() elif quant_type == QuantType.STATIC: - qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine) + qconfig_mapping = get_default_qconfig_mapping( + torch.backends.quantized.engine + ) model.eval() else: qconfig = default_dynamic_qconfig @@ -1098,30 +1271,37 @@ def checkGraphModeFxOp( # overwrite qconfig_dict with custom_qconfig_dict if custom_qconfig_dict is not None: - assert type(custom_qconfig_dict) in (QConfigMapping, dict), \ - 'custom_qconfig_dict should be a QConfigMapping or a dict' + assert type(custom_qconfig_dict) in ( + QConfigMapping, + dict, + ), "custom_qconfig_dict should be a QConfigMapping or a dict" if isinstance(custom_qconfig_dict, QConfigMapping): qconfig_mapping = custom_qconfig_dict else: qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) prepared = prepare( - model, qconfig_mapping, + model, + qconfig_mapping, example_inputs=inputs, prepare_custom_config=prepare_custom_config, - backend_config=backend_config) + backend_config=backend_config, + ) if not quant_type == QuantType.DYNAMIC: prepared(*inputs) if print_debug_info: print() - print('quant type:\n', quant_type) - print('original model:\n', model) + print("quant type:\n", quant_type) + print("original model:\n", model) print() - print('prepared model:\n', prepared) + print("prepared model:\n", prepared) self.checkGraphModuleNodes( - prepared, prepare_expected_node, - prepare_expected_node_occurrence, prepare_expected_node_list) + prepared, + prepare_expected_node, + prepare_expected_node_occurrence, + prepare_expected_node_list, + ) prepared_copy = copy.deepcopy(prepared) qgraph = convert_fx(copy.deepcopy(prepared)) @@ -1134,20 +1314,34 @@ def checkGraphModeFxOp( qgraph_to_check = qgraph_reference if is_reference else qgraph if print_debug_info: print() - print('quantized model:\n', qgraph_to_check) + print("quantized model:\n", qgraph_to_check) self.printGraphModule(qgraph_to_check) print() self.checkGraphModuleNodes( - qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) - return {"prepared": prepared_copy, - "quantized": qgraph_copy, - "quantized_reference": qgraph_reference_copy, - "quantized_output": result, - "quantized_reference_output": result_reference} - - - def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, - set_qconfig, is_emb_bag, dtype=torch.quint8): + qgraph_to_check, + expected_node, + expected_node_occurrence, + expected_node_list, + ) + return { + "prepared": prepared_copy, + "quantized": qgraph_copy, + "quantized_reference": qgraph_reference_copy, + "quantized_output": result, + "quantized_reference_output": result_reference, + } + + def checkEmbeddingSerialization( + self, + qemb, + num_embeddings, + embedding_dim, + indices, + offsets, + set_qconfig, + is_emb_bag, + dtype=torch.quint8, + ): # Test serialization of dynamic EmbeddingBag module using state_dict if is_emb_bag: inputs = [indices, offsets] @@ -1169,33 +1363,49 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic # Check state dict serialization and torch.save APIs if is_emb_bag: - loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum', dtype=dtype) + loaded_qemb = nnq.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + mode="sum", + dtype=dtype, + ) else: - loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) + loaded_qemb = nnq.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype + ) self.check_eager_serialization(qemb, loaded_qemb, inputs) loaded_qemb.load_state_dict(loaded_dict) - self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight), - embedding_unpack(loaded_qemb._packed_params._packed_weight)) - + self.assertEqual( + embedding_unpack(qemb._packed_params._packed_weight), + embedding_unpack(loaded_qemb._packed_params._packed_weight), + ) # Test JIT serialization self.checkScriptable(qemb, [inputs], check_save_load=True) # Test from_float call if is_emb_bag: - float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') + float_embedding = torch.nn.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) else: - float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + float_embedding = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) if set_qconfig: - float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, - qscheme=torch.per_channel_affine_float_qparams, - ch_axis=0) - float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer, - weight=float_qparams_observer) + float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 + ) + float_embedding.qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=float_qparams_observer + ) prepare_dynamic(float_embedding) @@ -1211,6 +1421,7 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic self.assertTrue(expected_name in str(q_embeddingbag)) + class QuantizationLiteTestCase(QuantizationTestCase): def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): # Creates quantized model for testing mobile script modules @@ -1223,9 +1434,7 @@ def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): return model - def _compare_script_and_mobile(self, - model: torch.nn.Module, - input: torch.Tensor): + def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): # Compares the numerical outputs for script and lite modules qengine = "qnnpack" with override_quantized_engine(qengine): @@ -1236,18 +1445,28 @@ def _compare_script_and_mobile(self, for retry in range(1, max_retry + 1): # retries `max_retry` times; breaks iff succeeds else throws exception try: - buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) + buffer = io.BytesIO( + script_module._save_to_buffer_for_lite_interpreter() + ) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) - torch.testing.assert_close(script_module_result, mobile_module_result) + torch.testing.assert_close( + script_module_result, mobile_module_result + ) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_close(script_module_result, mobile_module_forward_result) - - mobile_module_run_method_result = mobile_module.run_method("forward", input) - torch.testing.assert_close(script_module_result, mobile_module_run_method_result) + torch.testing.assert_close( + script_module_result, mobile_module_forward_result + ) + + mobile_module_run_method_result = mobile_module.run_method( + "forward", input + ) + torch.testing.assert_close( + script_module_result, mobile_module_run_method_result + ) except AssertionError as e: if retry == max_retry: raise e @@ -1260,6 +1479,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. """ + _MAP_TO_FX_TRACED_OPS = { torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, @@ -1297,6 +1517,7 @@ def _test_quantizer( m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() if is_qat: @@ -1337,6 +1558,7 @@ def _test_quantizer( m_fx, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() node_occurrence = {} for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): @@ -1344,7 +1566,8 @@ def _test_quantizer( node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] if training_ir_node_occurrence is not None: node_occurrence = { - ns.call_function(k): v for k, v in training_ir_node_occurrence.items() + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() } self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) fx_quant_output = m_fx(*example_inputs) @@ -1355,10 +1578,7 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): # resetting dynamo cache torch._dynamo.reset() - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() if is_qat: m = prepare_qat_pt2e(m, quantizer) else: @@ -1377,14 +1597,18 @@ def forward(self, x): return self.linear(x) quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel) + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() return self._quantize(m, quantizer, example_inputs) + # Below are a series of toy models to use in testing quantization + class SingleLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1397,8 +1621,9 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class AnnotatedSingleLayerLinearModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) @@ -1410,8 +1635,9 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class SingleLayerLinearDynamicModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) @@ -1423,6 +1649,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearAddModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -1438,38 +1665,41 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class RNNDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig - if mod_type == 'GRU': + if mod_type == "GRU": self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) - if mod_type == 'LSTM': + if mod_type == "LSTM": self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x + class RNNCellDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig - if mod_type == 'GRUCell': + if mod_type == "GRUCell": self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) - if mod_type == 'LSTMCell': + if mod_type == "LSTMCell": self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) - if mod_type == 'RNNReLU': - self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float) - if mod_type == 'RNNTanh': - self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float) + if mod_type == "RNNReLU": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float) + if mod_type == "RNNTanh": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x + class LSTMwithHiddenDynamicModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) @@ -1478,6 +1708,7 @@ def forward(self, x, hid): x, hid = self.lstm(x, hid) return x, hid + class ConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1490,6 +1721,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvTransposeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1502,6 +1734,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1519,6 +1752,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvTransposeModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1536,6 +1770,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1550,6 +1785,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1569,6 +1805,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvBnReLUModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1585,8 +1822,9 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvBnReLUModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) @@ -1606,13 +1844,18 @@ def forward(self, x): def fuse_model(self): # TODO: remove this check and define two fuse_modules function on this module if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules_qat( + self, [["conv", "bn", "relu"]], inplace=True + ) else: - torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules( + self, [["conv", "bn", "relu"]], inplace=True + ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class TwoLayerConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1627,6 +1870,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class TwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1641,6 +1885,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearModelWithSubmodule(nn.Module): def __init__(self) -> None: super().__init__() @@ -1655,6 +1900,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.subm.get_example_inputs() + class AnnotatedTwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1670,6 +1916,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class ActivationsTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1686,6 +1933,7 @@ def forward(self, x): x = self.dequant(x) return x + class LinearReluModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1716,6 +1964,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearReluAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1734,6 +1983,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearBnLeakyReluModel(torch.nn.Module): def __init__(self, with_bn=True): super().__init__() @@ -1752,6 +2002,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearTanhModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1766,13 +2017,16 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class ConvBnAddReluModel(torch.nn.Module): - def __init__(self, - with_bn=True, - with_relu=True, - left_conv=True, - two_conv=True, - use_torch_add=True): + def __init__( + self, + with_bn=True, + with_relu=True, + left_conv=True, + two_conv=True, + use_torch_add=True, + ): super().__init__() self.conv = nn.Conv2d(5, 5, (2, 2)) self.conv2 = nn.Conv2d(5, 5, (2, 2)) @@ -1826,6 +2080,7 @@ def forward(self, x1, x2): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) + # TODO: self.fc should be self.conv class ConvReluModel(torch.nn.Module): def __init__(self) -> None: @@ -1840,6 +2095,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + # TODO: self.fc should be self.conv class ConvReluConvModel(torch.nn.Module): def __init__(self) -> None: @@ -1857,6 +2113,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + # TODO: self.fc should be self.conv class ConvReluAddModel(torch.nn.Module): def __init__(self) -> None: @@ -1876,6 +2133,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class NormalizationTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1897,6 +2155,7 @@ def forward(self, x): x = self.instance_norm3d(x.unsqueeze(-1)) return x + class NestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1910,6 +2169,7 @@ def forward(self, x): x = self.fc3(x) return x + class AnnotatedNestedModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1918,7 +2178,7 @@ def __init__(self, qengine): self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) - if qengine == 'fbgemm': + if qengine == "fbgemm": self.sub2.fc1.qconfig = default_per_channel_qconfig else: self.sub2.fc1.qconfig = default_qconfig @@ -1929,6 +2189,7 @@ def forward(self, x): x = self.fc3(x) return x + class AnnotatedSubNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1944,6 +2205,7 @@ def forward(self, x): x = self.fc3(x) return x + class AnnotatedCustomConfigNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1953,12 +2215,11 @@ def __init__(self) -> None: self.fc3.qconfig = default_qconfig self.sub2.qconfig = default_qconfig - custom_options = { - 'dtype': torch.quint8, - 'qscheme': torch.per_tensor_affine - } - custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options), - weight=default_weight_observer) + custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine} + custom_qconfig = QConfig( + activation=default_observer.with_args(**custom_options), + weight=default_weight_observer, + ) self.sub2.fc1.qconfig = custom_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) @@ -1970,6 +2231,7 @@ def forward(self, x): x = self.fc3(x) return x + class QuantSubModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1985,6 +2247,7 @@ def forward(self, x): x = self.fc3(x) return x + class InnerModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2004,14 +2267,14 @@ def fuse_modules(self): if idx >= len(named_children) - 1: break if isinstance(named_children[idx + 1][1], torch.nn.ReLU): - fusable_layers.append([current_name, - named_children[idx + 1][0]]) + fusable_layers.append([current_name, named_children[idx + 1][0]]) # TODO: remove this check and define two fuse_modules function on this module if self.training: torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) else: torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) + class FunctionalLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2024,6 +2287,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class SingleLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2036,6 +2300,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class TwoLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2050,6 +2315,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalLinearAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2065,6 +2331,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalLinearReluModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2078,6 +2345,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear.get_example_inputs() + class FunctionalLinearReluLinearModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2094,6 +2362,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalConv2d(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2105,11 +2374,20 @@ def __init__(self) -> None: self.groups = 1 def forward(self, x): - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class SingleLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2122,6 +2400,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class TwoLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2136,6 +2415,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class FunctionalConvReluModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2149,6 +2429,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv.get_example_inputs() + class FunctionalConvReluConvModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2165,10 +2446,12 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class SkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ + def __init__(self) -> None: super().__init__() self.sub = InnerModule() @@ -2180,10 +2463,12 @@ def forward(self, x): def fuse_modules(self): self.sub.fuse_modules() + class AnnotatedSkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) @@ -2198,9 +2483,10 @@ def forward(self, x): def fuse_modules(self): self.sub.module.fuse_modules() + class QuantStubModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self) -> None: super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") @@ -2213,9 +2499,10 @@ def forward(self, x): x = self.fc(x) return self.dequant(x) + class ManualLinearQATModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) @@ -2230,9 +2517,10 @@ def forward(self, x): x = self.fc2(x) return self.dequant(x) + class ManualDropoutQATModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) @@ -2247,9 +2535,10 @@ def forward(self, x): x = self.dropout(x) return self.dequant(x) + class ManualLinearDynamicQATModel(torch.nn.Module): - r"""A Module that uses a dynamic QAT by default. - """ + r"""A Module that uses a dynamic QAT by default.""" + def __init__(self, qconfig=None): super().__init__() self.qconfig = qconfig or default_dynamic_qat_qconfig @@ -2261,13 +2550,19 @@ def forward(self, x): x = self.fc2(x) return x + class ManualConvLinearQATModel(torch.nn.Module): r"""A module with manually inserted `QuantStub` and `DeQuantStub` and contains both linear and conv modules """ + def __init__(self, qconfig=None): super().__init__() - self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + self.qconfig = ( + qconfig + if qconfig + else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + ) self.quant = QuantStub() self.dequant = DeQuantStub() self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) @@ -2282,30 +2577,38 @@ def forward(self, x): x = self.fc2(x) return self.dequant(x) + class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. Supported only with qnnpack. """ + def __init__(self) -> None: super().__init__(default_symmetric_qnnpack_qat_qconfig) + class ManualEmbeddingBagLinear(nn.Module): def __init__(self) -> None: super().__init__() - self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') + self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum") self.emb.qconfig = default_embedding_qat_qconfig self.quant = QuantStub() self.dequant = DeQuantStub() self.linear = nn.Linear(12, 1).to(dtype=torch.float) self.qconfig = get_default_qat_qconfig("qnnpack") - def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, - per_sample_weights: Optional[torch.Tensor] = None): + def forward( + self, + input: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + per_sample_weights: Optional[torch.Tensor] = None, + ): x = self.emb(input, offsets, per_sample_weights) x = self.quant(x) x = self.linear(x) return self.dequant(x) + class DeFusedEmbeddingBagLinear(nn.Module): r"""A module to simulate QAT embedding bag with a linear layer, this module uses a separate embedding and bagging op, similar @@ -2313,6 +2616,7 @@ class DeFusedEmbeddingBagLinear(nn.Module): https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html """ + def __init__(self) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) @@ -2329,6 +2633,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x = self.linear(x) return self.dequant(x) + class SubModelForFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2350,6 +2655,7 @@ def __init__(self) -> None: def forward(self, x): return self.relu(self.conv(x)) + class ModelForFusion(nn.Module): def __init__(self, qconfig): super().__init__() @@ -2396,14 +2702,14 @@ def forward(self, x): y = self.dequant(y) return x + class ConvBNReLU(nn.Sequential): def __init__(self) -> None: super().__init__( - nn.Conv2d(3, 3, 1, 1, bias=False), - nn.BatchNorm2d(3), - nn.ReLU(inplace=False) + nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False) ) + class ModelWithSequentialFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2428,6 +2734,7 @@ def forward(self, x): x = self.dequant(x) return x + class ModelForFusionWithBias(nn.Module): def __init__(self) -> None: super().__init__() @@ -2449,6 +2756,7 @@ def forward(self, x): x = self.dequant(x) return x + class ModelForLinearBNFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2460,6 +2768,7 @@ def __init__(self) -> None: def forward(self, x): return self.bn(self.fc(x)) + class DummyObserver(torch.nn.Module): def calculate_qparams(self): return 1.0, 0 @@ -2543,9 +2852,14 @@ def forward(self, x): def fuse_model(self): # TODO: remove this check and define two fuse_model function on this module if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules_qat( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) else: - torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + class ModelMultipleOps(torch.nn.Module): def __init__(self) -> None: @@ -2578,6 +2892,7 @@ def forward(self, x): out = self.fc(out) return out + # Model to ensure consistency of fake quant with true quant # Average pooling and mean operations are not modelled # accurately with fake-quant so this model does not @@ -2612,15 +2927,22 @@ def forward(self, x): out = self.fc(out) return out + class EmbeddingBagModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') + self.emb = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) def forward(self, indices, offsets, per_sample_weights): return self.emb(indices, offsets, per_sample_weights) + class EmbeddingModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2629,6 +2951,7 @@ def __init__(self) -> None: def forward(self, indices): return self.emb(indices) + class EmbeddingWithStaticLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2647,9 +2970,11 @@ def forward(self, indices, offsets, linear_in): features = torch.cat([fc] + [emb], dim=1) return features -class DenseTopMLP(nn.Module): - def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None: +class DenseTopMLP(nn.Module): + def __init__( + self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out + ) -> None: super().__init__() self.dense_mlp = nn.Sequential( @@ -2671,16 +2996,18 @@ def forward( out = self.top_mlp(features) return out + # thin wrapper around embedding bag, because tracing inside nn.Embedding # bag is not supported at the moment and this is top level class EmbBagWrapper(nn.Module): def __init__(self, num_embeddings, embedding_dim): super().__init__() - self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum') + self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") def forward(self, indices, offsets): return self.emb_bag(indices, offsets) + class SparseNNModel(nn.Module): _NUM_EMBEDDINGS = 10 _EMBEDDING_DIM = 5 @@ -2695,8 +3022,12 @@ def __init__(self) -> None: self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) self.dense_top = DenseTopMLP( - self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN, - self._TOP_OUT_OUT) + self._DENSE_DIM, + self._DENSE_OUTPUT, + self._EMBEDDING_DIM, + self._TOP_OUT_IN, + self._TOP_OUT_OUT, + ) def forward( self, @@ -2704,12 +3035,12 @@ def forward( sparse_offsets: torch.Tensor, dense: torch.Tensor, ) -> torch.Tensor: - sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) out = self.dense_top(sparse_feature, dense) return out + class TestHelperModules: class ControlFlow(torch.nn.Module): def forward( @@ -2719,7 +3050,6 @@ def forward( pred2: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: - def true_nested(y: torch.Tensor) -> torch.Tensor: y = y + y y = torch.mm(y, y) @@ -2736,7 +3066,10 @@ def false_fn(x: torch.Tensor, _) -> torch.Tensor: return x.cos() def map_fn( - x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor + x: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, ) -> torch.Tensor: x = x.cos() y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) @@ -2747,7 +3080,12 @@ def map_fn( return control_flow.map(map_fn, xs, pred1, pred2, y) def example_inputs(self): - return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),) + return ( + torch.ones(2, 2), + torch.tensor([False]), + torch.tensor([False]), + torch.ones(2, 2), + ) class Conv2dPropAnnotaton(torch.nn.Module): def __init__(self) -> None: @@ -3029,16 +3367,20 @@ def forward(self, x): x = self.relu(self.fc(x)) return x + def _generate_qdq_quantized_model( mod, inputs, is_qat=False, is_dynamic=False, quantizer=None ): - def get_default_quantizer(is_qat, is_dynamic, inputs): - has_xpu = any(isinstance(input, torch.Tensor) and input.device.type == "xpu" - for input in inputs) + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) if has_xpu: quantizer = XPUInductorQuantizer() - assert (not is_qat) and (not is_dynamic), "QAT and dynamic quantization is not supported at XPU backend currently" + assert (not is_qat) and ( + not is_dynamic + ), "QAT and dynamic quantization is not supported at XPU backend currently" quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) else: quantizer = X86InductorQuantizer() @@ -3051,12 +3393,11 @@ def get_default_quantizer(is_qat, is_dynamic, inputs): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: - export_model = export_for_training( - mod, - inputs, - ).module() + export_model = export_for_training(mod, inputs, strict=True).module() quantizer = ( - quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic, inputs) + quantizer + if quantizer + else get_default_quantizer(is_qat, is_dynamic, inputs) ) prepare_model = ( prepare_qat_pt2e(export_model, quantizer) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index afbd569b34ba..01232af5d0d5 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2813,7 +2813,10 @@ def _to_number(self, number_like, *, id): elif isinstance(number_like, Enum): return int(number_like) # type: ignore[call-overload] else: - return super()._to_number(number_like, id=id) + number = super()._to_number(number_like, id=id) + if type(number) not in self._TYPE_TO_DTYPE.keys(): + self._inputs_not_supported() + return number class TensorOrArrayPair(TensorLikePair): diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index c0ce944c641d..cbdb601af614 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -552,8 +552,16 @@ def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs): expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs) expected = tree_map(fwAD.unpack_dual, expected) - expected_primals = tree_map(lambda x: x.primal, expected) - expected_tangents = tree_map(lambda x: x.tangent, expected) + expected_primals = tree_map( + lambda x: x.primal, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + expected_tangents = tree_map( + lambda x: x.tangent, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) # Permutations of arg and kwargs in CCT. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): @@ -586,7 +594,15 @@ def unwrap(e): return e.elem if isinstance(e, CCT) else e actual = tree_map(fwAD.unpack_dual, actual) - actual_primals = tree_map(lambda x: unwrap(x.primal), actual) - actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual) + actual_primals = tree_map( + lambda x: unwrap(x.primal), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + actual_tangents = tree_map( + lambda x: unwrap(x.tangent), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) assert_equal_fn(actual_primals, expected_primals, equal_nan=True) assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 3f4a24a1ffb1..db9f9e70dee1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -96,7 +96,7 @@ import torchvision HAS_TORCHVISION = True -except ImportError: +except Exception: # Covering both ImportError and RuntimeError HAS_TORCHVISION = False if sys.platform == "win32": @@ -8310,50 +8310,14 @@ def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self): def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self): self._test_compute_bucket_assignment_by_size(use_logger=True) - def _determine_expected_error_verify_model_across_rank( - self, group_to_use, diff_num_params=False - ): - # When running with NCCL backend, we don't expect an error on rank 0, - # rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When - # running with Gloo or with debug mode wrapper, we expect the error - # to be caught inline. - # All ranks report same error when there is a # of parameter - # mismatch since we use allgather in the impl. - if diff_num_params: - expected_err = "DDP expects same model across all ranks" - ctx = self.assertRaisesRegex(RuntimeError, expected_err) - return ctx, expected_err - - is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL - if self.rank == 0: - if ( - dist.get_backend(group_to_use) == dist.Backend.NCCL - and not is_detail_dbg_mode - ): - expected_err = "caught collective operation timeout" - ctx = self.assertRaisesRegex(RuntimeError, expected_err) - else: - expected_err = None - ctx = self.assertRaises(RuntimeError) - else: - expected_err = "appears not to match" - ctx = self.assertRaisesRegex(RuntimeError, expected_err) - return ctx, expected_err - def _test_verify_model_across_rank(self, use_logger): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test - # determinism. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=5) ) torch.cuda.set_device(self.rank) - ctx, expected_err = self._determine_expected_error_verify_model_across_rank( - group_to_use - ) # Create a valid model. The constructor initializes the logger that we use later. net = EmbeddingNetDifferentParams(0) @@ -8371,7 +8335,8 @@ def _test_verify_model_across_rank(self, use_logger): net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1) # if we pass a logger we can verify that it was logged - with ctx: + caught = 0 + try: if use_logger: _verify_param_shape_across_processes( net.process_group, list(net.parameters()), net.logger @@ -8380,18 +8345,13 @@ def _test_verify_model_across_rank(self, use_logger): _verify_param_shape_across_processes( net.process_group, list(net.parameters()) ) - # Should only be run by rank 0, and blocking_wait catches and - # reports exception. - dist.barrier(group_to_use) + except Exception: + caught = 1 - # We don't check when self.rank != 0 because the logger doesn't log - # the error "Caught collective operation" as that is not thrown in the reducer. - if use_logger and self.rank != 0: - verify_ddp_error_logged(net, expected_err) - - # Perform gloo-based barrier to ensure one rank doesn't exit test - # early which causes failure with Barrier.sync. - dist.barrier(group_gloo) + # As long as there is one rank catching the exception + t = torch.Tensor([caught]) + dist.all_reduce(t, group=group_gloo) + self.assertGreater(t, 0) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_but_pass_in_sandcastle_if( @@ -8409,20 +8369,19 @@ def test_verify_model_across_rank_with_logger(self): def test_verify_model_across_rank_without_logger(self): self._test_verify_model_across_rank(use_logger=False) - def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo): - with ctx: + def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo): + caught = 0 + try: net = torch.nn.parallel.DistributedDataParallel( net.to(self.rank), device_ids=[self.rank], process_group=ddp_group ) - # Should only be run by rank 0, and blocking_wait catches and - # reports exception. - dist.barrier(ddp_group) - - # can't use verify_ddp_error_logged here because net was never properly constructed + except Exception: + caught = 1 - # Perform gloo-based barrier to ensure one rank doesn't exit test - # early which causes failure with Barrier.sync. - dist.barrier(group_gloo) + # As long as there is one rank catching the exception + t = torch.Tensor([caught]) + dist.all_reduce(t, group=group_gloo) + self.assertGreater(t, 0) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_but_pass_in_sandcastle_if( @@ -8433,21 +8392,15 @@ def test_ddp_model_diff_shape_across_ranks(self): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test - # determinism. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) - ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( - group_to_use - ) # Creates network with different sized embedding table on different # ranks. This should throw an error during DDP init. net = EmbeddingNetDifferentParams(self.rank) self._run_test_ddp_model_with_diff_params( - ctx, net, group_to_use, group_gloo + net, group_to_use, group_gloo ) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @@ -8459,16 +8412,10 @@ def test_ddp_model_diff_num_params_across_ranks(self): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test - # determinism. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) - ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( - group_to_use, diff_num_params=True - ) # Creates network with diff # of param across ranks, reducer should # recognize this and throw appropriate error. @@ -8477,7 +8424,6 @@ def test_ddp_model_diff_num_params_across_ranks(self): ) self._run_test_ddp_model_with_diff_params( - ctx, net, group_to_use, group_gloo, diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 1501a3bfcb36..4461a62bbe57 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -210,6 +210,12 @@ def maybe_skip_size_asserts(op): else: return contextlib.nullcontext() +def get_func_call() -> str: + return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call(" + +def get_kernel_launch() -> str: + return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run(" + def clone_preserve_strides_offset(x, device=None): if not isinstance(x, torch.Tensor): return x diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 26be0b5255ef..822e664270db 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -2327,13 +2327,6 @@ def make_input(): torch_opinfo_name="linalg.vector_norm", supports_out=True, op_db=op_db, - skips=( - # FIXME: sum reduces all dimensions when dim=[] - DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), - DecorateInfo( - unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" - ), - ), ), PythonRefInfo( "_refs.linalg.matrix_norm", diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 433a518feb15..608a6f14389b 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -117,6 +117,32 @@ def add_kernel_autotuned( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=[], + ) + @triton.jit + def sub_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x - y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2), diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index b8d869c1c802..028c21a84bc4 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -23,7 +23,15 @@ from optree import PyTreeSpec as TreeSpec # direct import for type annotations import torch.utils._pytree as python_pytree -from torch.utils._pytree import KeyEntry as KeyEntry +from torch.utils._pytree import ( + is_namedtuple as is_namedtuple, + is_namedtuple_class as is_namedtuple_class, + is_namedtuple_instance as is_namedtuple_instance, + is_structseq as is_structseq, + is_structseq_class as is_structseq_class, + is_structseq_instance as is_structseq_instance, + KeyEntry as KeyEntry, +) __all__ = [ @@ -39,6 +47,7 @@ "keystr", "key_get", "register_pytree_node", + "tree_is_leaf", "tree_flatten", "tree_flatten_with_path", "tree_unflatten", @@ -58,6 +67,12 @@ "treespec_dumps", "treespec_loads", "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", ] diff --git a/torch/utils/_device.py b/torch/utils/_device.py index d7903fc3b465..e16505791b9d 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -24,7 +24,6 @@ def _device_constructors(): torch.fft.fftfreq, torch.fft.rfftfreq, torch.full, - torch.fill, torch.hamming_window, torch.hann_window, torch.kaiser_window, @@ -33,7 +32,6 @@ def _device_constructors(): torch.nested.nested_tensor, # This function doesn't actually take a device argument # torch.normal, - torch.ones, torch.rand, torch.randn, torch.randint, @@ -47,14 +45,12 @@ def _device_constructors(): torch.sparse_bsc_tensor, torch.tril_indices, torch.triu_indices, - torch.vander, torch.zeros, torch.asarray, # weird ones torch.tensor, torch.as_tensor, - torch.scalar_tensor, - torch.asarray, + torch.scalar_tensor } # NB: This is directly called from C++ in torch/csrc/Device.cpp diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 857ea1aab080..9b5d472321e5 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -31,14 +31,17 @@ Any, Callable, cast, + ClassVar, + Final, Generic, + NoReturn, Optional, overload, Protocol, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple +from typing_extensions import deprecated, NamedTuple, Self __all__ = [ @@ -54,6 +57,7 @@ "keystr", "key_get", "register_pytree_node", + "tree_is_leaf", "tree_flatten", "tree_flatten_with_path", "tree_unflatten", @@ -73,6 +77,12 @@ "treespec_dumps", "treespec_loads", "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", ] @@ -195,6 +205,10 @@ def register_pytree_node( ) -> None: """Register a container-like type as pytree node. + Note: + :func:`register_dataclass` is a simpler way of registering a container-like + type as a pytree node. + Args: cls: the type to register flatten_fn: A callable that takes a pytree and returns a flattened @@ -255,14 +269,34 @@ def register_pytree_node( _cxx_pytree_pending_imports.append((args, kwargs)) -def register_dataclass(cls: type[Any]) -> None: - """Registers a ``dataclasses.dataclass`` type as a pytree node. +def register_dataclass( + cls: type[Any], + *, + field_names: Optional[list[str]] = None, + drop_field_names: Optional[list[str]] = None, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a type that has the semantics of a ``dataclasses.dataclass`` type + as a pytree node. This is a simpler API than :func:`register_pytree_node` for registering - a dataclass. + a dataclass or a custom class with the semantics of a dataclass. Args: - cls: the dataclass type to register + cls: The python type to register. The class must have the semantics of a + dataclass; in particular, it must be constructed by passing the fields + in. + field_names (Optional[List[str]]): A list of field names that correspond + to the **non-constant data** in this class. This list must contain + all the fields that are used to initialize the class. This argument + is optional if ``cls`` is a dataclass, in which case the fields will + be taken from ``dataclasses.fields()``. + drop_field_names (Optional[List[str]]): A list of field names that + should not be included in the pytree. + serialized_type_name: A keyword argument used to specify the fully + qualified name used when serializing the tree spec. This is only + needed for serializing the treespec in torch.export. Example: @@ -283,11 +317,67 @@ def register_dataclass(cls: type[Any]) -> None: >>> assert torch.allclose(point.y, torch.tensor(2)) """ - import torch.export + drop_field_names = drop_field_names or [] + + if not dataclasses.is_dataclass(cls): + if field_names is None: + raise ValueError( + "field_names must be specified with a list of all fields used to " + f"initialize {cls}, as it is not a dataclass." + ) + elif field_names is None: + field_names = [f.name for f in dataclasses.fields(cls) if f.init] + else: + dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init} + dataclass_init_fields.difference_update(drop_field_names) + + if dataclass_init_fields != set(field_names): + error_msg = "field_names does not include all dataclass fields.\n" + + if missing := dataclass_init_fields - set(field_names): + error_msg += ( + f"Missing fields in `field_names`: {missing}. If you want " + "to include these fields in the pytree, please add them " + "to `field_names`, otherwise please add them to " + "`drop_field_names`.\n" + ) + + if unexpected := set(field_names) - dataclass_init_fields: + error_msg += ( + f"Unexpected fields in `field_names`: {unexpected}. " + "Please remove these fields, or add them to `drop_field_names`.\n" + ) + + raise ValueError(error_msg) + + def _flatten_fn(obj: Any) -> tuple[list[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for name in field_names: + val = getattr(obj, name) + if val is not None: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] - # Eventually we should move the export code here. It is not specific to export, - # aside from the serialization pieces. - torch.export.register_dataclass(cls) + def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: + flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + _private_register_pytree_node( + cls, + _flatten_fn, + _unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=_flatten_fn_with_keys, + ) CONSTANT_NODES: set[type] = set() @@ -573,6 +663,90 @@ def get(self, obj: Any) -> Any: return getattr(obj, self.name) +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple(obj: Union[object, type]) -> bool: + """Return whether the object is an instance of namedtuple or a subclass of namedtuple.""" + cls = obj if isinstance(obj, type) else type(obj) + return is_namedtuple_class(cls) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple_class(cls: type) -> bool: + """Return whether the class is a subclass of namedtuple.""" + return ( + isinstance(cls, type) + and issubclass(cls, tuple) + and isinstance(getattr(cls, "_fields", None), tuple) + and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined] + and callable(getattr(cls, "_make", None)) + and callable(getattr(cls, "_asdict", None)) + ) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple_instance(obj: object) -> bool: + """Return whether the object is an instance of namedtuple.""" + return is_namedtuple_class(type(obj)) + + +_T_co = TypeVar("_T_co", covariant=True) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +class structseq(tuple[_T_co, ...]): + """A generic type stub for CPython's ``PyStructSequence`` type.""" + + __slots__: ClassVar[tuple[()]] = () + + n_fields: Final[int] # type: ignore[misc] + n_sequence_fields: Final[int] # type: ignore[misc] + n_unnamed_fields: Final[int] # type: ignore[misc] + + def __init_subclass__(cls) -> NoReturn: + """Prohibit subclassing.""" + raise TypeError("type 'structseq' is not an acceptable base type") + + def __new__( + cls: type[Self], + sequence: Iterable[_T_co], + dict: dict[str, Any] = ..., + ) -> Self: + raise NotImplementedError + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq(obj: Union[object, type]) -> bool: + """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.""" + cls = obj if isinstance(obj, type) else type(obj) + return is_structseq_class(cls) + + +# Set if the type allows subclassing (see CPython's Include/object.h) +Py_TPFLAGS_BASETYPE: int = 1 << 10 + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq_class(cls: type) -> bool: + """Return whether the class is a class of PyStructSequence.""" + return ( + isinstance(cls, type) + # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)` + and cls.__bases__ == (tuple,) + # Check PyStructSequence members + and isinstance(getattr(cls, "n_fields", None), int) + and isinstance(getattr(cls, "n_sequence_fields", None), int) + and isinstance(getattr(cls, "n_unnamed_fields", None), int) + # Check the type does not allow subclassing + and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython + ) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq_instance(obj: object) -> bool: + """Return whether the object is an instance of PyStructSequence.""" + return is_structseq_class(type(obj)) + + def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: return list(d), None @@ -807,37 +981,72 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: ) -STANDARD_DICT_TYPES: frozenset[type] = frozenset( - {dict, OrderedDict, defaultdict}, -) +STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) BUILTIN_TYPES: frozenset[type] = frozenset( - {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] + { + tuple, + list, + dict, + namedtuple, # type: ignore[arg-type] + OrderedDict, + defaultdict, + deque, + }, ) -# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +@deprecated( + "torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. " + "Please use torch.utils._pytree.is_namedtuple_instance instead.", + category=FutureWarning, +) def _is_namedtuple_instance(tree: Any) -> bool: - typ = type(tree) - bases = typ.__bases__ - if len(bases) != 1 or bases[0] != tuple: - return False - fields = getattr(typ, "_fields", None) - if not isinstance(fields, tuple): - return False - return all(type(entry) == str for entry in fields) + return is_namedtuple_instance(tree) def _get_node_type(tree: Any) -> Any: - if _is_namedtuple_instance(tree): + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): return namedtuple - return type(tree) + return node_type # A leaf is defined as anything that is not a Node. +def tree_is_leaf( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + """Check if a pytree is a leaf. + + >>> tree_is_leaf(1) + True + >>> tree_is_leaf(None) + True + >>> tree_is_leaf([1, 2, 3]) + False + >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) + True + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + False + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + False + """ + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + +@deprecated( + "torch.utils._pytree._is_leaf is private and will be removed in a future release. " + "Please use torch.utils._pytree.tree_is_leaf instead.", + category=FutureWarning, +) def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: - return (is_leaf is not None and is_leaf(tree)) or _get_node_type( - tree - ) not in SUPPORTED_NODES + return tree_is_leaf(tree, is_leaf=is_leaf) # A TreeSpec represents the structure of a pytree. It holds: @@ -1040,7 +1249,7 @@ def tree_flatten( """ def helper(node: PyTree, leaves: list[Any]) -> TreeSpec: - if _is_leaf(node, is_leaf=is_leaf): + if tree_is_leaf(node, is_leaf=is_leaf): leaves.append(node) return _LEAF_SPEC @@ -1074,7 +1283,7 @@ def tree_iter( is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> Iterable[Any]: """Get an iterator over the leaves of a pytree.""" - if _is_leaf(tree, is_leaf=is_leaf): + if tree_is_leaf(tree, is_leaf=is_leaf): yield tree else: node_type = _get_node_type(tree) @@ -1520,7 +1729,7 @@ def _broadcast_to_and_flatten( ) -> Optional[list[Any]]: assert isinstance(treespec, TreeSpec) - if _is_leaf(tree, is_leaf=is_leaf): + if tree_is_leaf(tree, is_leaf=is_leaf): return [tree] * treespec.num_leaves if treespec.is_leaf(): return None diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 33b4e6e0652d..60e6b37f1340 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -264,6 +264,10 @@ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.log2({self._print(expr.args[0])})" + def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -351,6 +355,10 @@ def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT # use std::pow, that operates on floats def _print_PowByNatural(self, expr: sympy.Expr) -> str: + # Implement the special-case of 2**x for now + base, exp = expr.args + if base == 2: + return f"(1 << ({self._print(exp)}))" raise NotImplementedError( f"_print_PowByNatural not implemented for {type(self)}" ) @@ -465,6 +473,9 @@ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: return f"std::sqrt({self._print(expr.args[0])})" + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + return f"std::log2({self._print(expr.args[0])})" + def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 # TODO: dispatch to llrint depending on index type diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 4d4e115f67b2..1ba4891ebb1c 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -382,7 +382,10 @@ def check_compiler_ok_for_platform(compiler: str) -> bool: # If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag env = os.environ.copy() env['LC_ALL'] = 'C' # Don't localize output - version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + try: + version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + except subprocess.CalledProcessError: + version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) if IS_LINUX: # Check for 'gcc' or 'g++' for sccache wrapper pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) @@ -445,13 +448,17 @@ def get_compiler_abi_compatibility_and_version(compiler) -> tuple[bool, TorchVer warnings.warn(f'Error checking compiler version for {compiler}: {error}') return (False, TorchVersion('0.0.0')) - if tuple(map(int, version)) >= minimum_required_version: - return (True, TorchVersion('.'.join(version))) + # convert alpha-numeric string to numeric string + # amdclang++ returns str like 0.0.0git, others return 0.0.0 + numeric_version = [re.sub(r'\D', '', v) for v in version] + + if tuple(map(int, numeric_version)) >= minimum_required_version: + return (True, TorchVersion('.'.join(numeric_version))) - compiler = f'{compiler} {".".join(version)}' + compiler = f'{compiler} {".".join(numeric_version)}' warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler)) - return (False, TorchVersion('.'.join(version))) + return (False, TorchVersion('.'.join(numeric_version))) def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None: @@ -798,6 +805,7 @@ def unix_wrap_ninja_compile(sources, if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink']) + cuda_dlink_post_cflags = [shlex.quote(f) for f in cuda_dlink_post_cflags] else: cuda_dlink_post_cflags = None @@ -2139,7 +2147,9 @@ def _jit_compile(name, def _get_hipcc_path(): if IS_WINDOWS: - return _join_rocm_home('bin', 'hipcc.bat') + # mypy thinks ROCM_VERSION is None but it will never be None here + hipcc_exe = 'hipcc.exe' if ROCM_VERSION >= (6, 4) else 'hipcc.bat' # type: ignore[operator] + return _join_rocm_home('bin', hipcc_exe) else: return _join_rocm_home('bin', 'hipcc') diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 66a371085b39..15a71c7d7f94 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ functions to be run in multiprocessing. E.g., the data loading worker loop is in `./_utils/worker.py`. """ +from __future__ import annotations import functools import itertools @@ -14,8 +15,8 @@ import queue import threading import warnings -from collections.abc import Iterable -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self import torch import torch.distributed as dist @@ -37,6 +38,9 @@ ) +if TYPE_CHECKING: + from collections.abc import Iterable + __all__ = [ "DataLoader", "get_worker_info", @@ -233,7 +237,7 @@ class DataLoader(Generic[_T_co]): sampler: Union[Sampler, Iterable] pin_memory_device: str prefetch_factor: Optional[int] - _iterator: Optional["_BaseDataLoaderIter"] + _iterator: Optional[_BaseDataLoaderIter] __initialized = False def __init__( @@ -256,7 +260,7 @@ def __init__( persistent_workers: bool = False, pin_memory_device: str = "", in_order: bool = True, - ): + ) -> None: torch._C._log_api_usage_once("python.data_loader") if num_workers < 0: @@ -416,7 +420,7 @@ def __init__( torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] - def _get_iterator(self) -> "_BaseDataLoaderIter": + def _get_iterator(self) -> _BaseDataLoaderIter: if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: @@ -475,9 +479,7 @@ def __setattr__(self, attr, val): super().__setattr__(attr, val) - # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up - # since '_BaseDataLoaderIter' references 'DataLoader'. - def __iter__(self) -> "_BaseDataLoaderIter": + def __iter__(self) -> _BaseDataLoaderIter: # When using a single worker the returned iterator should be # created everytime to avoid resetting its state # However, in the case of a multiple workers iterator @@ -704,7 +706,7 @@ def __init__(self, loader: DataLoader) -> None: self._num_yielded = 0 self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" - def __iter__(self) -> "_BaseDataLoaderIter": + def __iter__(self) -> Self: return self def _reset(self, loader, first_iter=False): diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index a2a6cf1b1afc..567ccdf1ee7a 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -10,16 +10,55 @@ inductor_fallback_ops = { "aten._adaptive_avg_pool2d_backward.default", "aten._adaptive_avg_pool2d.default", - "aten._adaptive_avg_pool3d.default", "aten._adaptive_avg_pool3d_backward.default", + "aten._adaptive_avg_pool3d.default", + "aten._addmm_activation.default", + "aten._cdist_backward.default", + "aten._cdist_forward.default", + "aten._cudnn_rnn.default", + "aten._dyn_quant_matmul_4bit.default", + "aten._dyn_quant_pack_4bit_weight.default", + "aten._efficient_attention_backward.default", + "aten._efficient_attention_forward.default", + "aten._efficientzerotensor.default", + "aten._embedding_bag_dense_backward.default", + "aten._embedding_bag_forward_only.default", + "aten._embedding_bag_per_sample_weights_backward.default", + "aten._embedding_bag.default", + "aten._fft_c2c.default", + "aten._fft_r2c.default", + "aten._flash_attention_backward.default", + "aten._flash_attention_forward.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._histogramdd_from_bin_cts.default", + "aten._int_mm.out", + "aten._pdist_backward.default", + "aten._pdist_forward.default", + "aten._scaled_dot_product_cudnn_attention_backward.default", + "aten._scaled_dot_product_cudnn_attention.default", + "aten._scaled_dot_product_efficient_attention_backward.default", + "aten._scaled_dot_product_efficient_attention.default", + "aten._scaled_dot_product_flash_attention_backward.default", + "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", + "aten._scaled_dot_product_flash_attention_for_cpu.default", + "aten._scaled_dot_product_flash_attention.default", + "aten._scaled_dot_product_fused_attention_overrideable_backward.default", + "aten._scaled_dot_product_fused_attention_overrideable.default", + "aten._scaled_mm.default", + "aten._scaled_mm.out", + "aten._segment_reduce_backward.default", + "aten._thnn_fused_lstm_cell.default", + "aten._to_sparse.default", + "aten._trilinear.default", + "aten._weight_int8pack_mm.default", "aten.adaptive_max_pool2d_backward.default", "aten.adaptive_max_pool2d.default", - "aten.adaptive_max_pool3d.default", "aten.adaptive_max_pool3d_backward.default", + "aten.adaptive_max_pool3d.default", "aten.add.Scalar", "aten.add.Tensor", "aten.addbmm.default", - "aten._addmm_activation.default", "aten.addmm.out", "aten.addmv.default", "aten.angle.default", @@ -33,57 +72,37 @@ "aten.bmm.out", "aten.bucketize.Tensor", "aten.cat.default", - "aten._cdist_backward.default", - "aten._cdist_forward.default", "aten.cholesky_inverse.default", "aten.cholesky_solve.default", "aten.convolution_backward.default", - "aten._cudnn_rnn.default", "aten.convolution.default", "aten.cummax.default", "aten.cummin.default", "aten.cumprod.default", "aten.cumsum.default", - "aten._dyn_quant_matmul_4bit.default", - "aten._dyn_quant_pack_4bit_weight.default", - "aten._efficient_attention_backward.default", - "aten._efficient_attention_forward.default", - "aten._efficientzerotensor.default", - "aten._embedding_bag.default", - "aten._embedding_bag_dense_backward.default", - "aten._embedding_bag_forward_only.default", - "aten._embedding_bag_per_sample_weights_backward.default", "aten.exponential.default", - "aten._fft_c2c.default", - "aten._fft_r2c.default", - "aten._flash_attention_backward.default", - "aten._flash_attention_forward.default", "aten.fractional_max_pool2d_backward.default", "aten.fractional_max_pool2d.default", - "aten.fractional_max_pool3d.default", "aten.fractional_max_pool3d_backward.default", - "aten._fused_moving_avg_obs_fq_helper.default", - "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.fractional_max_pool3d.default", "aten.gcd.default", "aten.geqrf.default", "aten.grid_sampler_2d_backward.default", "aten.histc.default", "aten.histogram.bin_ct", - "aten._histogramdd_from_bin_cts.default", "aten.index_put.default", "aten.index_reduce.default", "aten.index.Tensor", - "aten._int_mm.out", "aten.kthvalue.default", "aten.logcumsumexp.default", "aten.lu_unpack.default", - "aten.masked_select.default", - "aten.masked_scatter.default", "aten.masked_scatter_backward.default", + "aten.masked_scatter.default", + "aten.masked_select.default", "aten.max_pool2d_with_indices_backward.default", "aten.max_pool2d_with_indices.default", - "aten.max_pool3d_with_indices.default", "aten.max_pool3d_with_indices_backward.default", + "aten.max_pool3d_with_indices.default", "aten.max_unpool2d.default", "aten.max_unpool3d.default", "aten.median.default", @@ -93,11 +112,9 @@ "aten.mul.Tensor", "aten.nanmedian.default", "aten.native_dropout.default", - "aten.normal_functional.default", "aten.nonzero.default", + "aten.normal_functional.default", "aten.ormqr.default", - "aten._pdist_backward.default", - "aten._pdist_forward.default", "aten.polar.default", "aten.pow.Scalar", "aten.pow.Tensor_Scalar", @@ -106,8 +123,8 @@ "aten.rand.generator", "aten.randint.default", "aten.randint.generator", - "aten.randint.low", "aten.randint.low_out", + "aten.randint.low", "aten.randn.default", "aten.randn.generator", "aten.randperm.default", @@ -117,36 +134,20 @@ "aten.reshape.default", "aten.resize_.default", "aten.resize_as_.default", - "aten._scaled_dot_product_efficient_attention_backward.default", - "aten._scaled_dot_product_efficient_attention.default", - "aten._scaled_dot_product_flash_attention_backward.default", - "aten._scaled_dot_product_flash_attention.default", - "aten._scaled_dot_product_cudnn_attention_backward.default", - "aten._scaled_dot_product_cudnn_attention.default", - "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", - "aten._scaled_dot_product_flash_attention_for_cpu.default", - "aten._scaled_dot_product_fused_attention_overrideable_backward.default", - "aten._scaled_dot_product_fused_attention_overrideable.default", - "aten._scaled_mm.default", - "aten._scaled_mm.out", "aten.scatter_reduce.two_out", "aten.scatter.src_out", "aten.scatter.value_out", "aten.searchsorted.Scalar", "aten.searchsorted.Tensor", - "aten._segment_reduce_backward.default", "aten.segment_reduce.default", "aten.set_.source_Tensor", "aten.slice.Tensor", "aten.soft_margin_loss_backward.default", "aten.sort.default", "aten.sort.stable", - "aten._thnn_fused_lstm_cell.default", - "aten.topk.default", - "aten._to_sparse.default", "aten.to_sparse.default", + "aten.topk.default", "aten.triangular_solve.default", - "aten._trilinear.default", "aten.uniform.default", "aten.upsample_bicubic2d_backward.default", "aten.upsample_linear1d_backward.default", @@ -154,5 +155,4 @@ "aten.view_as_complex.default", "aten.view_as_real.default", "aten.view.dtype", - "aten._weight_int8pack_mm.default", }