diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt new file mode 100644 index 000000000000..d13e9d756c95 --- /dev/null +++ b/.ci/docker/aotriton_version.txt @@ -0,0 +1,5 @@ +0.6b +manylinux_2_17 +rocm6 +04b5df8c8123f90cba3ede7e971e6fbc6040d506 +3db6ecbc915893ff967abd6e1b43bd5f54949868873be60dc802086c3863e648 diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index fa4dbf2b0165..537b0b9d2ba7 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -91,9 +91,9 @@ _UCC_COMMIT=20eae37090a4ce1b32bcce6144ccad0b49943e0b # configuration, so we hardcode everything here rather than do it # from scratch case "$image" in - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -105,9 +105,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -119,9 +119,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -134,9 +134,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -149,9 +149,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -164,9 +164,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks) + pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=9 PROTOBUF=yes @@ -179,9 +179,9 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9) CUDA_VERSION=11.8.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -193,9 +193,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -207,9 +207,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9) CUDA_VERSION=12.1.1 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -221,9 +221,9 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9) + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9) CUDA_VERSION=12.4.0 - CUDNN_VERSION=8 + CUDNN_VERSION=9 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes @@ -330,10 +330,10 @@ case "$image" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12) + pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12) ANACONDA_PYTHON_VERSION=3.8 CUDA_VERSION=11.8 - CUDNN_VERSION=8 + CUDNN_VERSION=9 CLANG_VERSION=12 PROTOBUF=yes DB=yes @@ -380,7 +380,7 @@ case "$image" in ANACONDA_PYTHON_VERSION=3.9 CONDA_CMAKE=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter) + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter) ANACONDA_PYTHON_VERSION=3.9 CUDA_VERSION=11.8 CONDA_CMAKE=yes @@ -447,7 +447,7 @@ tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') #when using cudnn version 8 install it separately from cuda if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}" - if [[ ${CUDNN_VERSION} == 8 ]]; then + if [[ ${CUDNN_VERSION} == 9 ]]; then IMAGE_NAME="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}" fi fi @@ -499,7 +499,7 @@ docker build \ "$@" \ . -# NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, +# NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to replace the # "$UBUNTU_VERSION" == "18.04-rc" diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index 6cb82a1f770c..bfac9ddd8590 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -113,6 +113,13 @@ COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +# Install AOTriton (Early fail) +COPY ./aotriton_version.txt aotriton_version.txt +COPY ./common/common_utils.sh common_utils.sh +COPY ./common/install_aotriton.sh install_aotriton.sh +RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"] +ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton + # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt index 2df035af1fdd..15f681977a12 100644 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ b/.ci/docker/ci_commit_pins/triton-rocm.txt @@ -1 +1 @@ -bbe6246e37d8aa791c67daaf9d9d61b26c9ccfdc +01cbe5045a6898c9a925f01435c8277b2fe6afcc diff --git a/.ci/docker/common/install_aotriton.sh b/.ci/docker/common/install_aotriton.sh new file mode 100755 index 000000000000..da3fe468d3e8 --- /dev/null +++ b/.ci/docker/common/install_aotriton.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +set -ex + +source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" + +TARBALL='aotriton.tar.bz2' +# This read command alwasy returns with exit code 1 +read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true +ARCH=$(uname -m) +AOTRITON_INSTALL_PREFIX="$1" +AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}.tar.bz2" + +cd "${AOTRITON_INSTALL_PREFIX}" +# Must use -L to follow redirects +curl -L --retry 3 -o "${TARBALL}" "${AOTRITON_URL}" +ACTUAL_SHA256=$(sha256sum "${TARBALL}" | cut -d " " -f 1) +if [ "${SHA256}" != "${ACTUAL_SHA256}" ]; then + echo -n "Error: The SHA256 of downloaded tarball is ${ACTUAL_SHA256}," + echo " which does not match the expected value ${SHA256}." + exit +fi +tar xf "${TARBALL}" && rm -rf "${TARBALL}" diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index ebaa17878ade..fd58ad8a60b8 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -3,7 +3,7 @@ set -ex install_ubuntu() { - # NVIDIA dockers for RC releases use tag names like `11.0-cudnn8-devel-ubuntu18.04-rc`, + # NVIDIA dockers for RC releases use tag names like `11.0-cudnn9-devel-ubuntu18.04-rc`, # for this case we will set UBUNTU_VERSION to `18.04-rc` so that the Dockerfile could # find the correct image. As a result, here we have to check for # "$UBUNTU_VERSION" == "18.04"* diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index 3afd2f28841f..60f4561d420c 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -1,23 +1,18 @@ #!/bin/bash -if [[ ${CUDNN_VERSION} == 8 ]]; then +if [[ -n "${CUDNN_VERSION}" ]]; then # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement mkdir tmp_cudnn pushd tmp_cudnn - if [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.9.7.29_cuda12-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz - elif [[ ${CUDA_VERSION:0:4} == "12.1" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.9.2.26_cuda12-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz - elif [[ ${CUDA_VERSION:0:4} == "11.8" ]]; then - CUDNN_NAME="cudnn-linux-x86_64-8.7.0.84_cuda11-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.7.0/local_installers/11.8/${CUDNN_NAME}.tar.xz + if [[ ${CUDA_VERSION:0:2} == "12" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda12-archive" + elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" else print "Unsupported CUDA version ${CUDA_VERSION}" exit 1 fi - + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz tar xf ${CUDNN_NAME}.tar.xz cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/ cp -a ${CUDNN_NAME}/lib/* /usr/local/cuda/lib64/ diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index a1a5fde7d2f5..a91c798fcdf2 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -30,10 +30,10 @@ pip_install \ pip_install coloredlogs packaging -pip_install onnxruntime==1.17.0 -pip_install onnx==1.15.0 +pip_install onnxruntime==1.18 +pip_install onnx==1.16.0 # pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps -pip_install onnxscript==0.1.0.dev20240315 --no-deps +pip_install onnxscript==0.1.0.dev20240523 --no-deps # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index f96ee5e3b107..3b2bbea0097a 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -139,7 +139,7 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm ARG CUDNN_VERSION ARG CUDA_VERSION COPY ./common/install_cudnn.sh install_cudnn.sh -RUN if [ "${CUDNN_VERSION}" -eq 8 ]; then bash install_cudnn.sh; fi +RUN if [ -n "${CUDNN_VERSION}" ]; then bash install_cudnn.sh; fi RUN rm install_cudnn.sh # Install CUSPARSELT @@ -152,7 +152,7 @@ RUN rm install_cusparselt.sh RUN if [ -h /usr/local/cuda-11.6/cuda-11.6 ]; then rm /usr/local/cuda-11.6/cuda-11.6; fi RUN if [ -h /usr/local/cuda-11.7/cuda-11.7 ]; then rm /usr/local/cuda-11.7/cuda-11.7; fi RUN if [ -h /usr/local/cuda-12.1/cuda-12.1 ]; then rm /usr/local/cuda-12.1/cuda-12.1; fi -RUN if [ -h /usr/local/cuda-12.1/cuda-12.4 ]; then rm /usr/local/cuda-12.1/cuda-12.4; fi +RUN if [ -h /usr/local/cuda-12.4/cuda-12.4 ]; then rm /usr/local/cuda-12.4/cuda-12.4; fi USER jenkins CMD ["bash"] diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index cc43d9ec2414..ee9ede8ba611 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -105,6 +105,13 @@ COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +# Install AOTriton +COPY ./aotriton_version.txt aotriton_version.txt +COPY ./common/common_utils.sh common_utils.sh +COPY ./common/install_aotriton.sh install_aotriton.sh +RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"] +ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton + # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 190f99204e9c..d8eb45ee1d95 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -264,6 +264,18 @@ elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi +# temp workarounds for https://github.com/pytorch/pytorch/issues/126692, remove when fixed +if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then + pushd test + CUDA_VERSION=$(python -c "import torch; print(torch.version.cuda)") + if [ "$CUDA_VERSION" == "12.4" ]; then + ISCUDA124="cu124" + else + ISCUDA124="" + fi + popd +fi + test_python_legacy_jit() { time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose assert_git_not_dirty @@ -356,7 +368,7 @@ test_inductor_cpp_wrapper_abi_compatible() { echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1" # cpu stack allocation causes segfault and needs more investigation - python test/run_test.py --include inductor/test_cpu_cpp_wrapper + PYTORCH_TESTING_DEVICE_ONLY_FOR="" python test/run_test.py --include inductor/test_cpu_cpp_wrapper python test/run_test.py --include inductor/test_cuda_cpp_wrapper TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \ @@ -364,7 +376,7 @@ test_inductor_cpp_wrapper_abi_compatible() { --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/inductor_timm_training.csv" } # "Global" flags for inductor benchmarking controlled by TEST_CONFIG @@ -526,10 +538,10 @@ test_single_dynamo_benchmark() { --output "$TEST_REPORTS_DIR/${name}_${suite}.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/${TEST_CONFIG}_${name}.csv" python benchmarks/dynamo/check_graph_breaks.py \ --actual "$TEST_REPORTS_DIR/${name}_$suite.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/${TEST_CONFIG}_${name}.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/${TEST_CONFIG}_${name}.csv" fi } @@ -553,7 +565,11 @@ test_dynamo_benchmark() { test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@" else if [[ "${TEST_CONFIG}" == *cpu_inductor* ]]; then - test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 "$@" + if [[ "${TEST_CONFIG}" == *freezing* ]]; then + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 --freezing "$@" + else + test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 "$@" + fi elif [[ "${TEST_CONFIG}" == *aot_inductor* ]]; then test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@" else @@ -572,9 +588,11 @@ test_inductor_torchbench_smoketest_perf() { --bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/inductor_torchbench_inference.csv" python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \ --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \ @@ -589,7 +607,13 @@ test_inductor_torchbench_smoketest_perf() { # https://github.com/pytorch/pytorch/actions/runs/7158691360/job/19491437314, # and thus we lower its threshold to reduce flakiness. If this continues to be a problem, # we switch to use some other model. - python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_inference_smoketest.csv" -t 4.9 + # Use 4.7 for cuda 12.4, change back to 4.9 after fixing https://github.com/pytorch/pytorch/issues/126692 + if [ "$CUDA_VERSION" == "12.4" ]; then + THRESHOLD=4.7 + else + THRESHOLD=4.9 + fi + python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_inference_smoketest.csv" -t $THRESHOLD # Check memory compression ratio for a few models for test in hf_Albert timm_vision_transformer; do @@ -608,7 +632,7 @@ test_inductor_torchbench_smoketest_perf() { --only $test --output "$TEST_REPORTS_DIR/inductor_warm_start_smoketest_$test.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_warm_start_smoketest_$test.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv" + --expected "benchmarks/dynamo/ci_expected_accuracy/${ISCUDA124}/inductor_huggingface_training.csv" done } diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 287423641d77..a45a2c9754ba 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -76,8 +76,8 @@ TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) # Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then - # Only linux Python < 3.12 are supported wheels for triton - TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.12'" + # Only linux Python < 3.13 are supported wheels for triton + TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'" TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) diff --git a/.clang-tidy b/.clang-tidy index fef154d4b0c1..1f7521ce7600 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -62,4 +62,6 @@ readability-string-compare, ' HeaderFilterRegex: '^(aten/|c10/|torch/).*$' WarningsAsErrors: '*' +CheckOptions: + misc-header-include-cycle.IgnoredFilesList: 'format.h;ivalue.h;custom_class.h;Dict.h;List.h' ... diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 569facc32cdf..f41a70ada6af 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -1,9 +1,12 @@ self-hosted-runner: labels: + # GitHub hosted x86 Linux runners - linux.20_04.4x - linux.20_04.16x - - linux.large + # Repo-specific LF hosted ARC runners - linux.large.arc + # Organization-wide AWS Linux Runners + - linux.large - linux.2xlarge - linux.4xlarge - linux.12xlarge @@ -13,18 +16,34 @@ self-hosted-runner: - linux.8xlarge.nvidia.gpu - linux.16xlarge.nvidia.gpu - linux.g5.4xlarge.nvidia.gpu + # Organization-wide AWS Linux Runners on Linux Foundation account + - lf.linux.large + - lf.linux.2xlarge + - lf.linux.4xlarge + - lf.linux.12xlarge + - lf.linux.24xlarge + - lf.linux.arm64.2xlarge + - lf.linux.4xlarge.nvidia.gpu + - lf.linux.8xlarge.nvidia.gpu + - lf.linux.16xlarge.nvidia.gpu + - lf.linux.g5.4xlarge.nvidia.gpu + # Repo-specific IBM hosted S390x runner - linux.s390x + # Organization wide AWS Windows runners - windows.4xlarge.nonephemeral - windows.8xlarge.nvidia.gpu - windows.8xlarge.nvidia.gpu.nonephemeral - windows.g5.4xlarge.nvidia.gpu - - bm-runner + # Organization-wide AMD hosted MI300 runners - linux.rocm.gpu + # Repo-specific Apple hosted runners + - macos-m1-ultra + - macos-m2-14 + # Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors) - macos-m1-stable - macos-m1-13 - macos-m1-14 - - macos-12-xl - - macos-12 - - macos12.3-m1 + # GitHub-hosted MacOS runners - macos-latest-xlarge - macos-13-xlarge + - macos-14-xlarge diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 98cd949f9713..a8141b25ecdd 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -1980f8af5bcd0bb2ce51965cf79d8d4c25dad8a0 +b829e936f7cc61b48149f5f957a451a38bf2a178 diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml new file mode 100644 index 000000000000..eb7288ea56bf --- /dev/null +++ b/.github/lf-canary-scale-config.yml @@ -0,0 +1,154 @@ +# Defines runner types that will be provisioned by by LF Self-hosted +# runners for pytorch/pytorch-canary and their labels. +# +# Runners listed here will be available as self hosted runners. +# Configuration is directly pulled from the main branch. +# +# Default values: +# +# runner_types: +# runner_label: # label to specify in the Github Actions workflow +# instance_type: m4.large +# os: linux +# max_available: 20 +# disk_size: 50 +# is_ephemeral: true + +runner_types: + lf.c.linux.12xlarge: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.c.linux.24xl.spr-metal: + disk_size: 200 + instance_type: c7i.metal-24xl + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.16xlarge.spr: + disk_size: 200 + instance_type: c7i.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.12xlarge.ephemeral: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: true + max_available: 300 + os: linux + lf.c.linux.16xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.24xlarge: + disk_size: 150 + instance_type: c5.24xlarge + is_ephemeral: false + max_available: 250 + os: linux + lf.c.linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: false + max_available: 3120 + os: linux + lf.c.linux.4xlarge: + disk_size: 150 + instance_type: c5.4xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.c.linux.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.4xlarge + is_ephemeral: false + max_available: 520 + os: linux + lf.c.linux.8xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.8xlarge + is_ephemeral: false + max_available: 400 + os: linux + lf.c.linux.g4dn.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.12xlarge + is_ephemeral: false + max_available: 50 + os: linux + lf.c.linux.g4dn.metal.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.metal + is_ephemeral: false + max_available: 30 + os: linux + lf.c.linux.g5.48xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.48xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.c.linux.g5.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.12xlarge + is_ephemeral: false + max_available: 150 + os: linux + lf.c.linux.g5.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 1200 + os: linux + lf.c.linux.large: + disk_size: 15 + instance_type: c5.large + is_ephemeral: false + os: linux + lf.c.linux.arm64.2xlarge: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: false + max_available: 200 + os: linux + lf.c.linux.arm64.m7g.2xlarge: + disk_size: 256 + instance_type: m7g.2xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.c.windows.4xlarge: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: true + max_available: 420 + os: windows + lf.c.windows.4xlarge.nonephemeral: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: false + max_available: 420 + os: windows + lf.c.windows.8xlarge.nvidia.gpu: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: true + max_available: 150 + os: windows + lf.c.windows.8xlarge.nvidia.gpu.nonephemeral: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: false + max_available: 150 + os: windows + lf.c.windows.g5.4xlarge.nvidia.gpu: + disk_size: 256 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 250 + os: windows diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml new file mode 100644 index 000000000000..7977d7c15c2f --- /dev/null +++ b/.github/lf-scale-config.yml @@ -0,0 +1,154 @@ +# Defines runner types that will be provisioned by by LF Self-hosted +# runners for pytorch/pytorch and their labels. +# +# Runners listed here will be available as self hosted runners. +# Configuration is directly pulled from the main branch. +# +# Default values: +# +# runner_types: +# runner_label: # label to specify in the Github Actions workflow +# instance_type: m4.large +# os: linux +# max_available: 20 +# disk_size: 50 +# is_ephemeral: true + +runner_types: + lf.linux.12xlarge: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.linux.24xl.spr-metal: + disk_size: 200 + instance_type: c7i.metal-24xl + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.16xlarge.spr: + disk_size: 200 + instance_type: c7i.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.12xlarge.ephemeral: + disk_size: 200 + instance_type: c5.12xlarge + is_ephemeral: true + max_available: 300 + os: linux + lf.linux.16xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.16xlarge + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.24xlarge: + disk_size: 150 + instance_type: c5.24xlarge + is_ephemeral: false + max_available: 250 + os: linux + lf.linux.2xlarge: + disk_size: 150 + instance_type: c5.2xlarge + is_ephemeral: false + max_available: 3120 + os: linux + lf.linux.4xlarge: + disk_size: 150 + instance_type: c5.4xlarge + is_ephemeral: false + max_available: 1000 + os: linux + lf.linux.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.4xlarge + is_ephemeral: false + max_available: 520 + os: linux + lf.linux.8xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g3.8xlarge + is_ephemeral: false + max_available: 400 + os: linux + lf.linux.g4dn.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.12xlarge + is_ephemeral: false + max_available: 50 + os: linux + lf.linux.g4dn.metal.nvidia.gpu: + disk_size: 150 + instance_type: g4dn.metal + is_ephemeral: false + max_available: 30 + os: linux + lf.linux.g5.48xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.48xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.linux.g5.12xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.12xlarge + is_ephemeral: false + max_available: 150 + os: linux + lf.linux.g5.4xlarge.nvidia.gpu: + disk_size: 150 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 1200 + os: linux + lf.linux.large: + disk_size: 15 + instance_type: c5.large + is_ephemeral: false + os: linux + lf.linux.arm64.2xlarge: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: false + max_available: 200 + os: linux + lf.linux.arm64.m7g.2xlarge: + disk_size: 256 + instance_type: m7g.2xlarge + is_ephemeral: false + max_available: 20 + os: linux + lf.windows.4xlarge: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: true + max_available: 420 + os: windows + lf.windows.4xlarge.nonephemeral: + disk_size: 256 + instance_type: c5d.4xlarge + is_ephemeral: false + max_available: 420 + os: windows + lf.windows.8xlarge.nvidia.gpu: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: true + max_available: 150 + os: windows + lf.windows.8xlarge.nvidia.gpu.nonephemeral: + disk_size: 256 + instance_type: p3.2xlarge + is_ephemeral: false + max_available: 150 + os: windows + lf.windows.g5.4xlarge.nvidia.gpu: + disk_size: 256 + instance_type: g5.4xlarge + is_ephemeral: false + max_available: 250 + os: windows diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index db0ec3c51aa7..d69fff16f305 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -245,6 +245,7 @@ - torch/xpu/** - test/xpu/** - third_party/xpu.txt + - .ci/docker/ci_commit_pins/triton-xpu.txt approved_by: - EikanWang - jgong5 diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index d54346f81650..0d624788fc61 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -8,6 +8,7 @@ ciflow_push_tags: - ciflow/inductor - ciflow/inductor-perf-compare - ciflow/inductor-micro-benchmark +- ciflow/inductor-cu124 - ciflow/linux-aarch64 - ciflow/mps - ciflow/nightly diff --git a/.github/requirements/conda-env-Linux-X64.txt b/.github/requirements/conda-env-Linux-X64.txt index 78534c21e911..e0b7177e39c4 100644 --- a/.github/requirements/conda-env-Linux-X64.txt +++ b/.github/requirements/conda-env-Linux-X64.txt @@ -4,6 +4,5 @@ mkl-include=2022.1.0 ninja=1.10.2 numpy=1.23.3 pyyaml=6.0 -requests=2.31.0 setuptools=68.2.2 typing-extensions=4.9.0 diff --git a/.github/requirements/conda-env-iOS.txt b/.github/requirements/conda-env-iOS.txt index a88a16dba4df..fe67c6cbc312 100644 --- a/.github/requirements/conda-env-iOS.txt +++ b/.github/requirements/conda-env-iOS.txt @@ -3,6 +3,5 @@ cmake=3.22.1 ninja=1.10.2 numpy=1.23.3 pyyaml=6.0 -requests=2.31.0 setuptools=68.2.2 typing-extensions=4.9.0 diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py index 21b86fefa1a8..c2676ae09ea7 100644 --- a/.github/scripts/delete_old_branches.py +++ b/.github/scripts/delete_old_branches.py @@ -2,6 +2,7 @@ import os import re from datetime import datetime +from functools import lru_cache from pathlib import Path from typing import Any, Callable, Dict, List, Set @@ -187,6 +188,17 @@ def get_recent_prs() -> Dict[str, Any]: return prs_by_branch_base +@lru_cache(maxsize=1) +def get_open_prs() -> List[Dict[str, Any]]: + return paginate_graphql( + GRAPHQL_OPEN_PRS, + {"owner": "pytorch", "repo": "pytorch"}, + lambda data: False, + lambda res: res["data"]["repository"]["pullRequests"]["nodes"], + lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"], + ) + + def get_branches_with_magic_label_or_open_pr() -> Set[str]: pr_infos: List[Dict[str, Any]] = paginate_graphql( GRAPHQL_NO_DELETE_BRANCH_LABEL, @@ -196,15 +208,7 @@ def get_branches_with_magic_label_or_open_pr() -> Set[str]: lambda res: res["data"]["repository"]["label"]["pullRequests"]["pageInfo"], ) - pr_infos.extend( - paginate_graphql( - GRAPHQL_OPEN_PRS, - {"owner": "pytorch", "repo": "pytorch"}, - lambda data: False, - lambda res: res["data"]["repository"]["pullRequests"]["nodes"], - lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"], - ) - ) + pr_infos.extend(get_open_prs()) # Get the most recent PR for each branch base (group gh together) branch_bases = set() @@ -270,5 +274,41 @@ def delete_branches() -> None: delete_branch(git_repo, branch) +def delete_old_ciflow_tags() -> None: + # Deletes ciflow tags if they are associated with a closed PR or a specific + # commit. Lightweight tags don't have information about the date they were + # created, so we can't check how old they are. The script just assumes that + # ciflow tags should be deleted regardless of creation date. + git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True) + + def delete_tag(tag: str) -> None: + print(f"Deleting tag {tag}") + ESTIMATED_TOKENS[0] += 1 + delete_branch(git_repo, f"refs/tags/{tag}") + + tags = git_repo._run_git("tag").splitlines() + open_pr_numbers = [x["number"] for x in get_open_prs()] + + for tag in tags: + try: + if ESTIMATED_TOKENS[0] > 400: + print("Estimated tokens exceeded, exiting") + break + if not tag.startswith("ciflow/"): + continue + re_match_pr = re.match(r"^ciflow\/.*\/(\d{5,6})$", tag) + re_match_sha = re.match(r"^ciflow\/.*\/([0-9a-f]{40})$", tag) + if re_match_pr: + pr_number = int(re_match_pr.group(1)) + if pr_number in open_pr_numbers: + continue + delete_tag(tag) + elif re_match_sha: + delete_tag(tag) + except Exception as e: + print(f"Failed to check tag {tag}: {e}") + + if __name__ == "__main__": delete_branches() + delete_old_ciflow_tags() diff --git a/.github/scripts/docathon-label-sync.py b/.github/scripts/docathon-label-sync.py new file mode 100644 index 000000000000..a10c3c3f886c --- /dev/null +++ b/.github/scripts/docathon-label-sync.py @@ -0,0 +1,52 @@ +import os +import re +import sys + +from github import Github + + +def main() -> None: + token = os.environ.get("GITHUB_TOKEN") + + repo_owner = "pytorch" + repo_name = "pytorch" + pull_request_number = int(sys.argv[1]) + + g = Github(token) + repo = g.get_repo(f"{repo_owner}/{repo_name}") + pull_request = repo.get_pull(pull_request_number) + pull_request_body = pull_request.body + # PR without description + if pull_request_body is None: + return + + # get issue number from the PR body + if not re.search(r"#\d{1,6}", pull_request_body): + print("The pull request does not mention an issue.") + return + issue_number = int(re.findall(r"#(\d{1,6})", pull_request_body)[0]) + issue = repo.get_issue(issue_number) + issue_labels = issue.labels + docathon_label_present = any( + label.name == "docathon-h1-2024" for label in issue_labels + ) + + # if the issue has a docathon label, add all labels from the issue to the PR. + if not docathon_label_present: + print("The 'docathon-h1-2024' label is not present in the issue.") + return + pull_request_labels = pull_request.get_labels() + pull_request_label_names = [label.name for label in pull_request_labels] + issue_label_names = [label.name for label in issue_labels] + labels_to_add = [ + label for label in issue_label_names if label not in pull_request_label_names + ] + if not labels_to_add: + print("The pull request already has the same labels.") + return + pull_request.add_to_labels(*labels_to_add) + print("Labels added to the pull request!") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index b192475f72b1..920ca65fbf52 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -19,7 +19,7 @@ CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.1": "12.1.1", "12.4": "12.4.0"} -CUDA_ARCHES_CUDNN_VERSION = {"11.8": "8", "12.1": "8", "12.4": "8"} +CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.1": "9", "12.4": "9"} ROCM_ARCHES = ["6.0", "6.1"] @@ -42,7 +42,7 @@ "nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -55,7 +55,7 @@ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950 "nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -68,7 +68,7 @@ "nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " - "nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | " @@ -347,6 +347,10 @@ def generate_wheels_matrix( for python_version in python_versions: for arch_version in arches: gpu_arch_type = arch_type(arch_version) + # Disable py3.12 builds for ROCm because of triton dependency + # on llnl-hatchet, which doesn't have py3.12 wheels available + if gpu_arch_type == "rocm" and python_version == "3.12": + continue gpu_arch_version = ( "" if arch_version == "cpu" diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 54884e3a1261..fcac02bb8fe8 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -60,7 +60,7 @@ class BinaryBuildWorkflow: branches: str = "nightly" # Mainly for macos cross_compile_arm64: bool = False - macos_runner: str = "macos-12-xl" + macos_runner: str = "macos-14-xlarge" def __post_init__(self) -> None: if self.abi_version: @@ -285,7 +285,7 @@ class OperatingSystem: libtorch_variants=["shared-with-deps"], ), cross_compile_arm64=False, - macos_runner="macos-13-xlarge", + macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH}, isolated_workflow=True, @@ -298,7 +298,7 @@ class OperatingSystem: OperatingSystem.MACOS_ARM64 ), cross_compile_arm64=False, - macos_runner="macos-13-xlarge", + macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, isolated_workflow=True, @@ -308,7 +308,7 @@ class OperatingSystem: os=OperatingSystem.MACOS_ARM64, package_type="conda", cross_compile_arm64=False, - macos_runner="macos-13-xlarge", + macos_runner="macos-14-xlarge", build_configs=generate_binary_build_matrix.generate_conda_matrix( OperatingSystem.MACOS_ARM64 ), diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index 50a04ef487a0..ae3c203cf70f 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -29,6 +29,7 @@ python3 -m tools.pyi.gen_pyi \ --native-functions-path aten/src/ATen/native/native_functions.yaml \ --tags-path aten/src/ATen/native/tags.yaml \ --deprecated-functions-path "tools/autograd/deprecated.yaml" +python3 torch/utils/data/datapipes/gen_pyi.py RC=0 # Run lintrunner on all files diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index 2641fd30f348..ec3e69b706f8 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -773,13 +773,13 @@ def test_get_classifications_broken_trunk(self, *args: Any) -> None: # than the one on the base commit. This should still count as broken trunk "pr_num": 104214, "related_failure_count": 0, - "unrelated_failure_count": 1, + "flaky_or_broken_trunk": 1, }, { # This PR had one broken trunk failure and it used ghstack "pr_num": 105145, "related_failure_count": 0, - "unrelated_failure_count": 1, + "flaky_or_broken_trunk": 1, }, { # The failure on the merge base was retried successfully and @@ -788,20 +788,20 @@ def test_get_classifications_broken_trunk(self, *args: Any) -> None: # be used to detect broken trunk "pr_num": 107160, "related_failure_count": 0, - "unrelated_failure_count": 4, + "flaky_or_broken_trunk": 1, }, { # This PR used Dr.CI broken trunk classification "pr_num": 111253, "related_failure_count": 1, - "unrelated_failure_count": 2, + "flaky_or_broken_trunk": 1, }, ] for case in test_cases: pr_num = case["pr_num"] related_failure_count = case["related_failure_count"] - unrelated_failure_count = case["unrelated_failure_count"] + flaky_or_broken_trunk = case["flaky_or_broken_trunk"] pr = GitHubPR("pytorch", "pytorch", pr_num) checks = pr.get_checkrun_conclusions() @@ -823,7 +823,7 @@ def test_get_classifications_broken_trunk(self, *args: Any) -> None: ) self.assertTrue(len(pending) == 0) self.assertTrue( - len(failed) == unrelated_failure_count + related_failure_count + len(failed) == flaky_or_broken_trunk + related_failure_count ) def test_ignore_current(self, *args: Any) -> None: diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 95311d2d9b83..6a6d080a9b3a 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -2027,10 +2027,8 @@ def categorize_checks( pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] - # ok_failed_checks is used with ok_failed_checks_threshold while ignorable_failed_checks - # is used to keep track of all ignorable failures when saving the merge record on Rockset - ok_failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = [] - ignorable_failed_checks: Dict[str, List[Any]] = defaultdict(list) + # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on Rockset + failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list) # If required_checks is not set or empty, consider all names are relevant relevant_checknames = [ @@ -2058,36 +2056,38 @@ def categorize_checks( continue elif not is_passing_status(check_runs[checkname].status): target = ( - ignorable_failed_checks[classification] + failed_checks_categorization[classification] if classification in ("IGNORE_CURRENT_CHECK", "BROKEN_TRUNK", "FLAKY", "UNSTABLE") else failed_checks ) target.append((checkname, url, job_id)) - if classification in ("BROKEN_TRUNK", "FLAKY", "UNSTABLE"): - ok_failed_checks.append((checkname, url, job_id)) + flaky_or_broken_trunk = ( + failed_checks_categorization["BROKEN_TRUNK"] + + failed_checks_categorization["FLAKY"] + ) - if ok_failed_checks: + if flaky_or_broken_trunk: warn( - f"The following {len(ok_failed_checks)} checks failed but were likely due flakiness or broken trunk: " - + ", ".join([x[0] for x in ok_failed_checks]) + f"The following {len(flaky_or_broken_trunk)} checks failed but were likely due flakiness or broken trunk: " + + ", ".join([x[0] for x in flaky_or_broken_trunk]) + ( f" but this is greater than the threshold of {ok_failed_checks_threshold} so merge will fail" if ok_failed_checks_threshold is not None - and len(ok_failed_checks) > ok_failed_checks_threshold + and len(flaky_or_broken_trunk) > ok_failed_checks_threshold else "" ) ) if ( ok_failed_checks_threshold is not None - and len(ok_failed_checks) > ok_failed_checks_threshold + and len(flaky_or_broken_trunk) > ok_failed_checks_threshold ): - failed_checks = failed_checks + ok_failed_checks + failed_checks = failed_checks + flaky_or_broken_trunk - # The list of ignorable_failed_checks is returned so that it can be saved into the Rockset merge record - return (pending_checks, failed_checks, ignorable_failed_checks) + # The list of failed_checks_categorization is returned so that it can be saved into the Rockset merge record + return (pending_checks, failed_checks, failed_checks_categorization) def merge( diff --git a/.github/workflows/build-ios-binaries.yml b/.github/workflows/build-ios-binaries.yml index 3f3be84f48bd..32598f07a5c0 100644 --- a/.github/workflows/build-ios-binaries.yml +++ b/.github/workflows/build-ios-binaries.yml @@ -49,7 +49,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "SIMULATOR", ios_arch: "arm64", use_lite_interpreter: ${{ inputs.use_lite_interpreter || 1 }}, @@ -60,7 +60,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "OS", ios_arch: "arm64", use_lite_interpreter: ${{ inputs.use_lite_interpreter || 1 }}, diff --git a/.github/workflows/delete_old_branches.yml b/.github/workflows/delete_old_branches.yml index 04a0521419a8..eabb98e32065 100644 --- a/.github/workflows/delete_old_branches.yml +++ b/.github/workflows/delete_old_branches.yml @@ -29,7 +29,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.11' architecture: x64 check-latest: false diff --git a/.github/workflows/docathon-sync-label.yml b/.github/workflows/docathon-sync-label.yml new file mode 100644 index 000000000000..7cb1f608722d --- /dev/null +++ b/.github/workflows/docathon-sync-label.yml @@ -0,0 +1,30 @@ +name: Docathon Labels Sync + +on: + pull_request_target: + types: [opened, synchronize, edited] + branches: [main] + +jobs: + check-labels: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - name: Check out the repo + uses: actions/checkout@v2 + with: + fetch-depth: 1 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.x + - name: Install dependencies + run: | + pip install requests==2.32.3 + pip install PyGithub==2.3.0 + - name: Run Python script + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: python ./.github/scripts/docathon-label-sync.py ${{ github.event.pull_request.number }} diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 0eec1556bb96..f732dab42050 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -38,19 +38,19 @@ jobs: matrix: runner: [linux.12xlarge] docker-image-name: [ - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9, - pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9, - pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks, - pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9, + pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, pytorch-linux-focal-py3.8-clang10, pytorch-linux-focal-py3.11-clang10, pytorch-linux-focal-py3.12-clang10, pytorch-linux-focal-rocm-n-1-py3, pytorch-linux-focal-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12, pytorch-linux-focal-py3-clang9-android-ndk-r21e, pytorch-linux-jammy-py3.8-gcc11, pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks, @@ -58,7 +58,7 @@ jobs: pytorch-linux-jammy-py3-clang15-asan, pytorch-linux-focal-py3-clang10-onnx, pytorch-linux-focal-linter, - pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter, pytorch-linux-jammy-py3-clang12-executorch ] include: diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 9f5221a88f9c..351497bee753 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -149,3 +149,10 @@ jobs: - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() + + validate: + needs: build + uses: pytorch/builder/.github/workflows/validate-docker-images.yml@main + with: + channel: nightly + ref: main diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 726dbf40f985..a1a7e6fd9537 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_8-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-aarch64-test: # Testing @@ -162,7 +162,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-aarch64-test: # Testing @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-aarch64-test: # Testing @@ -378,7 +378,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-aarch64-test: # Testing @@ -486,7 +486,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-aarch64-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 6e7edae7b613..053877b1c90e 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -48,7 +48,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -88,7 +88,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -128,7 +128,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 8ad43b4c3660..9d59728bbbbb 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -174,7 +174,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda11_8-test: # Testing @@ -237,7 +237,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_1-test: # Testing @@ -300,7 +300,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cuda12_4-test: # Testing @@ -690,7 +690,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda11_8-test: # Testing @@ -753,7 +753,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_1-test: # Testing @@ -816,7 +816,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_4-test: # Testing @@ -1206,7 +1206,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda11_8-test: # Testing @@ -1269,7 +1269,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_1-test: # Testing @@ -1332,7 +1332,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_4-test: # Testing @@ -1722,7 +1722,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda11_8-test: # Testing @@ -1785,7 +1785,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_1-test: # Testing @@ -1848,7 +1848,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_4-test: # Testing @@ -2238,7 +2238,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda11_8-test: # Testing @@ -2301,7 +2301,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_1-test: # Testing @@ -2364,7 +2364,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_4-test: # Testing @@ -2410,209 +2410,3 @@ jobs: conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-rocm6_0-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.0 - GPU_ARCH_VERSION: 6.0 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_0 - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-rocm6_0-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_12-rocm6_0-build - runs-on: linux.rocm.gpu - timeout-minutes: 240 - env: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.0 - GPU_ARCH_VERSION: 6.0 - GPU_ARCH_TYPE: rocm - SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main - DESIRED_PYTHON: "3.12" - steps: - - name: Setup ROCm - uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 - name: Download Build Artifacts - with: - name: manywheel-py3_12-rocm6_0 - path: "${{ runner.temp }}/artifacts/" - - name: Checkout PyTorch - uses: malfet/checkout@silent-checkout - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - quiet-checkout: true - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: Checkout pytorch/builder - uses: malfet/checkout@silent-checkout - with: - ref: main - submodules: recursive - repository: pytorch/builder - path: builder - quiet-checkout: true - - name: Clean pytorch/builder checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: builder - - name: ROCm set GPU_FLAG - run: | - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" - - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: pytorch/manylinux-builder:rocm6.0-main - - name: Test Pytorch binary - uses: ./pytorch/.github/actions/test-pytorch-binary - - name: Teardown ROCm - uses: ./.github/actions/teardown-rocm - manywheel-py3_12-rocm6_0-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-rocm6_0-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.0 - GPU_ARCH_VERSION: 6.0 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_0 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-rocm6_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.1 - GPU_ARCH_VERSION: 6.1 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_1 - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-rocm6_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_12-rocm6_1-build - runs-on: linux.rocm.gpu - timeout-minutes: 240 - env: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.1 - GPU_ARCH_VERSION: 6.1 - GPU_ARCH_TYPE: rocm - SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main - DESIRED_PYTHON: "3.12" - steps: - - name: Setup ROCm - uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 - name: Download Build Artifacts - with: - name: manywheel-py3_12-rocm6_1 - path: "${{ runner.temp }}/artifacts/" - - name: Checkout PyTorch - uses: malfet/checkout@silent-checkout - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - quiet-checkout: true - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: Checkout pytorch/builder - uses: malfet/checkout@silent-checkout - with: - ref: main - submodules: recursive - repository: pytorch/builder - path: builder - quiet-checkout: true - - name: Clean pytorch/builder checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: builder - - name: ROCm set GPU_FLAG - run: | - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" - - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: pytorch/manylinux-builder:rocm6.1-main - - name: Test Pytorch binary - uses: ./pytorch/.github/actions/test-pytorch-binary - - name: Teardown ROCm - uses: ./.github/actions/teardown-rocm - manywheel-py3_12-rocm6_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-rocm6_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.1 - GPU_ARCH_VERSION: 6.1 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 4f0569c253f2..db0748463da5 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -54,7 +54,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_8-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_8-cpu-s390x-test: # Testing @@ -117,7 +117,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-s390x-test: # Testing @@ -180,7 +180,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-s390x-test: # Testing @@ -243,7 +243,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-s390x-test: # Testing @@ -306,7 +306,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-s390x-test: # Testing diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index a8cbdb7cd6fe..52ccb92a1935 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: conda-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -152,7 +152,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -270,7 +270,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -388,7 +388,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -506,7 +506,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml index 0ed7ba10a07d..7e2e345aefbc 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: libtorch-cpu-shared-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 167161de3645..b4910d46ed5e 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: wheel-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -153,7 +153,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -165,7 +165,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -272,7 +272,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -284,7 +284,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -391,7 +391,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -403,7 +403,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -510,7 +510,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -522,7 +522,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index d64c221e7895..d06f99bd9a5a 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -290,7 +290,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -536,7 +536,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -782,7 +782,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.8" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1027,7 +1027,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1271,7 +1271,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1517,7 +1517,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1763,7 +1763,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2008,7 +2008,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2252,7 +2252,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2498,7 +2498,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2744,7 +2744,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2989,7 +2989,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3233,7 +3233,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3479,7 +3479,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3725,7 +3725,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3970,7 +3970,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4214,7 +4214,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4460,7 +4460,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4706,7 +4706,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash diff --git a/.github/workflows/inductor-cu124.yml b/.github/workflows/inductor-cu124.yml new file mode 100644 index 000000000000..d7ab5665bed6 --- /dev/null +++ b/.github/workflows/inductor-cu124.yml @@ -0,0 +1,108 @@ +name: inductor-cu124 + +on: + push: + tags: + - ciflow/inductor-cu124/* + workflow_dispatch: + schedule: + # Run every 4 hours during the week and every 12 hours on the weekend + - cron: 45 0,4,8,12,16,20 * * 1-5 + - cron: 45 4,12 * * 0,6 + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + linux-focal-cuda12_4-py3_10-gcc9-inductor-build: + # Should be synced with the one in inductor.yml, but this doesn't run inductor_timm + name: cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-test: + name: cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build + with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-test + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.0' + test-matrix: | + { include: [ + { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, + ]} + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_10-gcc9-inductor-test-gcp: + name: cuda12.4-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + secrets: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + + linux-focal-cuda12_4-py3_12-gcc9-inductor-build: + name: cuda12.4-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3.12-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_12-gcc9-inductor-test: + name: cuda12.4-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-focal-cuda12_4-py3_12-gcc9-inductor-build + with: + build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 4fe0ddf50ef2..431545ea6d0d 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -21,7 +21,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index e485a8bfce1b..a5e4ad1781aa 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -18,7 +18,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index e77c915749f3..2f129c52fe13 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -71,7 +71,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 6f8c06ed030b..2fe649cebb5e 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index cb5122e631bb..2030ff5aee3b 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -44,7 +44,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -86,7 +86,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -112,7 +112,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ @@ -129,28 +129,18 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.test-matrix }} linux-focal-cuda12_4-py3_10-gcc9-inductor-build: + # Should be synced with the one in inductor-periodic.yml but this only runs inductor_timm name: cuda12.4-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, - { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -160,47 +150,13 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: linux-focal-cuda12_4-py3_10-gcc9-inductor-build with: + sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-test build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-focal-cuda12_4-py3_10-gcc9-inductor-build-gcp: - name: cuda12.4-py3.10-gcc9-sm80 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks - cuda-arch-list: '8.0' - test-matrix: | - { include: [ - { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, - ]} - secrets: - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - linux-focal-cuda12_4-py3_12-gcc9-inductor-build: - name: cuda12.4-py3.12-gcc9-sm86 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3.12-gcc9-inductor-benchmarks - cuda-arch-list: '8.6' - test-matrix: | - { include: [ - { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_12-gcc9-inductor-test: - name: cuda12.4-py3.12-gcc9-sm86 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-cuda12_4-py3_12-gcc9-inductor-build - with: - build-environment: linux-focal-cuda12.4-py3.12-gcc9-sm86 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_12-gcc9-inductor-build.outputs.test-matrix }} - linux-jammy-cpu-py3_8-gcc11-inductor-build: name: linux-jammy-cpu-py3.8-gcc11-inductor uses: ./.github/workflows/_linux-build.yml @@ -214,6 +170,11 @@ jobs: { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, + { config: "cpu_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f1b6611d00e0..e0e4d3c20cd8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 @@ -36,7 +36,7 @@ jobs: with: timeout: 120 runner: linux.2xlarge - docker-image: pytorch-linux-jammy-cuda11.8-cudnn8-py3.9-linter + docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index da98d01550a4..06521f20c49e 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -13,29 +13,31 @@ concurrency: permissions: read-all jobs: - macos-13-py3-arm64-build: - name: macos-13-py3-arm64 + macos-py3-arm64-build: + name: macos-py3-arm64 uses: ./.github/workflows/_mac-build.yml with: sync-tag: macos-py3-arm64-build - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 runner-type: macos-m1-stable build-generates-artifacts: true # To match the one pre-installed in the m1 runners python-version: 3.9.12 + # The runner macos-m2-14 is not a typo, it's a custom runner that is different + # than our AWS macos-m1-14 runners test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-stable" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-14" }, ]} macos-py3-arm64-mps-test: name: macos-py3-arm64-mps uses: ./.github/workflows/_mac-test-mps.yml - needs: macos-13-py3-arm64-build + needs: macos-py3-arm64-build with: sync-tag: macos-py3-arm64-mps-test - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 # Same as the build job python-version: 3.9.12 - test-matrix: ${{ needs.macos-13-py3-arm64-build.outputs.test-matrix }} + test-matrix: ${{ needs.macos-py3-arm64-build.outputs.test-matrix }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 716a72cc6d23..bae31f44d742 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -37,6 +37,59 @@ jobs: permissions: id-token: write contents: read + linux-focal-cuda12_1-py3_10-gcc9-build: + name: linux-focal-cuda12.1-py3.10-gcc9 + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 + test-matrix: | + { include: [ + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + ]} + linux-focal-cuda12_1-py3_10-gcc9-test: + name: linux-focal-cuda12.1-py3.10-gcc9 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda12_1-py3_10-gcc9-build + - target-determination + with: + build-environment: linux-focal-cuda12.1-py3.10-gcc9 + docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.test-matrix }} + + linux-focal-cuda12_4-py3_10-gcc9-build: + name: linux-focal-cuda12.4-py3.10-gcc9 + uses: ./.github/workflows/_linux-build-label.yml + with: + build-environment: linux-focal-cuda12.4-py3.10-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda12_4-py3_10-gcc9-test: + name: linux-focal-cuda12.4-py3.10-gcc9 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda12_4-py3_10-gcc9-build + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-focal-cuda12.4-py3.10-gcc9 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} parallelnative-linux-jammy-py3_8-gcc11-build: name: parallelnative-linux-jammy-py3.8-gcc11 @@ -67,7 +120,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.9-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -89,7 +142,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 build-with-debug: true test-matrix: | { include: [ @@ -151,7 +204,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "SIMULATOR", ios_arch: "arm64", use_lite_interpreter: 1, @@ -162,7 +215,7 @@ jobs: { config: "default", shard: 1, num_shards: 1, - runner: "macos-13-xlarge", + runner: "macos-14-xlarge", ios_platform: "OS", ios_arch: "arm64", use_lite_interpreter: 1, diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 2b81e998bde5..b435f1fe0791 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -237,7 +237,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda11.8-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, @@ -262,7 +262,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, @@ -285,34 +285,6 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.test-matrix }} - linux-focal-cuda12_4-py3_10-gcc9-build: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-build-label.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-test: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_4-py3_10-gcc9-build - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} - linux-jammy-py3-clang12-mobile-build: name: linux-jammy-py3-clang12-mobile-build uses: ./.github/workflows/_linux-build-label.yml @@ -325,12 +297,12 @@ jobs: { config: "default", shard: 1, num_shards: 1 }, ]} - linux-jammy-cuda-11_8-cudnn8-py3_8-clang12-build: - name: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + linux-jammy-cuda-11_8-cudnn9-py3_8-clang12-build: + name: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 uses: ./.github/workflows/_linux-build-label.yml with: - build-environment: linux-jammy-cuda11.8-cudnn8-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12 + build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -389,7 +361,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: cpu test-matrix: | { include: [ @@ -401,7 +373,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ @@ -413,7 +385,7 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-version: "12.4" test-matrix: | { include: [ @@ -475,7 +447,7 @@ jobs: uses: ./.github/workflows/_linux-build-label.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -497,33 +469,6 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.test-matrix }} - linux-focal-cuda12_4-py3_10-gcc9-sm86-build: - name: linux-focal-cuda12.4-py3.10-gcc9-sm86 - uses: ./.github/workflows/_linux-build-label.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 - cuda-arch-list: 8.6 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-sm86-test: - name: linux-focal-cuda12.4-py3.10-gcc9-sm86 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_4-py3_10-gcc9-sm86-build - - target-determination - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.test-matrix }} - linux-jammy-py3-clang12-executorch-build: name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build-label.yml diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 31db7af8fc55..50f74b01f08c 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -41,7 +41,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3-gcc9-slow-gradcheck - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ @@ -70,7 +70,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index 0ce1bae6a413..e8bf91c8d9ee 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -26,7 +26,7 @@ jobs: id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 working-directory: pytorch - name: Use following to pull public copy of the image diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 73befe34c078..ac5814966899 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -16,7 +16,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 9da73c8addb7..6897d4b1fa6d 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -34,36 +34,39 @@ jobs: id-token: write contents: read - linux-focal-cuda12_1-py3_10-gcc9-build: - name: linux-focal-cuda12.1-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml + linux-focal-cuda12_4-py3_10-gcc9-sm86-build: + name: linux-focal-cuda12.4-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-build-label.yml with: - build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 + cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} - linux-focal-cuda12_1-py3_10-gcc9-test: - name: linux-focal-cuda12.1-py3.10-gcc9 + linux-focal-cuda12_4-py3_10-gcc9-sm86-test: + name: linux-focal-cuda12.4-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-cuda12_1-py3_10-gcc9-build + - linux-focal-cuda12_4-py3_10-gcc9-sm86-build - target-determination with: - build-environment: linux-focal-cuda12.1-py3.10-gcc9 - docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-build.outputs.test-matrix }} + build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.test-matrix }} libtorch-linux-focal-cuda12_1-py3_7-gcc9-debug-build: name: libtorch-linux-focal-cuda12.1-py3.7-gcc9-debug uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.1-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -77,42 +80,18 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.1-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} - linux-focal-cuda12_4-py3_10-gcc9-build: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 - test-matrix: | - { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda12_4-py3_10-gcc9-test: - name: linux-focal-cuda12.4-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda12_4-py3_10-gcc9-build - - target-determination - with: - build-environment: linux-focal-cuda12.4-py3.10-gcc9 - docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} - libtorch-linux-focal-cuda12_4-py3_7-gcc9-debug-build: name: libtorch-linux-focal-cuda12.4-py3.7-gcc9-debug uses: ./.github/workflows/_linux-build.yml with: build-environment: libtorch-linux-focal-cuda12.4-py3.7-gcc9 - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 build-generates-artifacts: false runner: linux.4xlarge test-matrix: | @@ -126,7 +105,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-focal-cuda12.4-py3.10-gcc9-no-ops - docker-image-name: pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9 + docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -143,12 +122,12 @@ jobs: { config: "default", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, ]} - macos-13-py3-arm64-build: - name: macos-13-py3-arm64 + macos-py3-arm64-build: + name: macos-py3-arm64 uses: ./.github/workflows/_mac-build.yml with: sync-tag: macos-py3-arm64-build - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 runner-type: macos-m1-stable build-generates-artifacts: true # To match the one pre-installed in the m1 runners @@ -163,31 +142,30 @@ jobs: macos-py3-arm64-mps-test: name: macos-py3-arm64-mps uses: ./.github/workflows/_mac-test-mps.yml - needs: macos-13-py3-arm64-build - if: needs.macos-13-py3-arm64-build.outputs.build-outcome == 'success' + needs: macos-py3-arm64-build + if: needs.macos-py3-arm64-build.outputs.build-outcome == 'success' with: sync-tag: macos-py3-arm64-mps-test - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 # Same as the build job python-version: 3.9.12 test-matrix: | { include: [ - { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-stable" }, + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, - ]} - macos-13-py3-arm64-test: - name: macos-13-py3-arm64 + macos-py3-arm64-test: + name: macos-py3-arm64 uses: ./.github/workflows/_mac-test.yml needs: - - macos-13-py3-arm64-build + - macos-py3-arm64-build - target-determination with: - build-environment: macos-13-py3-arm64 + build-environment: macos-py3-arm64 # Same as the build job python-version: 3.9.12 - test-matrix: ${{ needs.macos-13-py3-arm64-build.outputs.test-matrix }} + test-matrix: ${{ needs.macos-py3-arm64-build.outputs.test-matrix }} win-vs2019-cpu-py3-build: name: win-vs2019-cpu-py3 diff --git a/.github/workflows/unstable.yml b/.github/workflows/unstable.yml index ac1d49d1cce5..a2c4a45bd8b5 100644 --- a/.github/workflows/unstable.yml +++ b/.github/workflows/unstable.yml @@ -32,174 +32,3 @@ jobs: echo echo "Once the jobs are deemed stable enough (% red signal < 5% and TTS < 3h)," echo " they can graduate and move back to pull or trunk." - - # - # Experimental ARC jobs - # - llm-td: - name: before-test - uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read - - target-determination: - name: before-test - uses: ./.github/workflows/target_determination.yml - needs: llm-td - permissions: - id-token: write - contents: read - - linux-jammy-py3_8-gcc11-build: - name: linux-jammy-py3.8-gcc11 - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "docs_test", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "backwards_compat", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "distributed", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "distributed", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-jammy-py3_8-gcc11-test: - name: linux-jammy-py3.8-gcc11 - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-jammy-py3_8-gcc11-build - - target-determination - with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.test-matrix }} - - linux-jammy-py3_8-gcc11-no-ops: - name: linux-jammy-py3.8-gcc11-no-ops - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.8-gcc11-no-ops - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - - linux-jammy-py3_8-gcc11-pch: - name: linux-jammy-py3.8-gcc11-pch - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.8-gcc11-pch - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - - linux-focal-py3_8-clang10-onnx-build: - name: linux-focal-py3.8-clang10-onnx - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-focal-py3.8-clang10-onnx - docker-image-name: pytorch-linux-focal-py3-clang10-onnx - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-focal-py3_8-clang10-onnx-test: - name: linux-focal-py3.8-clang10-onnx - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-focal-py3_8-clang10-onnx-build - - target-determination - with: - build-environment: linux-focal-py3.8-clang10-onnx - docker-image: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.test-matrix }} - - linux-jammy-py3_10-clang15-asan-build: - name: linux-jammy-py3.10-clang15-asan - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-jammy-py3.10-clang15-asan - docker-image-name: pytorch-linux-jammy-py3-clang15-asan - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 5, num_shards: 6, runner: "linux.4xlarge" }, - { config: "default", shard: 6, num_shards: 6, runner: "linux.4xlarge" }, - ]} - sync-tag: asan-build-arc - - linux-focal-py3_8-clang10-build: - name: linux-focal-py3.8-clang10 - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-focal-py3.8-clang10 - docker-image-name: pytorch-linux-focal-py3.8-clang10 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-focal-py3_8-clang10-test: - name: linux-focal-py3.8-clang10 - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-focal-py3_8-clang10-build - - target-determination - with: - build-environment: linux-focal-py3.8-clang10 - docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }} - - linux-focal-py3_11-clang10-build: - name: linux-focal-py3.11-clang10 - uses: ./.github/workflows/_linux-build-rg.yml - with: - build-environment: linux-focal-py3.11-clang10 - docker-image-name: pytorch-linux-focal-py3.11-clang10 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" }, - ]} - - linux-focal-py3_11-clang10-test: - name: linux-focal-py3.11-clang10 - uses: ./.github/workflows/_linux-test-rg.yml - needs: - - linux-focal-py3_11-clang10-build - - target-determination - with: - build-environment: linux-focal-py3.11-clang10 - docker-image: ${{ needs.linux-focal-py3_11-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_11-clang10-build.outputs.test-matrix }} - - # - # End of Experimental ARC jobs - # diff --git a/.gitmodules b/.gitmodules index 476f11fd945c..c031c2fd5ad3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -18,10 +18,6 @@ ignore = dirty path = third_party/protobuf url = https://github.com/protocolbuffers/protobuf.git -[submodule "third_party/ios-cmake"] - ignore = dirty - path = third_party/ios-cmake - url = https://github.com/Yangqing/ios-cmake.git [submodule "third_party/NNPACK"] ignore = dirty path = third_party/NNPACK diff --git a/.lintrunner.toml b/.lintrunner.toml index 2dc4305d9ab9..92a7fc0b1d8e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -235,7 +235,6 @@ exclude_patterns = [ 'torch/csrc/jit/serialization/import_legacy.cpp', 'torch/csrc/jit/serialization/export.cpp', 'torch/csrc/lazy/**/*', - 'torch/csrc/onnx/init.cpp', 'torch/csrc/mps/**/*', ] init_command = [ @@ -1000,7 +999,6 @@ command = [ ] exclude_patterns = [ 'tools/gen_vulkan_spv.py', - 'torch/__init__.py', # Skip this file to format because it's part of the public API # We don't care too much about files in this directory, don't enforce # formatting on them 'caffe2/**/*.py', @@ -1073,7 +1071,6 @@ exclude_patterns = [ 'test/test_jit_disabled.py', 'test/test_jit_fuser.py', 'test/test_jit_fuser_legacy.py', - 'test/test_jit_fuser_te.py', 'test/test_jit_legacy.py', 'test/test_jit_llga_fuser.py', 'test/test_jit_profiling.py', @@ -1101,7 +1098,6 @@ exclude_patterns = [ 'test/test_namedtuple_return_api.py', 'test/test_native_functions.py', 'test/test_native_mha.py', - 'test/test_nestedtensor.py', 'test/test_nn.py', 'test/test_out_dtype_op.py', 'test/test_overrides.py', @@ -1116,9 +1112,6 @@ exclude_patterns = [ 'test/test_segment_reductions.py', 'test/test_serialization.py', 'test/test_set_default_mobile_cpu_allocator.py', - 'test/test_shape_ops.py', - 'test/test_show_pickle.py', - 'test/test_sort_and_select.py', 'test/test_sparse.py', 'test/test_sparse_csr.py', 'test/test_sparse_semi_structured.py', @@ -1537,28 +1530,6 @@ exclude_patterns = [ 'torch/distributed/optim/post_localSGD_optimizer.py', 'torch/distributed/optim/utils.py', 'torch/distributed/optim/zero_redundancy_optimizer.py', - 'torch/distributed/pipeline/__init__.py', - 'torch/distributed/pipeline/sync/__init__.py', - 'torch/distributed/pipeline/sync/_balance/__init__.py', - 'torch/distributed/pipeline/sync/_balance/blockpartition.py', - 'torch/distributed/pipeline/sync/_balance/profile.py', - 'torch/distributed/pipeline/sync/batchnorm.py', - 'torch/distributed/pipeline/sync/checkpoint.py', - 'torch/distributed/pipeline/sync/copy.py', - 'torch/distributed/pipeline/sync/dependency.py', - 'torch/distributed/pipeline/sync/microbatch.py', - 'torch/distributed/pipeline/sync/phony.py', - 'torch/distributed/pipeline/sync/pipe.py', - 'torch/distributed/pipeline/sync/pipeline.py', - 'torch/distributed/pipeline/sync/skip/__init__.py', - 'torch/distributed/pipeline/sync/skip/layout.py', - 'torch/distributed/pipeline/sync/skip/namespace.py', - 'torch/distributed/pipeline/sync/skip/portal.py', - 'torch/distributed/pipeline/sync/skip/skippable.py', - 'torch/distributed/pipeline/sync/skip/tracker.py', - 'torch/distributed/pipeline/sync/stream.py', - 'torch/distributed/pipeline/sync/utils.py', - 'torch/distributed/pipeline/sync/worker.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', 'torch/distributed/rpc/__init__.py', @@ -1583,7 +1554,6 @@ exclude_patterns = [ 'torch/distributed/tensor/parallel/input_reshard.py', 'torch/distributed/tensor/parallel/multihead_attention_tp.py', 'torch/distributed/tensor/parallel/style.py', - 'torch/distributed/utils.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/functional.py', @@ -1675,18 +1645,6 @@ exclude_patterns = [ 'torch/hub.py', 'torch/library.py', 'torch/linalg/__init__.py', - # UFMT causes import cycle on masked - 'torch/masked/__init__.py', - 'torch/masked/_docs.py', - 'torch/masked/_ops.py', - 'torch/masked/maskedtensor/__init__.py', - 'torch/masked/maskedtensor/_ops_refs.py', - 'torch/masked/maskedtensor/binary.py', - 'torch/masked/maskedtensor/core.py', - 'torch/masked/maskedtensor/creation.py', - 'torch/masked/maskedtensor/passthrough.py', - 'torch/masked/maskedtensor/reductions.py', - 'torch/masked/maskedtensor/unary.py', 'torch/monitor/__init__.py', 'torch/nested/__init__.py', 'torch/nn/__init__.py', @@ -1865,8 +1823,6 @@ exclude_patterns = [ 'torch/testing/_internal/distributed/nn/__init__.py', 'torch/testing/_internal/distributed/nn/api/__init__.py', 'torch/testing/_internal/distributed/nn/api/remote_module_test.py', - 'torch/testing/_internal/distributed/pipe_with_ddp_test.py', - 'torch/testing/_internal/distributed/pipeline/__init__.py', 'torch/testing/_internal/distributed/rpc/__init__.py', 'torch/testing/_internal/distributed/rpc/dist_autograd_test.py', 'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py', @@ -2121,7 +2077,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.4.5', + 'ruff==0.4.8', ] is_formatter = true diff --git a/BUILD.bazel b/BUILD.bazel index 6d01ff42305c..f2f3be210e93 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -455,22 +455,14 @@ filegroup( name = "caffe2_core_srcs", srcs = [ "caffe2/core/common.cc", - "caffe2/core/types.cc", ], ) filegroup( name = "caffe2_perfkernels_srcs", srcs = [ - "caffe2/perfkernels/adagrad.cc", "caffe2/perfkernels/embedding_lookup.cc", "caffe2/perfkernels/embedding_lookup_idx.cc", - "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc", - "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc", - "caffe2/perfkernels/fused_nbit_rowwise_conversion.cc", - "caffe2/perfkernels/lstm_unit_cpu_common.cc", - "caffe2/perfkernels/math_cpu_base.cc", - "caffe2/perfkernels/typed_axpy.cc", ], ) @@ -488,10 +480,6 @@ filegroup( filegroup( name = "caffe2_utils_srcs", srcs = [ - "caffe2/utils/bench_utils.cc", - "caffe2/utils/cpuid.cc", - "caffe2/utils/murmur_hash3.cc", - "caffe2/utils/proto_utils.cc", "caffe2/utils/proto_wrap.cc", "caffe2/utils/string_utils.cc", "caffe2/utils/threadpool/ThreadPool.cc", @@ -510,12 +498,9 @@ cc_library( name = "caffe2_for_aten_headers", hdrs = [ "caffe2/core/common.h", - "caffe2/core/logging.h", - "caffe2/core/types.h", "caffe2/perfkernels/common.h", "caffe2/perfkernels/embedding_lookup.h", "caffe2/perfkernels/embedding_lookup_idx.h", - "caffe2/utils/cpuid.h", "caffe2/utils/fixed_divisor.h", ] + glob([ "caffe2/utils/threadpool/*.h", @@ -525,7 +510,6 @@ cc_library( deps = [ ":caffe2_core_macros", "//c10", - "//caffe2/proto:caffe2_pb", ], ) @@ -544,15 +528,12 @@ cc_library( ], ) + if_cuda(glob([ "caffe2/**/*.cuh", - "caffe2/image/*.h", ])), copts = CAFFE2_COPTS, visibility = ["//visibility:public"], deps = [ ":caffe2_core_macros", ":caffe2_for_aten_headers", - "//caffe2/proto:caffe2_pb", - "//caffe2/proto:cc_proto", ], ) @@ -573,8 +554,6 @@ cc_library( ":caffe2_perfkernels_avx", ":caffe2_perfkernels_avx2", ":caffe2_perfkernels_avx512", - "//caffe2/proto:caffe2_pb", - "//caffe2/proto:cc_proto", "//third_party/miniz-2.1.0:miniz", "@com_google_protobuf//:protobuf", "@eigen", @@ -782,8 +761,8 @@ cc_library( deps = [ ":caffe2", ":torch_headers", - "//caffe2/proto:torch_cc_proto", "@kineto", + "@cpp-httplib", ] + if_cuda([ "@cuda//:nvToolsExt", "@cutlass", diff --git a/CMakeLists.txt b/CMakeLists.txt index 335f5750648c..c4cd4b2c2a98 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,8 +242,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) option(USE_CUDA "Use CUDA" ON) -cmake_dependent_option(USE_XPU "Use XPU. Only available on Linux." ON "LINUX" - OFF) +option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) @@ -540,6 +539,8 @@ option(BUILD_EXECUTORCH "Master flag to build Executorch" ON) if(LINUX) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--no-as-needed") + set(CMAKE_SHARED_LINKER_FLAGS + "${CMAKE_SHARED_LINKER_FLAGS} $ENV{LDFLAGS}") endif() if(MSVC) @@ -864,12 +865,13 @@ cmake_dependent_option( # Suspect users building from source will need this add_definitions(-DFLASHATTENTION_DISABLE_ALIBI) -# CAVEAT: Again, do not check USE_ROCM here Flash Attention2 will error while -# building for sm52 while Mem Eff Attention won't +# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem +# Eff Attention won't cmake_dependent_option( USE_MEM_EFF_ATTENTION "Enable memory-efficient attention for scaled dot product attention.\ - Will be disabled if not supported by the platform" ON "USE_CUDA" OFF) + Will be disabled if not supported by the platform" ON + "USE_CUDA OR USE_ROCM" OFF) if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") @@ -892,6 +894,14 @@ endif() if(USE_SLEEF_FOR_ARM_VEC256) string(APPEND CMAKE_CXX_FLAGS " -DAT_BUILD_ARM_VEC256_WITH_SLEEF") + add_definitions(-DAT_BUILD_ARM_VEC256_WITH_SLEEF) +endif() + +# Enable sleef on macOS with Apple silicon by default +if((${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") AND ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64")) + message(STATUS "Running on macOS with Apple silicon") + string(APPEND CMAKE_CXX_FLAGS " -DAT_BUILD_ARM_VEC256_WITH_SLEEF") + add_definitions(-DAT_BUILD_ARM_VEC256_WITH_SLEEF) endif() if(USE_XNNPACK) diff --git a/README.md b/README.md index 2a469af7b166..aa4638f9ece6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![PyTorch Logo](https://github.com/pytorch/pytorch/blob/main/docs/source/_static/img/pytorch-logo-dark.png) +![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png) -------------------------------------------------------------------------------- @@ -98,7 +98,7 @@ from several research papers on this topic, as well as current and past work suc While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research. -![Dynamic graph](https://github.com/pytorch/pytorch/blob/main/docs/source/_static/img/dynamic_graph.gif) +![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif) ### Python First @@ -189,7 +189,7 @@ Other potentially useful environment variables may be found in `setup.py`. ##### Intel GPU Support If you want to compile with Intel GPU support, follow these - [PyTorch Prerequisites for Intel GPUs](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html) instructions. -- Intel GPU is currently supported only for Linux systems. +- Intel GPU is supported for Linux and Windows. If you want to disable Intel GPU support, export the environment variable `USE_XPU=0`. Other potentially useful environment variables may be found in `setup.py`. @@ -213,6 +213,7 @@ conda install -c pytorch magma-cuda121 # or the magma-cuda* that matches your C # (optional) If using torch.compile with inductor/triton, install the matching version of triton # Run from the pytorch directory after cloning +# For Intel GPU support, please explicitly `export USE_XPU=1` before running command. make triton ``` diff --git a/RELEASE.md b/RELEASE.md index ff8e99883e4e..3c9d68f9a6cd 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -37,6 +37,7 @@ - [TL;DR](#tldr) - [Accelerator Software](#accelerator-software) - [Special support cases](#special-support-cases) + - [Operating Systems](#operating-systems) - [Submitting Tutorials](#submitting-tutorials) - [Special Topics](#special-topics) - [Updating submodules for a release](#updating-submodules-for-a-release) @@ -426,6 +427,15 @@ the size restrictions for publishing on PyPI so the default version that is publ These special support cases will be handled on a case by case basis and support may be continued if current PyTorch maintainers feel as though there may still be a need to support these particular versions of software. +## Operating Systems +Supported OS flavors are summarized in the table below: +| Operating System family | Architectrue | Notes | +| --- | --- | --- | +| Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. | +| MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). | +| MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 | +| Windows | x86_64 | Buils are compatible with Windows-10 or newer. | + # Submitting Tutorials Tutorials in support of a release feature must be submitted to the [pytorch/tutorials](https://github.com/pytorch/tutorials) repo at least two weeks before the release date to allow for editorial and technical review. There is no cherry-pick process for tutorials. All tutorials will be merged around the release day and published at [pytorch.org/tutorials](https://pytorch.org/tutorials/). diff --git a/SECURITY.md b/SECURITY.md index e8e0249fc896..119a2b7615ac 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -5,6 +5,7 @@ - [Untrusted models](#untrusted-models) - [Untrusted inputs](#untrusted-inputs) - [Data privacy](#data-privacy) + - [Using distributed features](#using-distributed-features) ## Reporting Security Issues @@ -39,7 +40,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de ### Untrusted inputs during training and prediction -If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permisisons strictly required, and keep your libraries updated with the lates security patches. +If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permissions strictly required, and keep your libraries updated with the latest security patches. If applicable, prepare your model against bad inputs and prompt injections. Some recommendations: - Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using fuzzing for prompt injection). @@ -54,3 +55,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some **Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: - Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment) - If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits). + +### Using distributed features + +PyTorch can be used for distributed computing, and as such there is a `torch.distributed` package. PyTorch Distributed features are intended for internal communication only. They are not built for use in untrusted environments or networks. + +For performance reasons, none of the PyTorch Distributed primitives (including c10d, RPC, and TCPStore) include any authorization protocol and will send messages unencrypted. They accept connections from anywhere, and execute the workload sent without performing any checks. Therefore, if you run a PyTorch Distributed program on your network, anybody with access to the network can execute arbitrary code with the privileges of the user running PyTorch. diff --git a/WORKSPACE b/WORKSPACE index 5b4f2f2e3375..4169e0dbce1d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -168,6 +168,12 @@ new_local_repository( path = "third_party/opentelemetry-cpp", ) +new_local_repository( + name = "cpp-httplib", + build_file = "//third_party:cpp-httplib.BUILD", + path = "third_party/cpp-httplib", +) + new_local_repository( name = "tensorpipe", build_file = "//third_party:tensorpipe.BUILD", diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9fa7a1f2305b..0087dd95d96e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -386,6 +386,7 @@ if(UNIX AND NOT APPLE) endif(UNIX AND NOT APPLE) if(UNIX) + include(CheckFunctionExists) set(CMAKE_EXTRA_INCLUDE_FILES "sys/mman.h") CHECK_FUNCTION_EXISTS(mmap HAVE_MMAP) if(HAVE_MMAP) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a922bcd5922f..bb6b0611b743 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -364,7 +364,7 @@ class TORCH_API Context { bool enabled_flashSDP = true; bool enabled_mem_efficientSDP = true; bool enabled_mathSDP = true; - bool enabled_cudnnSDP = false; + bool enabled_cudnnSDP = true; #ifdef USE_ROCM bool benchmark_cudnn = true; #else @@ -385,8 +385,11 @@ class TORCH_API Context { ? at::LinalgBackend::Cusolver : at::LinalgBackend::Default; at::BlasBackend blas_preferred_backend = - (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true || - c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) +#ifdef USE_ROCM + (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false) +#else + (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true) +#endif ? at::BlasBackend::Cublaslt : at::BlasBackend::Cublas; #ifdef C10_MOBILE diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 3d2350d26101..6fb966f66713 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -143,7 +143,7 @@ static Device getATenDevice(const DLDevice& ctx, void* data) { return at::detail::getXPUHooks().getDeviceFromPtr(data); default: TORCH_CHECK( - false, "Unsupported device_type: " + c10::to_string(ctx.device_type)); + false, "Unsupported device_type: ", std::to_string(ctx.device_type)); } } @@ -167,7 +167,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kUInt bits " + c10::to_string(dtype.bits)); + false, "Unsupported kUInt bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLInt: @@ -186,7 +186,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kInt bits " + c10::to_string(dtype.bits)); + false, "Unsupported kInt bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLFloat: @@ -202,7 +202,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLBfloat: @@ -212,7 +212,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLComplex: @@ -228,7 +228,7 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kFloat bits " + c10::to_string(dtype.bits)); + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; case DLDataTypeCode::kDLBool: @@ -238,11 +238,11 @@ ScalarType toScalarType(const DLDataType& dtype) { break; default: TORCH_CHECK( - false, "Unsupported kDLBool bits " + c10::to_string(dtype.bits)); + false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); } break; default: - TORCH_CHECK(false, "Unsupported code " + c10::to_string(dtype.code)); + TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code)); } return stype; } @@ -298,9 +298,7 @@ Tensor fromDLPack(DLManagedTensor* src) { return fromDLPack(src, std::move(deleter)); } -Tensor fromDLPack( - DLManagedTensor* src, - std::function deleter) { +Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data); ScalarType stype = toScalarType(src->dl_tensor.dtype); if (!src->dl_tensor.strides) { diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 03cfca36e722..66973031c431 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -462,7 +462,7 @@ inline Tensor _sum_to( reduce_dims.push_back(i); } for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { - if (shape[i - leading_dims] == 1 && + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) && TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) { reduce_dims.push_back(i); } diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 7689934e4113..2e6792d5ca69 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -19,7 +19,13 @@ MemOverlap has_internal_overlap(TensorImpl* t) { auto strides = t->sym_strides(); auto sizes = t->sym_sizes(); for (const auto i : c10::irange(strides.size())) { - if (strides[i] == 0 && sizes[i] > 1) { + // NB: The size oblivious test is written very carefully here. When + // unbacked SymInts are involved, we should try to conservatively report + // if memory overlap /could/ happen under some setting of unbacked + // SymInts. Thus, if I have u0 size, we should assume that this has > 1 + // elements (first expression), but if I have a u0 stride, I should NOT + // assume that it is not zero (second expression) + if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) { return MemOverlap::Yes; } } diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index b2ef33ffc058..eb36c0e02fa4 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -197,7 +197,7 @@ TORCH_API std::ostream& operator<<( const std::vector& tensor_indices); namespace impl { -static inline Tensor applySlice( +inline Tensor applySlice( const Tensor& self, int64_t dim, c10::SymInt start, @@ -218,8 +218,8 @@ static inline Tensor applySlice( ? (*self_sizes)[dim] : self.sym_size(dim); if (!disable_slice_optimization && - TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && length == stop && - step == 1) { + TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && + TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) { return self; } } @@ -227,7 +227,7 @@ static inline Tensor applySlice( dim, std::move(start), std::move(stop), std::move(step)); } -static inline Tensor applySelect( +inline Tensor applySelect( const Tensor& self, int64_t dim, SymInt index, @@ -266,9 +266,7 @@ static inline Tensor applySelect( return self.select_symint(dim, std::move(index)); } -static inline Tensor boolToIndexingTensorCPUOrCUDA( - const Tensor& self, - bool value) { +inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, // false as empty. if (value) { @@ -278,7 +276,7 @@ static inline Tensor boolToIndexingTensorCPUOrCUDA( } } -static inline Tensor boolToIndexingTensorNonNativeDeviceType( +inline Tensor boolToIndexingTensorNonNativeDeviceType( const Tensor& self, bool value) { // booleans add a dimension of size 1. true indexes this dimension as if 0:, @@ -290,7 +288,7 @@ static inline Tensor boolToIndexingTensorNonNativeDeviceType( } } -static inline Tensor boolToIndexingTensor( +inline Tensor boolToIndexingTensor( const Tensor& self, bool value, const at::Device& self_device) { @@ -301,13 +299,13 @@ static inline Tensor boolToIndexingTensor( } } -static inline Tensor scalarToTensorNonNativeDeviceType( +inline Tensor scalarToTensorNonNativeDeviceType( const Scalar& v, const TensorOptions& options) { return at::scalar_tensor(v, options); } -static inline void recordTensorIndex( +inline void recordTensorIndex( const Tensor& tensor, std::vector& outIndices, int64_t* dim_ptr) { @@ -317,7 +315,7 @@ static inline void recordTensorIndex( (*dim_ptr)++; }; -static inline c10::List<::std::optional> typeConvertIndices( +inline c10::List<::std::optional> typeConvertIndices( const Tensor& /*self*/, std::vector&& indices) { c10::List<::std::optional> converted_inds; @@ -338,7 +336,7 @@ static inline c10::List<::std::optional> typeConvertIndices( // construct a `std::vector` container to be consumed by the C++ // `count_specified_dimensions` function, which adds 100s of nanoseconds // overhead and is undesirable. -static inline int64_t count_specified_dimensions( +inline int64_t count_specified_dimensions( const ArrayRef& indices) { // Count the number of indexed dimensions (everything but ellipsis and None) int64_t count = 0; @@ -372,7 +370,7 @@ static inline int64_t count_specified_dimensions( // // The rest of the functions are in `at::indexing::impl` namespace, signifying // that they shouldn't be used from Python indexing implementation. -static inline Tensor scalarToTensor( +inline Tensor scalarToTensor( const Scalar& v, const TensorOptions& options, const at::Device& self_device) { @@ -387,7 +385,7 @@ static inline Tensor scalarToTensor( // To match numpy semantics: // As a special case for backwards compatibility, // strip away unit dimensions from the left of 'src' -static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { +inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { size_t first_non1_src = sizes.size(); for (const auto i : c10::irange(sizes.size())) { // Unbacked SymInt has different behavior, but this is sound because @@ -402,7 +400,7 @@ static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { return sizes.slice(first_non1_src); } -static inline void copy_to(const Tensor& dst, const Tensor& src) { +inline void copy_to(const Tensor& dst, const Tensor& src) { if (dst.sym_sizes().equals(src.sym_sizes())) { // A shortcut to avoid generating hard-coded constant sizes during tracing. // This is not a perfect solution: when src & dst have different shapes, @@ -421,7 +419,7 @@ static inline void copy_to(const Tensor& dst, const Tensor& src) { // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor // indexing functions from Python ] -static inline Tensor handleDimInMultiDimIndexing( +inline Tensor handleDimInMultiDimIndexing( const Tensor& prev_dim_result, const Tensor& original_tensor, const TensorIndex& index, @@ -509,7 +507,7 @@ static inline Tensor handleDimInMultiDimIndexing( namespace impl { // This mirrors `applySlicing` in // torch/csrc/autograd/python_variable_indexing.cpp -static inline Tensor applySlicing( +inline Tensor applySlicing( const Tensor& self, const ArrayRef& indices, std::vector& outIndices, @@ -550,13 +548,13 @@ static inline Tensor applySlicing( } } // namespace impl -static inline Tensor dispatch_index( +inline Tensor dispatch_index( const Tensor& self, std::vector&& indices) { return self.index(impl::typeConvertIndices(self, std::move(indices))); } -static inline Tensor dispatch_index_put_( +inline Tensor dispatch_index_put_( Tensor& self, std::vector&& indices, const Tensor& value) { @@ -598,7 +596,7 @@ static inline Tensor dispatch_index_put_( // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting // `disable_slice_optimization` when calling C++ tensor indexing functions from // Python ] -static inline Tensor get_item( +inline Tensor get_item( const Tensor& self, const ArrayRef& indices, bool disable_slice_optimization = false) { @@ -664,7 +662,7 @@ static inline Tensor get_item( // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++ // tensor indexing functions from Python ] -static inline void set_item( +inline void set_item( const Tensor& self, const ArrayRef& indices, const Tensor& value, diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index c4a68a33e306..ecc90ace61e6 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -22,7 +22,6 @@ #endif #include -#include #include #include @@ -1398,7 +1397,7 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) { break; } default: - TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", c10::to_string((int)setup_type)); + TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type)); } //coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here) if (ndim() > 1){ diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index f0c73cde2dda..10fb72796fc6 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -68,7 +68,7 @@ thread_local std::array at::kBFloat16, // XLA / TPU at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal - at::kBFloat16, // XPU + at::kHalf, // XPU at::ScalarType::Undefined, // MPS at::ScalarType::Undefined, // Meta (tensors with no data) at::kBFloat16, // HPU / HABANA diff --git a/aten/src/ATen/code_template.h b/aten/src/ATen/code_template.h index 393e322e6fe6..ebf113e9d226 100644 --- a/aten/src/ATen/code_template.h +++ b/aten/src/ATen/code_template.h @@ -31,7 +31,7 @@ struct TemplateEnv { // Add a number 'v' to the map at key 'k' template void d(const std::string& k, const T& v) { - strings_[k] = c10::to_string(v); + strings_[k] = std::to_string(v); lists_.erase(k); } diff --git a/aten/src/ATen/core/Generator.h b/aten/src/ATen/core/Generator.h index 6b76db5d0686..297b805f407b 100644 --- a/aten/src/ATen/core/Generator.h +++ b/aten/src/ATen/core/Generator.h @@ -150,7 +150,7 @@ Generator make_generator(Args&&... args) { * the backend generator type (CPU/CUDAGeneratorImpl etc.) */ template -static inline T * check_generator(std::optional gen) { +inline T * check_generator(std::optional gen) { TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt"); TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed"); TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'"); @@ -164,7 +164,7 @@ static inline T * check_generator(std::optional gen) { * the backend generator type (CPU/CUDAGeneratorImpl etc.) */ template -static inline T* get_generator_or_default(const std::optional& gen, const Generator& default_gen) { +inline T* get_generator_or_default(const std::optional& gen, const Generator& default_gen) { return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); } @@ -177,7 +177,7 @@ namespace detail { * - The new state tensor must be a torch.ByteTensor * - Data of the new state tensor must be contiguous */ -static inline void check_rng_state(const c10::TensorImpl& new_state) { +inline void check_rng_state(const c10::TensorImpl& new_state) { TORCH_CHECK_TYPE( new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte, "RNG state must be a torch.ByteTensor" diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index 53560b9666ae..7f65551fbe70 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -478,8 +478,6 @@ namespace impl { // (maybe except for some internal prim ops). using GenericList = List; -const IValue* ptr_to_first_element(const GenericList& list); - } } diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index 64760b5f782b..0d223122599c 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -350,11 +350,4 @@ void List::unsafeSetElementType(TypePtr t) { impl_->elementType = std::move(t); } -namespace impl { - -inline const IValue* ptr_to_first_element(const GenericList& list) { - return &list.impl_->list[0]; -} - -} } diff --git a/aten/src/ATen/core/MetaFallbackKernel.cpp b/aten/src/ATen/core/MetaFallbackKernel.cpp index 2a7c34b17076..e87f641f9eb1 100644 --- a/aten/src/ATen/core/MetaFallbackKernel.cpp +++ b/aten/src/ATen/core/MetaFallbackKernel.cpp @@ -16,8 +16,8 @@ static void metaFallback( "fake impl or Meta kernel registered. You may have run into this message " "while using an operator with PT2 compilation APIs (torch.compile/torch.export); " "in order to use this operator with those APIs you'll need to add a fake impl. " - "Please see the following doc for next steps: " - "https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit"); + "Please see the following for next steps: " + "https://pytorch.org/docs/main/notes/custom_operators.html"); } TORCH_LIBRARY_IMPL(_, Meta, m) { diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 0188e546179b..7218ee56689c 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -953,7 +953,7 @@ TensorBase make_tensor_base(Args&&... args) { } // namespace detail -static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { +inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { return legacyExtractDispatchKey(t.key_set()); } diff --git a/aten/src/ATen/core/boxing/KernelFunction_test.cpp b/aten/src/ATen/core/boxing/KernelFunction_test.cpp index a0f990e87aaf..cf45c709c58d 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_test.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction_test.cpp @@ -275,16 +275,6 @@ void expectOutOfPlaceMultiBoxedCallingWorks(const KernelFunction& func) { EXPECT_TRUE(stack[1].toTensor().is_same(t2)); } -void expectBoxedCallingFailsWith(const KernelFunction& func, const char* errorMessage) { - called_with_args = c10::nullopt; - vector stack {3, 4}; - OperatorHandle dummy = makeDummyOperatorHandle(); - - expectThrows([&] { - func.callBoxed(dummy, CPU_TEST_SET, &stack); - }, errorMessage); -} - // // unboxed calling tests: // diff --git a/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp index 0b0df2af1ca1..f6dc3ee356a0 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_function_legacy_test.cpp @@ -40,10 +40,6 @@ int64_t incrementKernel(const Tensor& tensor, int64_t input) { return input + 1; } -int64_t decrementKernel(const Tensor& tensor, int64_t input) { - return input - 1; -} - void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoDispatchBelowAutograd mode; @@ -55,17 +51,6 @@ void expectCallsIncrement(DispatchKey dispatch_key) { EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(DispatchKey dispatch_key) { - at::AutoDispatchBelowAutograd mode; - - // assert that schema and cpu kernel are present - auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); - ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(dispatch_key), 5); - EXPECT_EQ(1, result.size()); - EXPECT_EQ(4, result[0].toInt()); -} - TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel); expectCallsIncrement(DispatchKey::CPU); diff --git a/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp index 5662c0982bfb..2d6f7346eec2 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_function_test.cpp @@ -662,18 +662,6 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) { EXPECT_EQ("123", result); } -void expectCannotCallConcatBoxed(DispatchKey dispatch_key) { - at::AutoDispatchBelowAutograd mode; - - // assert that schema and cpu kernel are present - auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); - ASSERT_TRUE(op.has_value()); - expectThrows( - [&] {callOp(*op, dummyTensor(dispatch_key), "1", "2", 3);}, - "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call()." - ); -} - TEST(OperatorRegistrationTestFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().kernel(DispatchKey::CPU)); expectCallsConcatUnboxed(DispatchKey::CPU); diff --git a/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp index 39dceafab006..8db6abad6c33 100644 --- a/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/impl/kernel_lambda_legacy_test.cpp @@ -731,8 +731,7 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenFallbackKernelWithout } TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool called; + bool called = false; std::optional called_arg2 = c10::nullopt; std::optional called_arg3 = c10::nullopt; std::optional called_arg4 = c10::nullopt; @@ -771,8 +770,7 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInp } TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool called; + bool called = false; std::optional called_arg2 = c10::nullopt; std::optional called_arg3 = c10::nullopt; std::optional called_arg4 = c10::nullopt; @@ -814,12 +812,6 @@ TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInp } TEST(OperatorRegistrationTestLegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool called; - std::optional called_arg2 = c10::nullopt; - std::optional called_arg3 = c10::nullopt; - std::optional called_arg4 = c10::nullopt; - auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", [] (Tensor arg1, const std::optional& arg2, std::optional arg3, std::optional arg4) { diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp index 3ced237702aa..345f5b11cba8 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor_test.cpp @@ -51,17 +51,6 @@ void expectCallsIncrement(DispatchKey dispatch_key) { EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(DispatchKey dispatch_key) { - at::AutoDispatchBelowAutograd mode; - - // assert that schema and cpu kernel are present - auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); - ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(dispatch_key), 5); - EXPECT_EQ(1, result.size()); - EXPECT_EQ(4, result[0].toInt()); -} - TEST(OperatorRegistrationTestFunctorBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPU)); expectCallsIncrement(DispatchKey::CPU); diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 46c291bada30..4a345facaa94 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -21,7 +21,7 @@ namespace impl { // on TLS. // // NB: If there is no valid dispatch key, this will return Undefined -static inline DispatchKeySet computeDispatchKeySet( +inline DispatchKeySet computeDispatchKeySet( DispatchKeySet ks, // The key mask lets us eliminate (by zero entries) keys which should not // be considered for dispatch. There are two cases when we use this: diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 707269de902e..5d12aa3d35d7 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -307,7 +307,6 @@ void stackBasedKernel(const OperatorHandle&, c10::Stack* stack) { } TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndNoneCanInferSchema_thenFails) { - bool called_kernel = false; expectThrows([&] { auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel<&stackBasedKernel>(c10::DispatchKey::CPU) diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h index 5dc89da6c562..6372a3ccb556 100644 --- a/aten/src/ATen/core/stack.h +++ b/aten/src/ATen/core/stack.h @@ -66,51 +66,51 @@ class Operation { // treat the last N elements of the stack as a list, looking up // element i -static inline IValue& peek(Stack& stack, size_t i, size_t N) { +inline IValue& peek(Stack& stack, size_t i, size_t N) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions) return *(stack.end() - N + i); } -static inline IValue& peek(Stack* stack, size_t i, size_t N) { +inline IValue& peek(Stack* stack, size_t i, size_t N) { return peek(*stack, i, N); } -static inline const IValue& peek(const Stack& stack, size_t i, size_t N) { +inline const IValue& peek(const Stack& stack, size_t i, size_t N) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions) return *(stack.end() - N + i); } -static inline const IValue& peek(const Stack* stack, size_t i, size_t N) { +inline const IValue& peek(const Stack* stack, size_t i, size_t N) { return peek(*stack, i, N); } // treat the last N elements of the stack as a list, looking up the // slice starting at index i and having length len -static inline at::ArrayRef peekSlice( +inline at::ArrayRef peekSlice( const Stack& stack, size_t i, size_t len, size_t N) { return at::ArrayRef(stack).slice(stack.size() - N + i, len); } -static inline at::ArrayRef last(const Stack& stack, size_t N) { +inline at::ArrayRef last(const Stack& stack, size_t N) { return peekSlice(stack, 0, N, N); } -static inline at::ArrayRef last(const Stack* stack, size_t N) { +inline at::ArrayRef last(const Stack* stack, size_t N) { return last(*stack, N); } -static inline void drop(Stack& stack, size_t n) { +inline void drop(Stack& stack, size_t n) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions) stack.erase(stack.end() - n, stack.end()); } -static inline void drop(Stack* stack, size_t n) { +inline void drop(Stack* stack, size_t n) { drop(*stack, n); } -static inline IValue pop(Stack& stack) { +inline IValue pop(Stack& stack) { auto r = std::move(stack.back()); stack.pop_back(); return r; } -static inline IValue pop(Stack* stack) { +inline IValue pop(Stack* stack) { return pop(*stack); } -static inline std::vector pop(Stack& stack, size_t n) { +inline std::vector pop(Stack& stack, size_t n) { std::vector result; result.reserve(n); for (const auto i : c10::irange(n)) { @@ -127,7 +127,7 @@ static inline std::vector pop(Stack& stack, size_t n) { // b = pop(stack).toTensor(); // a = pop(stack).toInt(); template -static inline void pop(Stack& stack, Types&... args) { +inline void pop(Stack& stack, Types&... args) { size_t i = 0; constexpr size_t N = sizeof...(args); (void)std::initializer_list{ @@ -135,15 +135,15 @@ static inline void pop(Stack& stack, Types&... args) { drop(stack, N); } template -static inline void pop(Stack* stack, Types&... args) { +inline void pop(Stack* stack, Types&... args) { pop(*stack, args...); } template -static inline void push_one(Stack& stack, Type&& arg) { +inline void push_one(Stack& stack, Type&& arg) { stack.emplace_back(std::forward(arg)); } -static inline void push_one(Stack& stack, c10::TensorOptions options) { +inline void push_one(Stack& stack, c10::TensorOptions options) { stack.emplace_back(c10::typeMetaToScalarType(options.dtype())); stack.emplace_back(options.layout()); stack.emplace_back(options.device()); @@ -151,15 +151,15 @@ static inline void push_one(Stack& stack, c10::TensorOptions options) { } template -static inline void push(Stack& stack, Types&&... args) { +inline void push(Stack& stack, Types&&... args) { (void)std::initializer_list{(push_one(stack, std::forward(args)), 0)...}; } template -static inline void push(Stack* stack, Types&&... args) { +inline void push(Stack* stack, Types&&... args) { return push(*stack, std::forward(args)...); } template -static inline void push_list_elements(Stack& stack, const c10::List& elements) { +inline void push_list_elements(Stack& stack, const c10::List& elements) { for (T elem : elements) { stack.push_back(std::move(elem)); } diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index ddb9b34eceb9..fbf861dcabcf 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -4,8 +4,23 @@ #endif namespace at::cpu { +bool is_cpu_support_avx2() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_avx2(); +#else + return false; +#endif +} + +bool is_cpu_support_avx512() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq(); +#else + return false; +#endif +} -bool is_cpu_support_vnni() { +bool is_cpu_support_avx512_vnni() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni(); #else diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index ece13c70bce3..0ad6f8e893ca 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -4,7 +4,10 @@ namespace at::cpu { +TORCH_API bool is_cpu_support_avx2(); +TORCH_API bool is_cpu_support_avx512(); + // Detect if CPU support Vector Neural Network Instruction. -TORCH_API bool is_cpu_support_vnni(); +TORCH_API bool is_cpu_support_avx512_vnni(); } // namespace at::cpu diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index d418dc53af38..2c6cef95f79f 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -48,7 +48,7 @@ class PointerModeGuard { template inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented"); } template <> @@ -66,7 +66,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); template inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::gemm_internal: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented"); } template <> @@ -154,7 +154,7 @@ void scaled_gemm( template inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented"); } template <> @@ -172,7 +172,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); template inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::bgemm_internal: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented"); } template <> @@ -195,7 +195,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); template inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::trsm: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented"); } template <> @@ -215,10 +215,7 @@ TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::co template inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::blas::trsmBatched: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented"); } template <> @@ -238,7 +235,7 @@ TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_A template inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::gemv: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented"); } template <> @@ -262,7 +259,7 @@ void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); template inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::dot: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented"); } template <> @@ -280,7 +277,7 @@ void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); template inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) { - AT_ERROR("at::cuda::blas::vdot: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented"); } template <> @@ -295,8 +292,7 @@ void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); template void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented"); } template<> TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float)); @@ -313,10 +309,7 @@ TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPE template void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::blas::geqrfBatched: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented"); } template <> TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)); @@ -334,7 +327,7 @@ TORCH_CUDA_CU_API void geqrfBatched>( template void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name()); + TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented"); } template<> TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float)); @@ -350,7 +343,7 @@ TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES template void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype),"at::cuda::blas::gelsBatched: not implemented"); } template<> diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index 8615bcdae911..1696bb3a0f44 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -9,7 +9,8 @@ namespace at::cuda { template cudaDataType getCudaDataType() { - TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.") + static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType."); + return {}; } template<> inline cudaDataType getCudaDataType() { diff --git a/aten/src/ATen/cuda/Sleep.cu b/aten/src/ATen/cuda/Sleep.cu index 4fe857e65c26..586520e25327 100644 --- a/aten/src/ATen/cuda/Sleep.cu +++ b/aten/src/ATen/cuda/Sleep.cu @@ -1,3 +1,4 @@ +#include #include #include @@ -32,4 +33,37 @@ void sleep(int64_t cycles) { C10_CUDA_KERNEL_LAUNCH_CHECK(); } +#ifdef USE_ROCM +__global__ void flush_icache_kernel() +{ + asm __volatile__("s_icache_inv \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" :: + :); +} +#endif + +void flush_icache() { +#ifdef USE_ROCM + dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 60); + dim3 block(64); + flush_icache_kernel<<>>(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif +} + } // namespace at::cuda diff --git a/aten/src/ATen/cuda/Sleep.h b/aten/src/ATen/cuda/Sleep.h index d31bf68ccafb..ef5e83a832f7 100644 --- a/aten/src/ATen/cuda/Sleep.h +++ b/aten/src/ATen/cuda/Sleep.h @@ -7,4 +7,7 @@ namespace at::cuda { // enqueues a kernel that spins for the specified number of cycles TORCH_CUDA_CU_API void sleep(int64_t cycles); +// flushes instruction cache for ROCm; no-op for CUDA +TORCH_CUDA_CU_API void flush_icache(); + } // namespace at::cuda diff --git a/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh b/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh index c9eeeadd542d..231cd167cacb 100644 --- a/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +++ b/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh @@ -34,8 +34,8 @@ struct PhiloxCudaState { int64_t* ptr; }; - Payload seed_; - Payload offset_; + Payload seed_{}; + Payload offset_{}; uint32_t offset_intragraph_ = 0; bool captured_ = false; }; diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index a1d7d0dc2163..64a482bc2781 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -66,7 +66,7 @@ static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t siz return false; } else { - TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); } return true; @@ -76,30 +76,55 @@ static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t siz template struct GemmParams : OpParams { + GemmParams() { + duplicate_inputs_ = false; + } + std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); + return val; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; } - GemmParams* DeepCopy() const { + GemmParams* DeepCopy(bool duplicate_inputs) const { GemmParams* copy = new GemmParams; *copy = *this; c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); - size_t c_size = m * n * sizeof(T); + size_t c_size = ldc * n * sizeof(T); copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } return copy; } // only call on object returned by DeepCopy void Delete() { c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } } TuningStatus NumericalCheck(GemmParams *other) { auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, m*n) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; } char transa; @@ -115,15 +140,98 @@ struct GemmParams : OpParams { at::opmath_type beta; T* c; int64_t ldc; +private: + bool duplicate_inputs_; +}; + +template +struct GemmAndBiasParams : OpParams { + std::string Signature() const override { + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); + return val; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; + } + + GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { + GemmAndBiasParams* copy = new GemmAndBiasParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = ldc * n * sizeof(T); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmAndBiasParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + T* c; + int64_t ldc; + const T* bias; + at::cuda::blas::GEMMAndBiasActivationEpilogue activation; +private: + bool duplicate_inputs_; }; template struct GemmStridedBatchedParams : OpParams { + GemmStridedBatchedParams() { + duplicate_inputs_ = false; + } + std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + return val; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * stride_c * batch; + if (duplicate_inputs) { + size += sizeof(T) * stride_a * batch; + size += sizeof(T) * stride_b * batch; + } + return size; } - GemmStridedBatchedParams* DeepCopy() const { + GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const { GemmStridedBatchedParams* copy = new GemmStridedBatchedParams; *copy = *this; c10::DeviceIndex device = 0; @@ -132,12 +240,23 @@ struct GemmStridedBatchedParams : OpParams { copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * stride_a * batch; + size_t b_size = sizeof(T) * stride_b * batch; + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } return copy; } // only call on object returned by DeepCopy void Delete() { c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } } TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { @@ -162,33 +281,60 @@ struct GemmStridedBatchedParams : OpParams { int64_t ldc; int64_t stride_c; int64_t batch; +private: + bool duplicate_inputs_; }; template struct ScaledGemmParams : OpParams { + ScaledGemmParams() { + duplicate_inputs_ = false; + } + std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); + return val; } - ScaledGemmParams* DeepCopy() const { + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; + } + + ScaledGemmParams* DeepCopy(bool duplicate_inputs) const { ScaledGemmParams* copy = new ScaledGemmParams; *copy = *this; c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); - size_t c_size = m * n * sizeof(T); + size_t c_size = ldc * n * sizeof(T); copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size); AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size); + copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size); + copy->duplicate_inputs_ = true; + } return copy; } // only call on object returned by DeepCopy void Delete() { c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } } TuningStatus NumericalCheck(ScaledGemmParams *other) { - return detail::NumericalCheck(c_dtype, c, other->c, m*n) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; } char transa; @@ -212,6 +358,8 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype; void* amax_ptr; bool use_fast_accum; +private: + bool duplicate_inputs_; }; } // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index b26c2415af7b..ab1525bef652 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -25,35 +25,35 @@ namespace at::cuda::tunable { template -constexpr hipblasDatatype_t HipBlasDataTypeFor(); +constexpr hipblasDatatype_t HipDataTypeFor(); template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_32F; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_32F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_16F; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_16F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_16B; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_16BF; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { - return HIPBLAS_R_64F; +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_64F; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { +constexpr hipblasDatatype_t HipDataTypeFor() { return HIP_R_8F_E4M3_FNUZ; } template <> -constexpr hipblasDatatype_t HipBlasDataTypeFor() { +constexpr hipblasDatatype_t HipDataTypeFor() { return HIP_R_8F_E5M2_FNUZ; } @@ -62,6 +62,11 @@ int GetBatchFromParams(const GemmParams* params) { return 1; } +template +int GetBatchFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetBatchFromParams(const GemmStridedBatchedParams* params) { return params->batch; @@ -77,6 +82,11 @@ int GetStrideAFromParams(const GemmParams* params) { return 1; } +template +int GetStrideAFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetStrideAFromParams(const GemmStridedBatchedParams* params) { return params->stride_a; @@ -92,6 +102,11 @@ int GetStrideBFromParams(const GemmParams* params) { return 1; } +template +int GetStrideBFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetStrideBFromParams(const GemmStridedBatchedParams* params) { return params->stride_b; @@ -107,6 +122,11 @@ int GetStrideCFromParams(const GemmParams* params) { return 1; } +template +int GetStrideCFromParams(const GemmAndBiasParams* params) { + return 1; +} + template int GetStrideCFromParams(const GemmStridedBatchedParams* params) { return params->stride_c; @@ -122,6 +142,11 @@ float GetAlphaFromParams(const GemmParams* params) { return params->alpha; } +template +float GetAlphaFromParams(const GemmAndBiasParams* params) { + return params->alpha; +} + template float GetAlphaFromParams(const GemmStridedBatchedParams* params) { return params->alpha; @@ -137,6 +162,11 @@ float GetBetaFromParams(const GemmParams* params) { return params->beta; } +template +float GetBetaFromParams(const GemmAndBiasParams* params) { + return 0.0; +} + template float GetBetaFromParams(const GemmStridedBatchedParams* params) { return params->beta; @@ -152,6 +182,11 @@ const void* GetAScalePointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + template const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -167,6 +202,11 @@ const void* GetBScalePointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + template const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -182,6 +222,11 @@ const void* GetDScalePointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + template const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -197,6 +242,11 @@ const void* GetBiasPointerFromParams(const GemmParams* params) { return nullptr; } +template +const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { + return params->bias; +} + template const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -212,6 +262,11 @@ hipDataType GetBiasTypeFromParams(const GemmParams* params) { return HIP_R_32F; } +template +hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { + return HipDataTypeFor(); +} + template hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { return HIP_R_32F; @@ -222,6 +277,26 @@ hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); } +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { + return params->activation; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + static hipblasOperation_t _hipblasOpFromChar(char op) { switch (op) { case 'n': @@ -263,19 +338,19 @@ static size_t GetHipblasltWorkspaceSize() { // 256MB is max workspace size allowed for hipblaslt // hipblaslt-bench uses 32MB // recommendation from hipblaslt author was 76MB - size_t workspace_size = 2*128*1024*1024; // default 256MB + size_t workspace_size = 32*1024; // going with 32MB if (env) { try { workspace_size = std::stoi(env); } catch(std::invalid_argument const& e) { TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", - " using default workspace size of ", workspace_size, " bytes."); + " using default workspace size of ", workspace_size, " KiB."); } catch(std::out_of_range const& e) { TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", - " using default workspace size of ", workspace_size, " bytes."); + " using default workspace size of ", workspace_size, " KiB."); } } - return workspace_size; + return workspace_size * 1024; } template @@ -327,9 +402,9 @@ class HipblasltGemmOp : public Callable { TuningStatus Call(const ParamsT* params) override { hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); - auto a_datatype = HipBlasDataTypeFor(); - auto b_datatype = HipBlasDataTypeFor(); - auto in_out_datatype = HipBlasDataTypeFor(); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); auto opa = _hipblasOpFromChar(params->transa); auto opb = _hipblasOpFromChar(params->transb); @@ -385,13 +460,22 @@ class HipblasltGemmOp : public Callable { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } - const void* bias_ptr = GetBiasPointerFromParams(params); - auto bias_datatype = GetBiasTypeFromParams(params); - if (bias_ptr) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); + auto activation = GetActivationFromParams(params); + if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); + } + else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); + } + else { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); } } @@ -413,12 +497,10 @@ class HipblasltGemmOp : public Callable { if (status == HIPBLAS_STATUS_SUCCESS) { if (ret_workspace_size >= workspace_size) { - //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large"); return FAIL; } } else { - //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported"); return FAIL; } @@ -462,9 +544,9 @@ template (); - auto b_datatype = HipBlasDataTypeFor(); - auto in_out_datatype = HipBlasDataTypeFor(); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); std::vector heuristic_result; hipblasLtHandle_t handle; @@ -507,6 +589,11 @@ auto GetHipBlasLtGemmTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } +template +auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + template auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index 364e6975c6c6..e17ff71f3004 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -2,67 +2,30 @@ This directory implements a TunableOp interface. -Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For -example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's -rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one -know which implementation is the fastest and should be chosen? That's what TunableOp provides. - -The behavior of TunableOp is currently easily manipulated through environment variables, though you could use the C++ -interface of at::cuda::tunable::getTuningContext(). A Python interface to the TuningContext does not yet exist. - -Currently only a TunableGemm for ROCm is implemented. Any call to at::cuda::blas::gemm() can optionally use the -TunableGemm. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the -fastest available implementation. - -## Environment Variables - -#### PYTORCH_TUNABLEOP_ENABLED -Default is 0. Set to 1 to enable. -This is the big on/off switch for all TunableOp implementations. - -#### PYTORCH_TUNABLEOP_TUNING -Default is 1. Set to 0 to disable. -When enabled, if a tuned entry isn't found, run the tuning step and record the entry. - -#### PYTORCH_TUNABLEOP_VERBOSE -Default is 0. Set to 1 to enable. -This will produce a lot of diagnostic messages but may be useful to see if TunableOp is being used at all. -Otherwise, TunableOp is completely silent unless there is a warning or error during its use. - -#### PYTORCH_TUNABLEOP_FILENAME -Default is 'tunableop_results.csv'. If you provide a filename, the TuningContext will attempt to read it the first time -the context is used. If tuning is enabled and new tunings are discovered, it will also write out to this same filename -with all tunings, both the ones it read in at startup as well as the new ones found at runtime. This can be used, for -example, to build up a tunings file across many workloads by reusing the same file. Unsetting this variable is not -recommended but can be done, in which case the tuning results will not be saved. - -#### PYTORCH_TUNABLEOP_NUMERICAL_CHECK -Default is 1. Set to 0 to disable. Compare the results of each possible solution against the default solution and reject -those with low accuracy. - -#### PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED -Default is 1. Set to 0 to disable hipblaslt being considered during tuning. - -### Tuning Iterations -By default, each possible solution for a given operator will be run for either 100 iterations or as many iterations can -be run within 30ms, whichever is smaller. Its average execution will be calculated. The fastest solution is chosen. In -addition, a set of warm up iterations can optionally be run prior to the timed iterations. The following environment -variables can be used to set either the maximum number of iterations to attempt or the maximum amount of time allowed in -milliseconds, or both, in which case the smaller of the two values used. - -#### PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS -Default is 30. - -#### PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS -Default is 100. - -#### PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS -Default is 0, meaning it is not used. - -#### PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS -Default is 1. - -## File Output +Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For +example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's +rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one +know which implementation is the fastest and should be chosen? That's what TunableOp provides. + +## Enabling TunableOp and Tuning Separately +The TunableOp feature is enabled separately from enabling the tuning phase itself. Enabling TunableOp means that PyTorch +will replace any standard operators with their Tunable implementations. Any call to a TunableOp first checks whether it +has already been tuned for the given operator inputs. If so, it will immediately call the tuned operation; no further +tuning will take place even when the tuning setting is enabled. Instead if no tuning result is found, and tuning is +enabled, the TunableOp will benchmark every registered implementation of that operator for the given set of inputs and +select the fastest. + +## File Input and Output +The first time any TunableOp is invoked, the internal database of tuned operations will be prepared by attempting to +read the results from the given file. The default filename is 'tunableop_results.csv'. To support tuning when multiple +GPUs are used across multiple processes, the GPU device ordinal is automatically inserted into the filename to avoid +multiple processes overwriting the same file. + +If tuning is enabled and new tunings are discovered during the course of your workload, it will also write out to this +same filename with all tunings, both the ones it read in at startup as well as the new ones found at runtime. This can +be used, for example, to build up a tunings file across many workloads by reusing the same file. The output file is +automatically created when the application terminates. This behavior can be controlled by the C++ and Python APIs but +not the environment variables. Assuming you specified a filename, you'll end up with a CSV file with contents like so: @@ -75,8 +38,8 @@ GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 ``` -Note the "Validator" lines. If you change a library verison, or rocm version, or pytorch version, TunableOp will detect -this and not load the tunings because they are likely affected by other software changes. +Note the "Validator" lines. If you change a library verison, or ROCm version, or PyTorch version, TunableOp will detect +this and reject the tunings file because the prior tunings are likely affected by other software changes. The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of 4 comma-separated fields: operator name, operator parameters, solution name, and average execution time. The execution @@ -86,3 +49,102 @@ hipBLAS or hipBLASLt libraries, if you know the specific solution index you can selected by replacing the value. The operator name and parameters (fields 1 and 2) are internally named and should not be modified. In the case of GemmTunableOp, field 1 indicates the datatype and whether the inputs are transposed (T) or not (N) and field 2 indicates the M, N, K input shapes. + +There is an option to enable verbose output but it is only recommended for debugging purposes. This will produce a lot +of diagnostic messages but may be useful to see if TunableOp is being used at all. Otherwise, TunableOp is completely +silent, besides file output, unless there is a warning or error during its use. + +## A Note on Tuning Behavior, Warmup, and Cache Effects +Tuning an operator consists of iterating through the list or registered implementations and profiling each one. The +profile is established by running a single implementation in a loop multiple times and taking the average execution +time. There is also an optional warmup phase prior to tuning that can help with reaching stable power states by the +hardware. During tuning of a workload the various hardware caches will more likely produce hits than when not tuning. +There are options for flushing the instruction cache and rotate the input tensors which might help produce a more +faithful profile of the tuned operator as if the operator were run within a larger workload instead of in a tight, +repetitive loop. + +By default, each possible solution for a given operator will be run for either 100 iterations or as many iterations that +can be run within 30ms, whichever is smaller, and its average execution will be calculated. The fastest solution among +all that were successfully profiled will be chosen. A profile might fail if the given solution doesn't achieve the same +accuracy as the default implementation or if the solution returns an error code. + +## Current Tunable Operators + +### TunableGemm for ROCm +Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of PyTorch will function correctly when +using TunableOp but the only solution available to CUDA builds is the 'Default' implementation i.e. the original cuBLAS +default, now called through TunableOp. Any call to at::cuda::blas::gemm() or ::bgemm() will be routed through TunableOp +when enabled. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation across both rocblas and hipblaslt. + +## Tuning Context +The behavior of TunableOp is currently manipulated through environment variables, the C++ interface of +at::cuda::tunable::getTuningContext(), or the `torch.cuda.tunable` python interfaces. The environment variables take +precedence over any setting you manipulate using the C++ or Python APIs. + +### Environment Variable Interface +Environment variables are cached the first time they are read. You cannot use the environment variable interface +programmatically since the settings become fixed. Use the C++ or Python APIs instead. + +| Environment Variable | Description | +| -------------------- | ----------- | +| PYTORCH_TUNABLEOP_ENABLED | Default is 0. Set to 1 to enable. | +| PYTORCH_TUNABLEOP_TUNING | Default is 1. Set to 0 to disable. | +| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. | +| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. | +| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. | +| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. | +| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. | +| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. | +| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. | +| PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS | Default is 100. | +| PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS | Default is 0, meaning it is not used. Unit is milliseconds. | +| PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS | Default is 0, meaning it is not used. | +| PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED | Default is 1. Set to 0 to disable. | +| PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE | Default is to query L2 cache size. Set to 0 to disable. Otherwise, set to the number of MiB to use for the pool of operator parameters. For example, setting this to the size of your device's memory cache will guarantee that every tuning iteration will use a cold cache. | + +### Python Interface +All python APIs exist in the `torch.cuda.tunable` module. + +| Python API | Description | +| ---------- | ----------- | +| enable(val: bool = True) -> None | | +| is_enabled() -> bool | | +| tuning_enable(val: bool = True) -> None | Default is True. | +| tuning_is_enabled() -> bool | | +| set_max_tuning_duration(duration: int) -> None | | +| get_max_tuning_duration() -> int | | +| set_max_tuning_iterations(iterations: int) -> None | | +| get_max_tuning_iterations() -> int | | +| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | | +| get_filename() -> str | | +| get_results() -> Tuple[str, str, str, float] | | +| get_validators() -> Tuple[str, str] | | +| write_file_on_exit(val: bool) -> None | Default is True. | +| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | +| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | + +### C++ Interface +Example: +```C++ +#include + +at::cuda::tunable::getTuningContext()->EnableTunableOp(true); +``` + +| C++ API | Description | +| ------- | ----------- | +| void EnableTunableOp(bool value); | | +| bool IsTunableOpEnabled() const; | | +| void EnableTuning(bool value); | | +| bool IsTuningEnabled() const; | | +| void SetMaxTuningDurationMs(int max_duration_ms); | | +| int GetMaxTuningDurationMs() const; | | +| void SetMaxTuningIterations(int max_iter); | | +| int GetMaxTuningIterations() const; | | +| TuningResults GetTuningResults(); | | +| void SetFilename(const std::string& filename, bool insert_device_ordinal=false); | | +| std::string GetFilename() const; | | +| void WriteFileOnExit(bool value); | | +| bool ReadFile(const std::string& filename={}); | | +| bool WriteFile(const std::string& filename={}); | | diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 22bde7f4c427..d3d2333323e7 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -65,14 +65,14 @@ ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const std::scoped_lock l{lock_}; auto kernel_map_it = results_.find(op_signature); if (kernel_map_it == results_.cend()) { - TUNABLE_LOG("missing op_signature, returning null ResultEntry"); + TUNABLE_LOG3("missing op_signature, returning null ResultEntry"); return ResultEntry::Null(); } const auto& km = kernel_map_it->second; auto it = km.find(params_signature); if (it == km.cend()) { - TUNABLE_LOG("missing params_signature, returning null ResultEntry"); + TUNABLE_LOG3("missing params_signature, returning null ResultEntry"); return ResultEntry::Null(); } return it->second; @@ -85,14 +85,14 @@ inline void TuningResultsManager::AddImpl(const std::string& op_signature, auto it = kernel_map.find(params_signature); if (it != kernel_map.end()) { if (it->second != best) { - TUNABLE_LOG(op_signature, "(", params_signature, ") already has a best kernel ", + TUNABLE_LOG1(op_signature, "(", params_signature, ") already has a best kernel ", "id=", it->second, " selected, want to add a different best kernel ", best, ", the new kernel id will be ignored."); } return; } - TUNABLE_LOG(op_signature, "(", params_signature, ") -> ", best); + TUNABLE_LOG2(op_signature, "(", params_signature, ") -> ", best); kernel_map.emplace(params_signature, best); } @@ -120,7 +120,7 @@ void TuningResultsManager::Delete(const std::string& op_signature, const std::st return; } - TUNABLE_LOG(op_signature, "(", params_signature, ")"); + TUNABLE_LOG2(op_signature, "(", params_signature, ")"); it->second.erase(it2); } @@ -131,7 +131,7 @@ inline void TuningResultsManager::DisjointMergeImpl( auto it = results.find(op_signature); if (it == results.end()) { for (const auto& [param_sig, kernel_id] : kernel_map) { - TUNABLE_LOG(op_signature, "(", param_sig, ") -> ", kernel_id); + TUNABLE_LOG2(op_signature, "(", param_sig, ") -> ", kernel_id); } results[op_signature] = kernel_map; return; @@ -143,7 +143,7 @@ inline void TuningResultsManager::DisjointMergeImpl( } void TuningResultsManager::Load(const std::unordered_map& results_to_load) { - TUNABLE_LOG("Loading results"); + TUNABLE_LOG1("Loading results"); std::scoped_lock l{lock_}; for (const auto& [op_signature, kernel_map] : results_to_load) { DisjointMergeImpl(op_signature, kernel_map, results_); @@ -194,12 +194,12 @@ static bool CheckMandatoryKeys( for (const auto& k : TuningResultsValidator::mandatory_keys) { if (gv_funcs.find(k) == gv_funcs.end()) { passed = false; - TUNABLE_LOG("key=\"", k, "\" is not registered for Get and Validate. "); + TUNABLE_LOG1("key=\"", k, "\" is not registered for Get and Validate. "); } if (to_check.find(k) == to_check.end()) { passed = false; - TUNABLE_LOG("key=\"", k, "\" is not provided for validation. "); + TUNABLE_LOG1("key=\"", k, "\" is not provided for validation. "); } } return passed; @@ -294,10 +294,14 @@ TuningContext::TuningContext() : enable_{false}, tuning_enable_{true}, manager_initialized_{false}, + write_file_on_exit_{true}, + numerics_check_enable_{false}, max_tuning_duration_ms_{30}, max_tuning_iterations_{100}, max_warmup_duration_ms_{0}, max_warmup_iterations_{0}, + icache_flush_{true}, + rotating_buffer_size_{-1}, filename_{}, results_count_from_input_file_{0} { @@ -311,115 +315,158 @@ TuningContext::~TuningContext() { return; } auto filename = GetFilename(); - if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty()) { + if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) { if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) { if (results_count_from_input_file_ > 0) { - TUNABLE_LOG("additional tuning results available, rewriting file ", filename); + TUNABLE_LOG1("additional tuning results available, rewriting file ", filename); } else { - TUNABLE_LOG("writing file ", filename); + TUNABLE_LOG1("writing file ", filename); } if (!WriteFile(filename)) { - TUNABLE_LOG("failed to write file ", filename); + TUNABLE_LOG1("failed to write file ", filename); } } } } -void TuningContext::EnableTunableOp() { - TUNABLE_LOG("Enable TunableOp"); - enable_ = true; -} - -void TuningContext::DisableTunableOp() { - TUNABLE_LOG("Disable TunableOp"); - enable_ = false; +void TuningContext::EnableTunableOp(bool value) { + enable_ = value; + if (value) { + TUNABLE_LOG1("Enable TunableOp"); + } + else { + TUNABLE_LOG1("Disable TunableOp"); + } } bool TuningContext::IsTunableOpEnabled() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_ENABLED"); if (env != nullptr && strcmp(env, "1") == 0) { - //TUNABLE_LOG("PYTORCH_TUNABLEOP_ENABLED=1"); return true; } return enable_; } -void TuningContext::EnableTuning() { - TUNABLE_LOG("Enable Tuning for TunableOp"); - tuning_enable_ = true; -} - -void TuningContext::DisableTuning() { - TUNABLE_LOG("Disable Tuning for TunableOp"); - tuning_enable_ = false; +void TuningContext::EnableTuning(bool value) { + tuning_enable_ = value; + if (value) { + TUNABLE_LOG1("Enable Tuning for TunableOp"); + } + else { + TUNABLE_LOG1("Disable Tuning for TunableOp"); + } } bool TuningContext::IsTuningEnabled() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING"); if (env != nullptr && strcmp(env, "0") == 0) { - //TUNABLE_LOG("PYTORCH_TUNABLEOP_TUNING=1"); return false; } return tuning_enable_; } +void TuningContext::WriteFileOnExit(bool value) { + write_file_on_exit_ = value; +} + +void TuningContext::EnableNumericsCheck(bool value) { + numerics_check_enable_ = value; +} + +bool TuningContext::IsNumericsCheckEnabled() const { + static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); + if (env != nullptr && strcmp(env, "1") == 0) { + return true; + } + return numerics_check_enable_; +} + void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) { - max_tuning_duration_ms_ = max_duration_ms; + max_tuning_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms; } int TuningContext::GetMaxTuningDurationMs() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_tuning_duration_ms_; } void TuningContext::SetMaxTuningIterations(int max_iter) { - max_tuning_iterations_ = max_iter; + max_tuning_iterations_ = max_iter < 0 ? 0 : max_iter; } int TuningContext::GetMaxTuningIterations() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_tuning_iterations_; } void TuningContext::SetMaxWarmupDurationMs(int max_duration_ms) { - max_warmup_duration_ms_ = max_duration_ms; + max_warmup_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms; } int TuningContext::GetMaxWarmupDurationMs() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_warmup_duration_ms_; } void TuningContext::SetMaxWarmupIterations(int max_iter) { - max_warmup_iterations_ = max_iter; + max_warmup_iterations_ = max_iter < 0 ? 0 : max_iter; } int TuningContext::GetMaxWarmupIterations() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS"); if (env != nullptr) { - return atoi(env); + int val = atoi(env); + return val < 0 ? 0 : val; } return max_warmup_iterations_; } -void TuningContext::EnableTunableOpAndTuning() { - EnableTunableOp(); - EnableTuning(); +void TuningContext::EnableICacheFlush(bool value) { + icache_flush_ = value; } -void TuningContext::DisableTunableOpAndTuning() { - DisableTunableOp(); - DisableTuning(); +bool TuningContext::IsICacheFlushEnabled() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED"); + if (env != nullptr && strcmp(env, "0") == 0) { + return false; + } + return icache_flush_; +} + +void TuningContext::SetRotatingBufferSize(int size) { + rotating_buffer_size_ = size < 0 ? 0 : size; +} + +int TuningContext::GetRotatingBufferSize() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"); + if (env != nullptr) { + constexpr int MB = 1024 * 1024; + int val = atoi(env); + return val < 0 ? 0 : val * MB; // env var is specified as MB, returned as bytes + } + else { + if (rotating_buffer_size_ < 0) { + // negative buffer size (default) means query for L2 cache size + int l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; + return l2_cache_size; + } + else { + return rotating_buffer_size_; + } + } } TuningResultsManager& TuningContext::GetTuningResultsManager() { @@ -429,7 +476,7 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() { // if SetFilename() was not already called, call it now with the default or env var const char *env = std::getenv("PYTORCH_TUNABLEOP_FILENAME"); std::string filename = (env == nullptr) ? "tunableop_results.csv" : env; - SetFilename(filename); + SetFilename(filename, true); } auto filename = GetFilename(); if (!filename.empty()) { @@ -461,32 +508,34 @@ TuningStatus TuningContext::LoadTuningResults(const TuningResults& tr) { return OK; } -void TuningContext::SetFilename(const std::string& filename) { +void TuningContext::SetFilename(const std::string& filename, bool insert_device_ordinal) { filename_ = filename; if (filename_.empty()) { return; } - // differentiate filename based on device ordinal to avoid - // use case of one process per device writing to same file - std::string device = c10::str(int(c10::cuda::current_device())); + if (insert_device_ordinal) { + // differentiate filename based on device ordinal to avoid + // use case of one process per device writing to same file + std::string device = c10::str(int(c10::cuda::current_device())); - // does filename contain %d to insert device ordinal in specific location? - const std::string TOKEN("%d"); - std::size_t found = filename_.find(TOKEN); - if (found != std::string::npos) { - filename_.replace(found, TOKEN.length(), device); - } - else { - // no %d present, so append device ordinal before final '.' - found = filename_.rfind("."); + // does filename contain %d to insert device ordinal in specific location? + const std::string TOKEN("%d"); + std::size_t found = filename_.find(TOKEN); if (found != std::string::npos) { - filename_.insert(found, device); + filename_.replace(found, TOKEN.length(), device); } else { - // all else fails, just append - filename_.append(device); + // no %d present, so append device ordinal before final '.' + found = filename_.rfind("."); + if (found != std::string::npos) { + filename_.insert(found, device); + } + else { + // all else fails, just append + filename_.append(device); + } } } } @@ -495,14 +544,15 @@ std::string TuningContext::GetFilename() const { return filename_; } -bool TuningContext::ReadFile(const std::string& filename) { - TUNABLE_LOG("reading tuning results from ", filename); +bool TuningContext::ReadFile(const std::string& filename_) { + std::string filename = filename_.empty() ? GetFilename() : filename_; + TUNABLE_LOG1("reading tuning results from ", filename); ResultsMap results; std::unordered_map validators; std::string line; std::ifstream file(filename); if (!file) { - TUNABLE_LOG("could not open ", filename, " for reading tuning results"); + TUNABLE_LOG1("could not open ", filename, " for reading tuning results"); return false; } while (std::getline(file, line)) { @@ -517,7 +567,7 @@ bool TuningContext::ReadFile(const std::string& filename) { } if (parts[0] == "Validator" && parts.size() >= 3) { validators[parts[1]] = parts[2]; - TUNABLE_LOG("Validator ", parts[1], "=", parts[2]); + TUNABLE_LOG1("Validator ", parts[1], "=", parts[2]); } else if (parts.size() >= 4) { results[parts[0]].emplace(parts[1], ResultEntry(parts[2], atof(parts[3].c_str()))); @@ -527,7 +577,7 @@ bool TuningContext::ReadFile(const std::string& filename) { results[parts[0]].emplace(parts[1], ResultEntry(parts[2], 0)); } else { - TUNABLE_LOG("could not parse line: ", line); + TUNABLE_LOG1("could not parse line: ", line); } } if (GetTuningResultsValidator().ValidateAll(validators) != FAIL) { @@ -535,16 +585,17 @@ bool TuningContext::ReadFile(const std::string& filename) { results_count_from_input_file_ = manager_.GetSize(); } else { - TUNABLE_LOG("results validator check failed"); + TUNABLE_LOG1("results validator check failed"); return false; } return true; } -bool TuningContext::WriteFile(const std::string& filename) { +bool TuningContext::WriteFile(const std::string& filename_) { + std::string filename = filename_.empty() ? GetFilename() : filename_; std::ofstream file(filename, std::ios::out | std::ios::trunc); if (!file.good()) { - TUNABLE_LOG("error opening tuning results file for writing ", filename); + TUNABLE_LOG1("error opening tuning results file for writing ", filename); return false; } auto validators = GetTuningResultsValidator().GetAllValidators(); diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index eb849a213fe5..243031cf3da2 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -23,27 +24,58 @@ namespace at::cuda::tunable { -static void TunableLog(const std::string& msg) { - static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE"); - if (env != nullptr && strcmp(env, "1") == 0) { - std::cerr << msg << std::endl; +namespace detail { + +struct MaybeDelete { + bool owns_pointer; + void operator()(std::ostream* os) const { if (owns_pointer) delete os; } +}; + +using OstreamPtr = std::unique_ptr; + +static OstreamPtr get_stream(std::string filename) { + if (filename.compare("out") == 0) { + return OstreamPtr { &std::cout, MaybeDelete {false} }; } + else if (filename.compare("err") == 0) { + return OstreamPtr { &std::cerr, MaybeDelete {false} }; + } + else { + return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} }; + } +} + } -#define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__)) -enum TuningStatus { +static void TunableLog(int level, const std::string& msg) { + static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME"); + static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE"); + static int level_user = env_verbose ? atoi(env_verbose) : 0; + static auto streamptr = detail::get_stream(env_file ? env_file : "err"); + if (level_user >= level) { + (*streamptr) << msg < KernelMap; typedef std::unordered_map ResultsMap; -struct TuningResults { +struct TORCH_CUDA_CPP_API TuningResults { // Validates if these results are compatible with the libraries std::unordered_map validators; @@ -64,7 +96,7 @@ struct TuningResults { ResultsMap results; }; -class TuningResultsManager { +class TORCH_CUDA_CPP_API TuningResultsManager { public: TuningResultsManager() = default; ~TuningResultsManager() = default; @@ -102,7 +134,7 @@ class TuningResultsManager { ResultsMap results_; }; -class TuningResultsValidator { +class TORCH_CUDA_CPP_API TuningResultsValidator { public: using GetFunc = std::function; using ValidateFunc = std::function; @@ -126,7 +158,7 @@ class TuningResultsValidator { GetValidateFuncs validators_; }; -class TuningContext { +class TORCH_CUDA_CPP_API TuningContext { public: TuningContext(); ~TuningContext(); @@ -135,14 +167,15 @@ class TuningContext { TuningContext &operator=(TuningContext &) = delete; TuningContext &operator=(TuningContext &&) = delete; - void EnableTunableOp(); - void DisableTunableOp(); + void EnableTunableOp(bool value); bool IsTunableOpEnabled() const; - void EnableTuning(); - void DisableTuning(); + void EnableTuning(bool value); bool IsTuningEnabled() const; + void EnableNumericsCheck(bool value); + bool IsNumericsCheckEnabled() const; + void SetMaxTuningDurationMs(int max_duration_ms); int GetMaxTuningDurationMs() const; @@ -155,8 +188,11 @@ class TuningContext { void SetMaxWarmupIterations(int max_iter); int GetMaxWarmupIterations() const; - void EnableTunableOpAndTuning(); - void DisableTunableOpAndTuning(); + void EnableICacheFlush(bool value); + bool IsICacheFlushEnabled() const; + + void SetRotatingBufferSize(int size); + int GetRotatingBufferSize() const; TuningResultsManager& GetTuningResultsManager(); @@ -166,21 +202,26 @@ class TuningContext { TuningStatus LoadTuningResults(const TuningResults& tr); - void SetFilename(const std::string& filename); + void SetFilename(const std::string& filename, bool insert_device_ordinal=false); std::string GetFilename() const; - protected: - bool ReadFile(const std::string& filename); - bool WriteFile(const std::string& filename); + void WriteFileOnExit(bool value); + + bool ReadFile(const std::string& filename={}); + bool WriteFile(const std::string& filename={}); private: bool enable_; bool tuning_enable_; bool manager_initialized_; + bool write_file_on_exit_; + bool numerics_check_enable_; int max_tuning_duration_ms_; int max_tuning_iterations_; int max_warmup_duration_ms_; int max_warmup_iterations_; + bool icache_flush_; + int rotating_buffer_size_; mutable TuningResultsManager manager_; mutable c10::once_flag manager_init_once_; TuningResultsValidator validator_; @@ -188,7 +229,7 @@ class TuningContext { size_t results_count_from_input_file_; }; -TuningContext* getTuningContext(); +TORCH_CUDA_CPP_API TuningContext* getTuningContext(); class ITimer { public: diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 1eaf251caad7..6b02e26ade4d 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -48,6 +48,28 @@ class DefaultGemmOp : public Callable> { } }; +static bool _transposeBoolFromChar(char op) { + return op == 't' || op == 'T'; +} + +template +class DefaultGemmAndBiasOp : public Callable> { + public: + TuningStatus Call(const GemmAndBiasParams* params) override { + at::cuda::blas::gemm_and_bias( + _transposeBoolFromChar(params->transa), + _transposeBoolFromChar(params->transb), + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->bias, + params->c, params->ldc, + params->activation); + return OK; + } +}; + template class DefaultGemmStridedBatchedOp : public Callable> { public: @@ -175,6 +197,56 @@ inline std::string TypeName(c10::complex v) { return "c10::complex"; } +#ifdef USE_ROCM +static void AddRocblasValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("ROCBLAS_VERSION") == validators.end()) { + std::string rocblas_version = c10::str( + XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", + XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCBLAS_VERSION", + [rocblas_version]() { return rocblas_version; }, + [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + } +} + +static void AddHipblasltValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("HIPBLASLT_VERSION") == validators.end()) { + std::string hipblaslt_version = c10::str( + XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", + XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "HIPBLASLT_VERSION", + [hipblaslt_version]() { return hipblaslt_version; }, + [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + } +} + +static void AddRocmValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("ROCM_VERSION") == validators.end()) { + std::string rocm_version = ROCM_BUILD_INFO; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCM_VERSION", + [rocm_version]() { return rocm_version; }, + [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + } + + if (validators.find("GCN_ARCH_NAME") == validators.end()) { + std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "GCN_ARCH_NAME", + [gcn_arch_name]() { return gcn_arch_name; }, + [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + } +} +#endif template class GemmTunableOp : public TunableOp, StreamTimer> { @@ -182,71 +254,78 @@ class GemmTunableOp : public TunableOp, StreamTimer> { GemmTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); - auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); - #ifdef USE_ROCM - for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { - this->RegisterOp(std::move(name), std::move(op)); - } + bool rocm_validators = false; - if (validators.find("ROCM_VERSION") == validators.end()) { - std::string rocm_version = ROCM_BUILD_INFO; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCM_VERSION", - [rocm_version]() { return rocm_version; }, - [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + rocm_validators = true; + for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + AddRocblasValidator(); } - if (validators.find("GCN_ARCH_NAME") == validators.end()) { - std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "GCN_ARCH_NAME", - [gcn_arch_name]() { return gcn_arch_name; }, - [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + AddHipblasltValidator(); } - if (validators.find("ROCBLAS_VERSION") == validators.end()) { - std::string rocblas_version = c10::str( - XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", - XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCBLAS_VERSION", - [rocblas_version]() { return rocblas_version; }, - [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + if (rocm_validators) { + AddRocmValidator(); } #endif + } + + std::string Signature() override { + static std::string val = c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; + } +}; + +template +class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { + public: + GemmAndBiasTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); #if defined(USE_ROCM) - static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env == nullptr || strcmp(env, "1") == 0) { + bool rocm_validators = false; + + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && !std::is_same_v>) { - for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { + for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } + AddHipblasltValidator(); + } - if (validators.find("HIPBLASLT_VERSION") == validators.end()) { - std::string hipblaslt_version = c10::str( - XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", - XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "HIPBLASLT_VERSION", - [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); - } + if (rocm_validators) { + AddRocmValidator(); } #endif } std::string Signature() override { - return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + static std::string val = c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; } }; @@ -256,45 +335,21 @@ class GemmStridedBatchedTunableOp : public TunableOp GemmStridedBatchedTunableOp() { this->RegisterOp(std::string("Default"), std::make_unique>()); - auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); - #ifdef USE_ROCM - for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { - this->RegisterOp(std::move(name), std::move(op)); - } - - if (validators.find("ROCM_VERSION") == validators.end()) { - std::string rocm_version = ROCM_BUILD_INFO; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCM_VERSION", - [rocm_version]() { return rocm_version; }, - [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); - } - - if (validators.find("GCN_ARCH_NAME") == validators.end()) { - std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "GCN_ARCH_NAME", - [gcn_arch_name]() { return gcn_arch_name; }, - [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); - } + bool rocm_validators = false; - if (validators.find("ROCBLAS_VERSION") == validators.end()) { - std::string rocblas_version = c10::str( - XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", - XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", - XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "ROCBLAS_VERSION", - [rocblas_version]() { return rocblas_version; }, - [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + rocm_validators = true; + for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + AddRocblasValidator(); } -#endif -#if defined(USE_ROCM) - static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env == nullptr || strcmp(env, "1") == 0) { + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && @@ -303,24 +358,18 @@ class GemmStridedBatchedTunableOp : public TunableOp this->RegisterOp(std::move(name), std::move(op)); } } + AddHipblasltValidator(); + } - if (validators.find("HIPBLASLT_VERSION") == validators.end()) { - std::string hipblaslt_version = c10::str( - XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", - XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "HIPBLASLT_VERSION", - [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); - } + if (rocm_validators) { + AddRocmValidator(); } #endif } std::string Signature() override { - return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + static std::string val = c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; } }; @@ -336,27 +385,18 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } - - if (validators.find("HIPBLASLT_VERSION") == validators.end()) { - std::string hipblaslt_version = c10::str( - XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", - XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", - XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); - getTuningContext()->GetTuningResultsValidator().RegisterValidator( - "HIPBLASLT_VERSION", - [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); - } + AddHipblasltValidator(); + AddRocmValidator(); #endif } std::string Signature() override { - return c10::str("ScaledGemmTunableOp", + static std::string val = c10::str("ScaledGemmTunableOp", "_", TypeName(AT{}), "_", TypeName(BT{}), "_", TypeName(CT{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return val; } }; diff --git a/aten/src/ATen/cuda/tunable/TunableOp.h b/aten/src/ATen/cuda/tunable/TunableOp.h index 65257974ab0c..f158e11cef0a 100644 --- a/aten/src/ATen/cuda/tunable/TunableOp.h +++ b/aten/src/ATen/cuda/tunable/TunableOp.h @@ -10,6 +10,7 @@ #pragma once #include +#include #include #ifndef _WIN32 @@ -62,7 +63,7 @@ class TunableOp { result = ResultEntry::Default(); } if (result == ResultEntry::Null()) { - TUNABLE_LOG("no result, using default"); + TUNABLE_LOG2("no result, using default"); result = ResultEntry::Default(); } auto iter = ops_.find(result); @@ -87,88 +88,120 @@ class TunableOp { } private: - static void WarmUp(Callable *op, ParamsT* param, size_t num_iter) { + static void WarmUp(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); for (size_t i = 0; i < num_iter; i++) { - TORCH_CHECK(op->Call(param) == OK); + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); } } - static double Profile(Callable *op, ParamsT* param, size_t num_iter) { + static double Profile(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); TimerT timer{}; timer.Start(); for (size_t i = 0; i < num_iter; i++) { - TORCH_CHECK(op->Call(param) == OK); + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); } timer.End(); return timer.Duration() / num_iter; } protected: - bool IsNumericsCheckEnabled() { - static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); - if (env != nullptr && strcmp(env, "0") == 0) { - return false; - } - return true; - } - virtual ResultEntry FindFastest(const ParamsT* params) { TuningContext* ctx = getTuningContext(); auto op_sig = Signature(); auto params_sig = params->Signature(); - TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); + TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); auto min_duration_ms = std::numeric_limits::infinity(); std::string id_name = "Default"; + ParamsT* reference_params = nullptr; // calcaulte a reference answer for numerical check - ParamsT* reference_params = params->DeepCopy(); - TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + if (ctx->IsNumericsCheckEnabled()) { + reference_params = params->DeepCopy(false); + TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + } + + // need copies of params to reuse + // make as many copies as will fill the requested rotating buffer size, if requested + // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int + size_t rotating_size = ctx->GetRotatingBufferSize(); + bool use_buffer_rotation = (rotating_size > 0); + size_t param_size = params->GetSize(use_buffer_rotation); + size_t param_count = (rotating_size / param_size) + 1; + constexpr size_t MB = 1024*1024; + if (use_buffer_rotation) { + TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", + "Needed Size: ", param_size/MB, " MiB. ", + "Needed number of param copies: ", param_count); + } + TORCH_CHECK(param_count > 0); + + std::vector reusable_params(param_count); + for (size_t i = 0; i < param_count; i++) { + reusable_params[i] = params->DeepCopy(use_buffer_rotation); + } - // need a copy of params to reuse - ParamsT* reusable_params = params->DeepCopy(); + // for rotating buffer + size_t offset = 0; for (size_t i = 0; i < op_names_.size(); i++) { auto* candidate = ops_[op_names_[i]].get(); // borrow pointer - auto status = candidate->Call(reusable_params); - if (status != OK) { - TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); - continue; - } - if (IsNumericsCheckEnabled()) { - ParamsT* numerical_params = params->DeepCopy(); - WarmUp(candidate, numerical_params, 1); + if (ctx->IsNumericsCheckEnabled()) { + ParamsT* numerical_params = params->DeepCopy(false); + auto status = candidate->Call(numerical_params); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } status = reference_params->NumericalCheck(numerical_params); numerical_params->Delete(); if (status != OK) { - TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + else { + auto status = candidate->Call(reusable_params[0]); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); continue; } } // collect a small profile constexpr const int approx_num_iter = 3; - auto approx_duration = Profile(candidate, reusable_params, approx_num_iter); + auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset); // bail if too slow if (approx_duration > 2 * min_duration_ms) { - TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); continue; } // for warmup does user set max duration, max iters, or both? + // warmup is allowed to be skipped by setting either iterations or duration to 0 double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); int max_warmup_iter = ctx->GetMaxWarmupIterations(); int warmup_iter = 1; // default - if (max_warmup_duration > 0) { + if (max_warmup_duration >= 0) { int duration_iters = max_warmup_duration / approx_duration; - if (max_warmup_iter > 0) { + if (max_warmup_iter >= 0) { warmup_iter = std::min(max_warmup_iter, duration_iters); } else { warmup_iter = duration_iters; } } - else if (max_warmup_iter > 0) { + else if (max_warmup_iter >= 0) { warmup_iter = max_warmup_iter; } @@ -188,27 +221,34 @@ class TunableOp { else if (max_tuning_iter > 0) { tuning_iter = max_tuning_iter; } + // tuning must run at least 1 iteration + tuning_iter = std::max(1, tuning_iter); // do the full warmup followed by tuning double warmup_ms = warmup_iter * approx_duration; double tuning_ms = tuning_iter * approx_duration; - TUNABLE_LOG("├──tuning using " + TUNABLE_LOG3("├──tuning using " "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); - WarmUp(candidate, reusable_params, warmup_iter); - auto duration_ms = Profile(candidate, reusable_params, tuning_iter); + TUNABLE_LOG3("├──offset at ", offset); + WarmUp(candidate, reusable_params, warmup_iter, offset); + auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset); if (duration_ms < min_duration_ms) { - TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); + TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); min_duration_ms = duration_ms; id_name = op_names_[i]; } } - reusable_params->Delete(); - reference_params->Delete(); + for (size_t i = 0; i < reusable_params.size(); i++) { + reusable_params[i]->Delete(); + } + if (reference_params) { + reference_params->Delete(); + } - TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); + TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); return ResultEntry(id_name, min_duration_ms); } diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index dea35e671267..814b5aeb72d8 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -57,6 +57,9 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { virtual size_t getDriverAllocatedMemory() const { FAIL_MPSHOOKS_FUNC(__func__); } + virtual size_t getRecommendedMaxMemory() const { + FAIL_MPSHOOKS_FUNC(__func__); + } virtual void setMemoryFraction(double /*ratio*/) const { FAIL_MPSHOOKS_FUNC(__func__); } diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 3e064d6c39dc..a0007aa18a00 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -324,6 +324,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); + OP_DECOMPOSE(alias_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(swapdims_); diff --git a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp index d8213a1b9e0d..85210d0b214c 100644 --- a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp @@ -59,13 +59,6 @@ view_as_complex_batch_rule(const Tensor& self, optional self_bdim) { return std::make_tuple(result, 0); } -std::tuple> -to_other_batch_rule(const Tensor& self, optional self_bdim, - const Tensor& other, optional other_bdim, - bool non_blocking, - bool copy, std::optional memory_format) { - return std::make_tuple(self.to(other, non_blocking, copy, memory_format), self_bdim); -} } TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { diff --git a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp index ce3f20ef97ef..e9e7b2a99553 100644 --- a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp +++ b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp @@ -31,46 +31,6 @@ Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, in return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad); } -static optional> unwrap(const Tensor& tensor) { - auto* wrapped = maybeGetTensorWrapper(tensor); - if (wrapped) { - if (wrapped->level().has_value()) { - return std::make_tuple(wrapped->value(), *wrapped->level()); - } - return unwrap(wrapped->value()); - } - auto* batched = maybeGetBatchedImpl(tensor); - if (batched) { - return std::make_tuple(batched->value(), batched->level()); - } - return nullopt; -} - -static bool can_perform_inplace(const Tensor& a, const Tensor& b) { - // TODO: generalize this to more transforms - auto a_ = unwrap(a); - auto b_ = unwrap(b); - if (!a_.has_value() && b_.has_value()) { - return false; - } - if (!a_.has_value() && !b_.has_value()) { - return true; - } - if (a_.has_value() && !b_.has_value()) { - return true; - } - TORCH_INTERNAL_ASSERT(a_.has_value() && b_.has_value()); - - // If b has any wrapper that a does not, then we cannot do a.inplace_(b) - if (std::get<1>(*a_) < std::get<1>(*b_)) { - return false; - } - if (std::get<1>(*a_) > std::get<1>(*b_)) { - return can_perform_inplace(std::get<0>(*a_), b); - } - return can_perform_inplace(std::get<0>(*a_), std::get<0>(*b_)); -} - // TODO: linear is pretty important for performance, but I'm not sure how to work // around the in-place. Tensor linear_hack(const Tensor& input, const Tensor& weight, const std::optional& bias_opt) { diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index bdf19e8d7362..1dc8c434f85b 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -308,6 +308,8 @@ class MPSHeapAllocatorImpl { // total GPU memory allocated in the process by Metal driver; including // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl. size_t getDriverAllocatedMemory() const { return current_allocated_size(); } + // recommended Max memory for Metal + size_t getRecommendedMaxMemory() const { return max_device_size(); } // (see enum DebugVerbosity for description) uint32_t getDebugVerbosity() const { return m_debug_verbosity; } // returns the device that we allocate from diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 76280fb469e5..0c2a86948a4c 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -794,6 +794,9 @@ size_t getCurrentAllocatedMemory() const override { size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); } + size_t getRecommendedMaxMemory() const override { + return _getAllocImpl().getRecommendedMaxMemory(); + } ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); } diff --git a/aten/src/ATen/mps/MPSAllocatorInterface.h b/aten/src/ATen/mps/MPSAllocatorInterface.h index e30a02c3fb21..cce232fd6937 100644 --- a/aten/src/ATen/mps/MPSAllocatorInterface.h +++ b/aten/src/ATen/mps/MPSAllocatorInterface.h @@ -33,6 +33,7 @@ class IMPSAllocator : public c10::Allocator { virtual size_t getTotalAllocatedMemory() const = 0; virtual size_t getCurrentAllocatedMemory() const = 0; virtual size_t getDriverAllocatedMemory() const = 0; + virtual size_t getRecommendedMaxMemory() const = 0; virtual std::pair getSharedBufferPtr(const void* ptr) const = 0; virtual bool recordEvents(c10::ArrayRef buffers) const = 0; virtual bool waitForEvents(c10::ArrayRef buffers) const = 0; diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 667430eaf811..dea8f25fa7fd 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -32,6 +32,7 @@ struct MPSHooks : public at::MPSHooksInterface { void emptyCache() const override; size_t getCurrentAllocatedMemory() const override; size_t getDriverAllocatedMemory() const override; + size_t getRecommendedMaxMemory() const override; void setMemoryFraction(double ratio) const override; // MPSProfiler interface diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 387359592a74..285c0771c3c6 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -80,6 +80,10 @@ return at::mps::getIMPSAllocator()->getDriverAllocatedMemory(); } +size_t MPSHooks::getRecommendedMaxMemory() const { + return at::mps::getIMPSAllocator()->getRecommendedMaxMemory(); +} + void MPSHooks::setMemoryFraction(double ratio) const { at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio); } diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index dca6a39a0970..dc84547b7fe1 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -23,7 +23,7 @@ enum class GeluType { END }; -static GeluType get_gelutype_enum(const c10::string_view approximate) { +inline GeluType get_gelutype_enum(const c10::string_view approximate) { if (approximate == "none") { return GeluType::None; } else if (approximate == "tanh") { @@ -33,7 +33,7 @@ static GeluType get_gelutype_enum(const c10::string_view approximate) { } } -static std::string gelutype_to_string(const GeluType type) { +inline std::string gelutype_to_string(const GeluType type) { switch(type) { case GeluType::None: return "none"; case GeluType::Tanh: return "tanh"; diff --git a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp index bbd4f68d40d0..54cde5aad4c0 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp @@ -310,7 +310,7 @@ Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_siz TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3"); TORCH_CHECK( (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0), - "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 ", + "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ", "but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}"); if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1 && !input.is_xpu()) { diff --git a/aten/src/ATen/native/AdaptivePooling.h b/aten/src/ATen/native/AdaptivePooling.h index bb2fda9906ab..6c49fd38d940 100644 --- a/aten/src/ATen/native/AdaptivePooling.h +++ b/aten/src/ATen/native/AdaptivePooling.h @@ -28,15 +28,15 @@ using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, con DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel); DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel); -static inline int64_t start_index(int64_t a, int64_t b, int64_t c) { +inline int64_t start_index(int64_t a, int64_t b, int64_t c) { return (a / b) * c + ((a % b) * c) / b; } -static inline int64_t end_index(int64_t a, int64_t b, int64_t c) { +inline int64_t end_index(int64_t a, int64_t b, int64_t c) { return 1 + ((a + 1) * c - 1) / b; } -static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { +inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { int64_t ndim = gradOutput_.ndimension(); for (const auto i : c10::irange(1, ndim)) { TORCH_CHECK(gradOutput_.size(i) > 0, diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 19c70672fb93..3fe3ac2b4a25 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -1480,23 +1480,14 @@ Tensor& not_equal_(Tensor& self, const Scalar& other) { return self.ne_(other); Tensor& logical_and_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_and_stub); } Tensor logical_and(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast(at::logical_and_out)); } Tensor& logical_and_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast(at::logical_and_out)); } -static Tensor& logical_and_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast(at::logical_and_out)); } -static Tensor logical_and(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast(at::logical_and_out)); } -static Tensor& logical_and_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast(at::logical_and_out)); } Tensor& logical_or_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_or_stub); } Tensor logical_or(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast(at::logical_or_out)); } Tensor& logical_or_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast(at::logical_or_out)); } -static Tensor& logical_or_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast(at::logical_or_out)); } -static Tensor logical_or(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast(at::logical_or_out)); } -static Tensor& logical_or_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast(at::logical_or_out)); } Tensor& logical_xor_out(const Tensor& self, const Tensor& other, Tensor& result) { return comparison_op_out(result, self, other, logical_xor_stub); } Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast(at::logical_xor_out)); } Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast(at::logical_xor_out)); } -static Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Scalar& other) { return comparison_op_out(result, self, other, static_cast(at::logical_xor_out)); } -static Tensor logical_xor(const Tensor& self, const Scalar& other) { return comparison_op(self, other, static_cast(at::logical_xor_out)); } -static Tensor& logical_xor_(Tensor& self, const Scalar& other) { return comparison_op_(self, other, static_cast(at::logical_xor_out)); } // binary max, alias for maximum Tensor& max_out(const Tensor& self, const Tensor& other, Tensor& result) { diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index bc601885b54e..97f04c9968c8 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -105,6 +105,28 @@ void fp16_gemv_trans( const float beta, float16_t* y, const int incy); + +float fp16_dot_with_fp32_arith( + const float16_t* vec1, + const float16_t* vec2, + int64_t len); + +void bf16_gemv_trans( + const int m, + const int n, + const at::BFloat16 alpha, + const at::BFloat16* a, + const int lda, + const at::BFloat16* x, + const int incx, + const at::BFloat16 beta, + at::BFloat16* y, + const int incy); + +float bf16_dot_with_fp32_arith( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + int64_t len); #endif template @@ -113,8 +135,11 @@ bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { } template -bool gemv_use_fast_path(C10_UNUSED int64_t m, C10_UNUSED int64_t n, - C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, C10_UNUSED int64_t incy) { +bool gemv_use_fast_path(C10_UNUSED char trans, C10_UNUSED int64_t m, + C10_UNUSED int64_t n, C10_UNUSED scalar_t alpha, + C10_UNUSED int64_t lda, + C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta, + C10_UNUSED int64_t incy) { return false; } @@ -133,7 +158,7 @@ void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_U #define INSTANTIATE(scalar_t) \ template bool scal_use_fast_path(int64_t n, int64_t incx); \ -template bool gemv_use_fast_path(int64_t m, int64_t n, int64_t lda, int64_t incx, int64_t incy); \ +template bool gemv_use_fast_path(char trans, int64_t m, int64_t n, scalar_t alpha, int64_t lda, int64_t incx, scalar_t beta, int64_t incy); \ template void gemv_fast_path(const char *trans, const int *m, const int *n, const scalar_t *alpha, const scalar_t *a, const int *lda, const scalar_t *x, const int *incx, const scalar_t *beta, scalar_t *y, const int *incy); \ template void scal_fast_path(int *n, scalar_t *a, scalar_t *x, int *incx); @@ -160,15 +185,15 @@ void scal_fast_path(int *n, float *a, float *x, int *incx) { } template <> -bool gemv_use_fast_path(int64_t m, int64_t n, int64_t lda, int64_t incx, int64_t incy) { +bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) { auto intmax = std::numeric_limits::max(); return (m <= intmax) && (n <= intmax) && (lda <= intmax) && (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); } template <> -bool gemv_use_fast_path(int64_t m, int64_t n, int64_t lda, int64_t incx, int64_t incy) { - return gemv_use_fast_path(m, n, lda, incx, incy); +bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) { + return gemv_use_fast_path(trans, m, n, (float)alpha, lda, incx, (float)beta, incy); } template <> @@ -190,7 +215,6 @@ INSTANTIATE(int8_t); INSTANTIATE(int16_t); INSTANTIATE(int); INSTANTIATE(int64_t); -INSTANTIATE(c10::BFloat16); #if defined(__aarch64__) && !defined(C10_MOBILE) template <> bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { @@ -199,14 +223,32 @@ bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) template <> bool gemv_use_fast_path( + C10_UNUSED char trans, C10_UNUSED int64_t m, C10_UNUSED int64_t n, + at::Half alpha, C10_UNUSED int64_t lda, C10_UNUSED int64_t incx, + at::Half beta, C10_UNUSED int64_t incy) { - return true; + return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f && + c10::detail::fp16_from_bits(beta.x) == 0.0f; } +template <> +bool gemv_use_fast_path( + C10_UNUSED char trans, + C10_UNUSED int64_t m, + C10_UNUSED int64_t n, + at::BFloat16 alpha, + C10_UNUSED int64_t lda, + C10_UNUSED int64_t incx, + at::BFloat16 beta, + C10_UNUSED int64_t incy) { + return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0; +} + + #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static inline float16_t reduce(float16x4_t x) { auto sum = vpadd_f16(x, x); @@ -218,10 +260,9 @@ static inline float16_t reduce(float16x8_t x) { /* * NOTE [ GGML Copyright Notice ] - * The below reduce overload and - * fp16_gemv_trans_fp16_arith_by_dot_products function is adapted from - * llama.cpp's ggml_vec_dot_f16 and surrounding utility functions, so - * here is the required copyright notice: + * The below reduce overload and fp16_dot_with_fp16_arith function is + * adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility + * functions, so here is the required copyright notice: * * MIT License * @@ -279,29 +320,33 @@ static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) { #endif } +static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, int len) { + float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)}; + + const auto len_aligned = len & ~(kF16ElementsPerIteration - 1); + for (int j = 0; j < len_aligned ; j += kF16ElementsPerIteration) { + for (int k = 0; k < kF16RegistersPerIteration; ++k) { + const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister); + const auto temp_a = vld1q_f16(a + j + k * kF16ElementsPerRegister); + sum[k] = f16_fma(sum[k], temp_x, temp_a); + } + } + auto reducedSum = reduce(sum); + + for (int j = len_aligned; j < len; ++j) { + reducedSum += x[j] * a[j]; + } + return reducedSum; +} + // Rather than unrolling to process multiple rows (transposed columns) // of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll // along an individual dot product. static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { - for (int i = begin; i < end; ++i) { - float16x8_t sum[kF16RegistersPerIteration] = {vdupq_n_f16(0)}; - - const auto m_aligned = m & ~(kF16ElementsPerIteration - 1); - for (int j = 0; j < m_aligned ; j += kF16ElementsPerIteration) { - for (int k = 0; k < kF16RegistersPerIteration; ++k) { - const auto temp_x = vld1q_f16(x + j + k * kF16ElementsPerRegister); - const auto temp_a = vld1q_f16(a + lda * i + j + k * kF16ElementsPerRegister); - sum[k] = f16_fma(sum[k], temp_x, temp_a); - } - } - auto reducedSum = reduce(sum); - - for (int j = m_aligned; j < m; ++j) { - reducedSum += x[j] * a[lda * i + j]; - } - y[i * incy] = reducedSum; - } + for (int i = begin; i < end; ++i) { + y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m); + } }); } @@ -341,10 +386,14 @@ static inline float32x4_t f32_fma_high_f16(float32x4_t a, float16x8_t b, float16 #endif } -// The below reduce overload and -// fp16_gemv_trans_fp32_arith_by_dot_products are adapted from -// llama.cpp's ggml_vec_dot_f32 and surrounding utility functions. See -// NOTE [ GGML Copyright Notice ] above for the required notice. +static inline float32x4_t f32_fma_f16(float32x4_t a, float16x4_t b, float16x4_t c) { + return f32_fma_low_f16(a, vcombine_f16(b, vdup_n_f16(0)), vcombine_f16(c, vdup_n_f16(0))); +} + +// The below reduce overload and fp16_dot_with_fp32_arith are adapted +// from llama.cpp's ggml_vec_dot_f32 and surrounding utility +// functions. See NOTE [ GGML Copyright Notice ] above for the +// required notice. // We need the shift for reduce(), hence the extra constants. static constexpr auto kF32ElementsPerIterationShift = 5; @@ -372,32 +421,117 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { return vaddvq_f32(x[0]); } +static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( + const float16_t* vec1, + const float16_t* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + // Load a pair of f32 registers at a time. + const auto temp_vec1 = vld1q_f16(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]); + const auto temp_vec2 = vld1q_f16(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]); + + sum[2 * registerPairIndex] = f32_fma_low_f16(sum[2 * registerPairIndex], temp_vec1, temp_vec2); + sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2); +} + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( + const float16_t* vec1, + const float16_t* vec2, + float32x4_t* tailSum, + int idx) { + const auto temp_vec1 = vld1_f16(&vec1[idx]); + const auto temp_vec2 = vld1_f16(&vec2[idx]); + *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2); +} + +static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { + int32x4_t shift = vdupq_n_s32(16); + return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); +} + +static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { + return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); +} + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16 + // Load a pair of f32 registers at a time. + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); +} + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + float32x4_t* tailSum, + int idx) { + const auto temp_vec1 = vld1_u16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = vld1_u16(reinterpret_cast(&vec2[idx])); + *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); +} + +template +float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { + float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); + for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { + const auto* vec1_ = vec1 + j; + const auto* vec2_ = vec2 + j; + c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) { + dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + }); + } + auto reducedSum = reduce(sum); + + // First-tier tail fixup: make sure we handle workloads that can + // benefit from vectorization, but don't fit into our fully unrolled + // loop above. + float32x4_t tailSum = vdupq_n_f32(0); + const auto len_aligned_4 = len & ~3; + for (int j = len_aligned; j < len_aligned_4; j += 4) { + dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); + } + auto reducedTail = vpaddq_f32(tailSum, tailSum); + reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); + + // Second-tier tail fixup: handle all workloads. + for (int j = len_aligned_4; j < len; ++j) { + reducedSum += vec1[j] * vec2[j]; + } + return reducedSum; +} + +float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) { + return dot_with_fp32_arith(vec1, vec2, len); +} + +float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) { + return dot_with_fp32_arith(vec1, vec2, len); +} + // On my Apple M1 Macbook (which is ARM v8.5 and thus has the // instructions f32_fma_{low,high}_f16 is targeting), this kernel has // equivalent performance to the fp16-native kernel. static void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) { parallel_for(0, n, 1, [&](int begin, int end) { - for (int i = begin; i < end; ++i) { - float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; - - const auto m_aligned = m & ~(kF32ElementsPerIteration - 1); - for (int j = 0; j < m_aligned ; j += kF32ElementsPerIteration) { - c10::ForcedUnroll{}([x, a, lda, i, j, &sum](auto k) { - // Load a pair of f32 registers at a time. - const auto temp_x = vld1q_f16(x + j + k * 2 * kF32ElementsPerRegister); - const auto temp_a = vld1q_f16(a + lda * i + j + k * 2 * kF32ElementsPerRegister); - - sum[2 * k] = f32_fma_low_f16(sum[2 * k], temp_x, temp_a); - sum[2 * k + 1] = f32_fma_high_f16(sum[2 * k + 1], temp_x, temp_a); - }); - } - auto reducedSum = reduce(sum); + for (int i = begin; i < end; ++i) { + y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m); + } + }); +} - for (int j = m_aligned; j < m; ++j) { - reducedSum += x[j] * a[lda * i + j]; - } - y[i * incy] = reducedSum; - } +static void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int lda, const at::BFloat16 *x, at::BFloat16* y, int incy) { + parallel_for(0, n, 1, [&](int begin, int end) { + for (int i = begin; i < end; ++i) { + y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m); + } }); } @@ -412,26 +546,28 @@ void fp16_gemv_trans( const float beta, float16_t* y, const int incy) { - if (incx == 1 && alpha == 1.0 && beta == 0.0) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC - if (at::globalContext().allowFP16ReductionCPU()) { - return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); - } -#endif - return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); - } - for (const auto i : c10::irange(n)) { - float sum = 0; - const auto row_ = a + lda * i; - for (const auto j : c10::irange(m)) { - sum += x[j * incx] * row_[j]; - } - if (beta == 0.0) { - y[i * incy] = alpha * sum; - } else { - y[i * incy] = beta * y[i * incy] + alpha * sum; - } + if (at::globalContext().allowFP16ReductionCPU()) { + return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy); } +#endif + return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); +} + +void bf16_gemv_trans( + const int m, + const int n, + const at::BFloat16 alpha, + const at::BFloat16* a, + const int lda, + const at::BFloat16* x, + const int incx, + const at::BFloat16 beta, + at::BFloat16* y, + const int incy) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); + return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); } @@ -548,9 +684,37 @@ void gemv_fast_path( *incy); } } -#else + +template <> +void gemv_fast_path( + const char* trans, + const int* m, + const int* n, + const at::BFloat16* alpha, + const at::BFloat16* a, + const int* lda, + const at::BFloat16* x, + const int* incx, + const at::BFloat16* beta, + at::BFloat16* y, + const int* incy) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(trans[0] == 'T' || trans[0] == 't'); + bf16_gemv_trans( + *m, + *n, + *alpha, + a, + *lda, + x, + *incx, + *beta, + y, + *incy); +} +#else // defined(__aarch64__) && !defined(C10_MOBILE) INSTANTIATE(c10::Half); -#endif +INSTANTIATE(c10::BFloat16); +#endif // defined(__aarch64__) && !defined(C10_MOBILE) #undef INSTANTIATE } // namespace blas_impl @@ -559,12 +723,14 @@ template inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx) { if (n == 1) incx = 1; +#if AT_BUILD_WITH_BLAS() if (blas_impl::scal_use_fast_path(n, incx)) { int i_n = (int)n; int i_incx = (int)incx; blas_impl::scal_fast_path(&i_n, &a, x, &i_incx); return; } +#endif for (const auto i : c10::irange(n)) { if (a == scalar_t(0)) { x[i * incx] = 0; @@ -578,7 +744,8 @@ template void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy) { if(n == 1) lda = m; - if (blas_impl::gemv_use_fast_path(m, n, lda, incx, incy)) { +#if AT_BUILD_WITH_BLAS() + if (blas_impl::gemv_use_fast_path(trans, m, n, alpha, lda, incx, beta, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); int i_m = (int)m; int i_n = (int)n; @@ -588,6 +755,7 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i blas_impl::gemv_fast_path(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); return; } +#endif using opmath_t = at::opmath_type; if ((trans == 'T') || (trans == 't')) { diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 446bbeccc223..4c77c983c295 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -75,7 +75,7 @@ namespace { } } -static inline bool cudnnv8_enabled_check_debug() { +inline bool cudnnv8_enabled_check_debug() { static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true; static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true; static uint8_t cudnnv8_debugcount = 0; @@ -86,7 +86,7 @@ static inline bool cudnnv8_enabled_check_debug() { return cudnnv8_flag == 1; } -static inline bool cudnnv8_use_heur_mode_b() { +inline bool cudnnv8_use_heur_mode_b() { return is_cudnnv8_heuristic_mode_b(); } @@ -186,7 +186,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co // (which the user can change) and computed inputs (which the user can // only indirectly affect). It would be an interesting exercise to // come up with a general framework to handle such situations.) -static void convolution_shape_check( +inline void convolution_shape_check( CheckedFrom c, const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) @@ -212,7 +212,7 @@ static void convolution_shape_check( // takes an extra output_padding argument to resolve the ambiguity. template -static inline std::vector _conv_output_size( +inline std::vector _conv_output_size( ArrayRef input_size, ArrayRef weight_size, ArrayRef padding, ArrayRef stride, ArrayRef dilation = ArrayRef() ) { @@ -231,14 +231,14 @@ static inline std::vector _conv_output_size( return output_size; } -static inline std::vector conv_output_size( +inline std::vector conv_output_size( IntArrayRef input_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() ) { return _conv_output_size(input_size, weight_size, padding, stride, dilation); } -static inline std::vector conv_output_size( +inline std::vector conv_output_size( SymIntArrayRef input_size, SymIntArrayRef weight_size, SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef() ) { @@ -264,14 +264,14 @@ std::vector _conv_input_size( return input_size; } -static inline std::vector conv_input_size( +inline std::vector conv_input_size( SymIntArrayRef output_size, SymIntArrayRef weight_size, SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups ) { return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); } -static inline std::vector conv_input_size( +inline std::vector conv_input_size( IntArrayRef output_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { @@ -295,27 +295,27 @@ std::vector _conv_weight_size( return weight_size; } -static inline std::vector conv_weight_size( +inline std::vector conv_weight_size( SymIntArrayRef input_size, SymIntArrayRef output_size, SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); } -static inline std::vector conv_weight_size( +inline std::vector conv_weight_size( IntArrayRef input_size, IntArrayRef output_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); } -static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { +inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { std::vector shape(dim, 1); shape[1] = -1; return bias.reshape(shape); } -static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { +inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (!at::detail::getCUDAHooks().compiledWithCuDNN() || input.scalar_type() == at::kDouble || @@ -351,7 +351,7 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable); TORCH_API bool _cudnn_get_conv_benchmark_empty_cache(); -static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (!at::detail::getCUDAHooks().compiledWithMIOpen() || @@ -378,7 +378,7 @@ static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d; } -static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (input.scalar_type() == at::kDouble || @@ -405,7 +405,7 @@ static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d; } -static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { auto input_memory_format = input.suggest_memory_format(); auto weight_memory_format = weight.suggest_memory_format(); @@ -417,7 +417,7 @@ static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at return can_use_thnn_channels_last_2d; } -static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { +inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { // check layout only for xpu tensor. if (!input.is_xpu() || !weight.is_xpu()) { diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index e1952795843c..b35ad072d0cf 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -393,7 +393,7 @@ struct RegisterPRIVATEUSE1Dispatch { // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. #ifdef CPU_CAPABILITY_AVX512 -#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, nullptr) +#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) #else #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h index 2c334157eba9..664e2db3b2dc 100644 --- a/aten/src/ATen/native/Distributions.h +++ b/aten/src/ATen/native/Distributions.h @@ -254,7 +254,7 @@ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler -C10_DEVICE static inline scalar_t digamma_one(scalar_t x) { +C10_DEVICE inline scalar_t digamma_one(scalar_t x) { constexpr accscalar_t PSI_10 = 2.25175258906672110764; if (x == 0) { return INFINITY; @@ -376,7 +376,7 @@ C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. // Assumes x is close to zero and uses a Taylor expansion. template -C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) { +C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) { const scalar_t factor = digamma_one(alpha) - digamma_one(alpha + beta) - compat_log(x); scalar_t numer = 1; @@ -394,7 +394,7 @@ C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t al // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta. // Assumes x is close to zero and uses a Taylor expansion. template -C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) { +C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) { const scalar_t factor = digamma_one(alpha + beta) - digamma_one(beta); scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha; for (int i = 1; i <= 8; ++i) { @@ -412,7 +412,7 @@ C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alp // Assumes alpha and beta are both large and uses a Rice saddle point expansion. // To ensure numerical stability, this computation is performed at higher precision. template -C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) { +C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) { const accscalar_t total = alpha + beta; const accscalar_t mean = alpha / total; const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total; @@ -452,7 +452,7 @@ C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_ // This function inputs total=alpha+beta to make it easy to implement // Dirichlet reparameterized gradients in terms of Betas. template -C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) { +C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) { accscalar_t x_ = static_cast(x); accscalar_t alpha_ = static_cast(alpha); accscalar_t total_ = static_cast(total); diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index 34c71a886862..9656e2aa4f72 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -448,6 +450,15 @@ std::vector foreach_tensor_norm_slow( return result; } +std::vector foreach_tensor_max_slow(TensorList tensors) { + check_foreach_api_restrictions(tensors); + std::vector result; + for (const auto& t : tensors) { + result.emplace_back(at::max(t)); + } + return result; +} + std::vector foreach_scalar_pow_list_kernel_slow( const Scalar& self, TensorList exponent) { diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index 0839dd9a1560..f5c0672402f3 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -102,12 +102,13 @@ inline void check_foreach_api_restrictions( // corresponding tensors (aligning in index across the tensorLists) share the // same device and dtype. inline bool _check_tensors_share_device_and_dtype( - ArrayRef tensorLists) { + ArrayRef tensorLists, + const bool skip_dtype_check = false) { const auto expected_dtype = tensorLists[0][0].dtype(); const auto expected_device = tensorLists[0][0].device(); auto is_tensor_okay = [&](const Tensor& tensor) { - return tensor.dtype() == expected_dtype && + return (skip_dtype_check || tensor.dtype() == expected_dtype) && tensor.device() == expected_device && tensor.layout() == at::kStrided && tensor.is_non_overlapping_and_dense(); }; diff --git a/aten/src/ATen/native/FractionalMaxPooling.h b/aten/src/ATen/native/FractionalMaxPooling.h index cb5438a03e70..95c05618caef 100644 --- a/aten/src/ATen/native/FractionalMaxPooling.h +++ b/aten/src/ATen/native/FractionalMaxPooling.h @@ -6,7 +6,7 @@ namespace at::native { template -static inline std::vector generate_intervals( +inline std::vector generate_intervals( scalar_t sample, int64_t inputSize, int64_t outputSize, @@ -28,7 +28,7 @@ static inline std::vector generate_intervals( } template -static inline void fractional_max_pool_check_shape( +inline void fractional_max_pool_check_shape( const Tensor& input, const Tensor& randomSamples) { diff --git a/aten/src/ATen/native/GridSamplerUtils.h b/aten/src/ATen/native/GridSamplerUtils.h index eea21ddf5e37..f783043c7961 100644 --- a/aten/src/ATen/native/GridSamplerUtils.h +++ b/aten/src/ATen/native/GridSamplerUtils.h @@ -18,10 +18,8 @@ enum class GridSamplerPadding {Zeros, Border, Reflection}; using detail::GridSamplerInterpolation; using detail::GridSamplerPadding; -namespace { - // See NOTE [ grid_sampler Native Functions ]. -void check_grid_sampler_common( +inline void check_grid_sampler_common( const TensorBase& input, const TensorBase& grid ) { @@ -60,7 +58,7 @@ void check_grid_sampler_common( } // See NOTE [ grid_sampler Native Functions ]. -void check_grid_sampler_2d( +inline void check_grid_sampler_2d( const TensorBase& input, const TensorBase& grid ) { @@ -72,7 +70,7 @@ void check_grid_sampler_2d( } // See NOTE [ grid_sampler Native Functions ]. -void check_grid_sampler_3d( +inline void check_grid_sampler_3d( const TensorBase& input, const TensorBase& grid, int64_t interpolation_mode @@ -91,7 +89,7 @@ void check_grid_sampler_3d( // See NOTE [ grid_sampler Native Functions ]. // cudnn does not support inputs larger than 1024. -bool cond_cudnn_grid_sampler( +inline bool cond_cudnn_grid_sampler( const TensorBase& input, const TensorBase& grid ) { @@ -104,6 +102,4 @@ bool cond_cudnn_grid_sampler( input.sym_size(1) <= 1024); } -} // anonymous namespace - } // namespace at::native diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 3389033ac985..6015a3b509b0 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -856,7 +856,7 @@ namespace { /** * @brief Computes the optimal matrix chain multiplication order * - * Follows the dynamic programming algorithm from Cormen et al, + * Follows the dynamic programming algorithm from Cormen et al., * "Introduction to Algorithms, Third Edition", Chapter 15.2, * p. 370-378. Note that the book uses 1-based indexing. * diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 0b05d5162e66..52f5e1cb6555 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -27,7 +27,7 @@ namespace at::native { -static inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) { +inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) { if (tensor.is_conj()) { return c10::MaybeOwned::owned(tensor.resolve_conj()); } else { @@ -35,7 +35,7 @@ static inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) } } -static inline DimVector batched_matrix_contiguous_strides( +inline DimVector batched_matrix_contiguous_strides( const IntArrayRef sizes, const bool f_contig = false) { // f_contig chooses between the strides of a batch of Fortran (F-contiguous) @@ -62,7 +62,7 @@ static inline DimVector batched_matrix_contiguous_strides( * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N' * matrix starting at Q.data_ptr()[B * M' * N']. */ -static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { +inline Tensor cloneBatchedColumnMajor(const Tensor& src) { // If src is already in batched column major format, then // this will be efficient (no reordering of the data will occur) // because the first transpose will make the tensor contiguous, @@ -75,7 +75,7 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { /* * contig chooses between C-contig (true) and F-contig (false) */ -static inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { +inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { return cond ? c10::MaybeOwned::borrowed(borrow) : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous) : cloneBatchedColumnMajor(clone)); @@ -92,7 +92,7 @@ static inline c10::MaybeOwned borrow_else_clone(const bool cond, const T * which is either the original batch size of the input, or its larger * broadcasted shape. */ -static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1, +inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1, at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) { nrows = (nrows == -1) ? src.size(-2) : nrows; auto copy_sizes = desired_batch_sizes.has_value() @@ -109,7 +109,7 @@ static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = - * Given batches of matrices with arbitrary batch dim, * computes the number of batches. */ -static inline int64_t batchCount(const Tensor& batched_matrices) { +inline int64_t batchCount(const Tensor& batched_matrices) { int64_t result = 1; for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) { result *= batched_matrices.size(i); @@ -118,15 +118,15 @@ static inline int64_t batchCount(const Tensor& batched_matrices) { } // Computes the number of elements of a matrix in a batched matrix tensor -static inline int64_t matrixStride(const Tensor& batched_matrices) { +inline int64_t matrixStride(const Tensor& batched_matrices) { return batched_matrices.size(-1) * batched_matrices.size(-2); } // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig) -static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") { +inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") { TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions."); } -static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") { +inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") { checkIsMatrix(self, f_name, arg_name); TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2), f_name, @@ -134,7 +134,7 @@ static inline void squareCheckInputs(const Tensor& self, const char* const f_nam "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices"); } -static inline void checkInputsSolver(const Tensor& A, +inline void checkInputsSolver(const Tensor& A, const Tensor& B, const bool left, const char* const f_name) { @@ -146,14 +146,14 @@ static inline void checkInputsSolver(const Tensor& A, " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")"); } -static inline bool is_row_or_column_contiguous(const Tensor& t) { +inline bool is_row_or_column_contiguous(const Tensor& t) { // This could be made more general, similar to how it's checked in matmul, which would allow to // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky. // We choose to be conservative for simplicity return t.is_contiguous() || t.transpose(-2, -1).is_contiguous(); } -static inline TransposeType to_transpose_type(const bool contig, const bool conj) { +inline TransposeType to_transpose_type(const bool contig, const bool conj) { if (conj) { if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } else { return TransposeType::ConjTranspose; } @@ -261,7 +261,7 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu } // Returns the epsilon value for floating types except half -static inline double _get_epsilon(const ScalarType& sc_type) { +inline double _get_epsilon(const ScalarType& sc_type) { switch (sc_type) { case at::ScalarType::Float: return static_cast(std::numeric_limits::epsilon()); @@ -274,7 +274,7 @@ static inline double _get_epsilon(const ScalarType& sc_type) { // Validates input shapes and devices // for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve) -static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) { +inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) { TORCH_CHECK(self.device() == A.device(), "Expected b and A to be on the same device, but found b on ", self.device(), " and A on ", A.device(), " instead."); @@ -293,7 +293,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, c " but each b matrix is ", self.size(-2), " by ", self.size(-1)); } -static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) { +inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) { auto dtype = t.scalar_type(); TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)), f_name, ": Expected a floating point or complex tensor as input. Got ", dtype); @@ -305,13 +305,13 @@ static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_n // Checks if all the Tensors in a TensorList are of the same dimensions -static inline void checkAllSameDim(TensorList tensors, int64_t dim) { +inline void checkAllSameDim(TensorList tensors, int64_t dim) { for (auto &t : tensors) { TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead."); } } -static inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { +inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { // broadcast the batch dimensions of arg1 and arg2. IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2); IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2); @@ -325,7 +325,7 @@ static inline std::tuple, std::vector> _linalg_bro return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size)); } -static inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) { +inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) { // If there's no name we assume we don't want to check the errors if (name != nullptr) { linearSolveCheckInputs(arg1, arg2, name); @@ -338,7 +338,7 @@ static inline std::tuple _linalg_broadcast_batch_dims(const Tenso return std::make_tuple(arg1_broadcasted, arg2_broadcasted); } -static inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) { +inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) { IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims); IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims); auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes); @@ -346,7 +346,7 @@ static inline std::vector broadcast_batch_size(const Tensor& t1, const } // Return a permutation with the given axes moved to the end. -static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { +inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { const std::vector a = axes.vec(); const int64_t ndim = self.ndimension(); std::vector perm; @@ -368,7 +368,7 @@ static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { } // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) -static inline std::tuple _parse_qr_mode(c10::string_view mode) { +inline std::tuple _parse_qr_mode(c10::string_view mode) { bool compute_q; bool reduced; if (mode == "reduced") { @@ -388,7 +388,7 @@ static inline std::tuple _parse_qr_mode(c10::string_view mode) { } // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition -static inline std::tuple _compute_geometry_for_Q( +inline std::tuple _compute_geometry_for_Q( const Tensor& input, bool reduced) { int64_t m = input.size(-2), n = input.size(-1); @@ -407,7 +407,7 @@ static inline std::tuple _compute_geometry_for_Q( return std::make_tuple(q_sizes, q_strides, n_columns_q); } -static inline bool svd_uses_cusolver(const Tensor& A) { +inline bool svd_uses_cusolver(const Tensor& A) { // if cusolver is available, it is used unconditionally return A.is_cuda() && at::globalContext().hasCuSOLVER() @@ -417,7 +417,7 @@ static inline bool svd_uses_cusolver(const Tensor& A) { // Function used instead of .to so that the original strides are retained // .to doesn't retain strides and make the output tensor contiguous -static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) { +inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) { auto strided_to = at::empty_strided(original_tensor.sizes(), original_tensor.strides(), options); @@ -433,7 +433,7 @@ static inline Tensor same_stride_to(const Tensor& original_tensor, const at::Ten // For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by // calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will // be `vec(0, 2, 1, 3)`. -static inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) { +inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) { TORCH_CHECK( (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0), "duplicate or invalid dimensions"); @@ -453,7 +453,7 @@ static inline std::vector create_dim_backshift_permutation(int64_t dim0 // will reverse a given permutation. // The reverse permutation array is created by swapping the indices and their // associated values from the given permutation array. -static inline std::vector create_reverse_permutation(std::vector permutation) { +inline std::vector create_reverse_permutation(std::vector permutation) { int64_t ndim = permutation.size(); std::vector reverse_permutation(ndim); for (const auto dim_ind : c10::irange(ndim)) { @@ -464,7 +464,7 @@ static inline std::vector create_reverse_permutation(std::vector(std::toupper(static_cast(uplo[0]))); TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'), "Expected UPLO argument to be 'L' or 'U', but got ", uplo); } -static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { +inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { TORCH_CHECK( result.device() == input.device(), fn_name, @@ -504,7 +504,7 @@ static inline void checkSameDevice(const std::string& fn_name, Tensor result, Te // (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype. // According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch // c10::canCast is used for checking the "safe copy" dtype requirements. -static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { +inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") { bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type()); TORCH_CHECK( can_cast, @@ -514,7 +514,7 @@ static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor } // Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type) -static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") { +inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") { bool can_cast = c10::canCast(result_type, out_type); TORCH_CHECK( can_cast, @@ -523,7 +523,7 @@ static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Scalar out_name, " with dtype ", out_type); } -static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) { +inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) { TORCH_CHECK(!at::isComplexType(tol.scalar_type()), f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type()); } @@ -538,7 +538,7 @@ static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m). This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389 */ -static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) { +inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) { auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1] bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape)); return vector_case; @@ -547,7 +547,7 @@ static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& /* Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor. */ -static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) { +inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) { TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous(); } @@ -578,7 +578,7 @@ class BroadcastLinearIndices { } }; -static inline bool is_blas_compatible_column_major_order(const Tensor& input) { +inline bool is_blas_compatible_column_major_order(const Tensor& input) { IntArrayRef input_strides = input.strides(); IntArrayRef input_sizes = input.sizes(); auto ndim = input.dim(); @@ -599,7 +599,7 @@ static inline bool is_blas_compatible_column_major_order(const Tensor& input) { batch_stride_compatible; } -static inline bool is_blas_compatible_row_major_order(const Tensor& input) { +inline bool is_blas_compatible_row_major_order(const Tensor& input) { IntArrayRef input_strides = input.strides(); IntArrayRef input_sizes = input.sizes(); auto ndim = input.dim(); diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index b13ed7e2ce92..6848abe70ec7 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -2,9 +2,9 @@ // Licensed under the BSD-3-Clause license // This is the CPU implementation of the Connectionist Temporal Loss. // We mostly follow Graves. -// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf +// 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based. -// Graves et al call the probabilities y, we use log_probs (also calling them inputs) +// Graves et al. call the probabilities y, we use log_probs (also calling them inputs) #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include diff --git a/aten/src/ATen/native/LossMulti.h b/aten/src/ATen/native/LossMulti.h index 27697815ad59..8877b05a54cc 100644 --- a/aten/src/ATen/native/LossMulti.h +++ b/aten/src/ATen/native/LossMulti.h @@ -5,8 +5,7 @@ #include namespace at::native { -namespace { - static C10_UNUSED void multilabel_margin_loss_shape_check( + inline void multilabel_margin_loss_shape_check( int64_t& nframe, int64_t& dim, const int64_t& ndims, @@ -35,7 +34,7 @@ namespace { } } - static C10_UNUSED void multi_margin_loss_shape_check( + inline void multi_margin_loss_shape_check( int64_t& nframe, int64_t& dim, const int64_t& ndims, @@ -67,6 +66,4 @@ namespace { } } - -} // anonymous namespace } // namespace at::native diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index b7809ab21dd5..35ae21c32736 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -675,15 +675,6 @@ Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::op return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index))); } -// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages. -static Tensor nll_loss(const Tensor & self, const Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index)); -} - Tensor nll_loss_nd_symint( const Tensor& self, const Tensor& target, diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp index 6f27884b8f24..13c575a1a7bb 100644 --- a/aten/src/ATen/native/LossNLL2d.cpp +++ b/aten/src/ATen/native/LossNLL2d.cpp @@ -499,13 +499,4 @@ Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const std:: return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, std::move(ignore_index))); } -// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages. -static Tensor nll_loss2d(const Tensor & self, const Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, ignore_index)); -} - } // namespace at::native diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 092ee00992e9..e86a9aea411a 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -147,7 +147,7 @@ jiterator_also_stringify_as(jiterator_code( #define CENTRAL_RANGE 0.7 template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_erfinv(T y) { /* Function to calculate inverse error function. Rational approximation is used to generate an initial approximation, which is then improved to @@ -232,7 +232,7 @@ Date: February 1996 * See note [3-Clause BSD License for the Cephes Math Library]. */ template -C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { +C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { using acc_t = at::acc_type; const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; constexpr acc_t zero = acc_t{0.0}; @@ -324,7 +324,7 @@ C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_igno * N 0 */ template -C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) { +C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) { T result = 0; for (size_t i = 0; i <= len; i++) { result = result * x + A[i]; @@ -332,7 +332,7 @@ C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) { return result; } -static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { +inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { double sign = +1; double result = 0; if (x < 0.5) { @@ -350,7 +350,7 @@ static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { return sign * result; } -static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { +inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { float sign = +1; float result = 0; if (x < 0.5f) { @@ -372,7 +372,7 @@ static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { * This function is derived from the implementation of the digamma function in the Cephes Math Library. * See note [3-Clause BSD License for the Cephes Math Library]. */ -static inline double calc_digamma(double x) { +inline double calc_digamma(double x) { // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma static double PSI_10 = 2.25175258906672110764; if (x == 0) { @@ -430,7 +430,7 @@ static inline double calc_digamma(double x) { * This function is derived from the implementation of the digamma function in the Cephes Math Library. * See note [3-Clause BSD License for the Cephes Math Library]. */ -static inline float calc_digamma(float x) { +inline float calc_digamma(float x) { // See [C++ Standard Reference: Gamma Function] static float PSI_10 = 2.25175258906672110764f; if (x == 0) { @@ -485,16 +485,16 @@ static inline float calc_digamma(float x) { return result + logf(x) - (0.5f / x) - y; } -static inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { +inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { return calc_digamma(static_cast(a)); } -static inline c10::Half calc_digamma(c10::Half a) { +inline c10::Half calc_digamma(c10::Half a) { return calc_digamma(static_cast(a)); } template -static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { +inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { // already blocked if n <= 1 const auto one = scalar_t{1}; return ((n % 2) ? one : -one) * @@ -508,7 +508,7 @@ static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { /* References * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov - * [igam2] Maddock et. al., "Incomplete Gamma Functions", + * [igam2] Maddock et al., "Incomplete Gamma Functions", * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html */ @@ -519,7 +519,7 @@ static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { * See NOTICE for the licenses. */ template -static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, +scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, const scalar_t denom[], int64_t N) { // evaluating rational function, i.e., the ratio of two polynomials // the coefficients for numerator are given by `num` while coeffs for @@ -1061,7 +1061,7 @@ static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { } template -static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { +inline scalar_t calc_igammac(scalar_t a, scalar_t x) { /* the calculation of the regularized upper incomplete gamma function * is done differently based on the values of a and x: * - if x and/or a is at the boundary of defined region, then assign the @@ -1141,7 +1141,7 @@ static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { } template -static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { +scalar_t calc_igamma(scalar_t a, scalar_t x) { /* the calculation of the regularized lower incomplete gamma function * is done differently based on the values of a and x: * - if x and/or a is at the boundary of defined region, then assign the @@ -1203,39 +1203,39 @@ static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { } template <> -C10_UNUSED c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { +C10_UNUSED inline c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { return calc_igamma(float(a), float(x)); } template <> -C10_UNUSED c10::Half calc_igamma(c10::Half a, c10::Half x) { +C10_UNUSED inline c10::Half calc_igamma(c10::Half a, c10::Half x) { return calc_igamma(float(a), float(x)); } template <> -C10_UNUSED c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { +C10_UNUSED inline c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { return calc_igammac(float(a), float(x)); } template <> -C10_UNUSED c10::Half calc_igammac(c10::Half a, c10::Half x) { +C10_UNUSED inline c10::Half calc_igammac(c10::Half a, c10::Half x) { return calc_igammac(float(a), float(x)); } inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } template -static T abs_impl(T v) { +inline T abs_impl(T v) { return std::abs(v); } template <> -C10_UNUSED uint8_t abs_impl(uint8_t v) { +C10_UNUSED inline uint8_t abs_impl(uint8_t v) { return v; } template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_gcd(T a, T b) { a = abs_impl(a); b = abs_impl(b); @@ -1284,7 +1284,7 @@ C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) { * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. */ template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type chbevl(const T x, const T array[], size_t len) { T b0, b1, b2; @@ -1310,7 +1310,7 @@ chbevl(const T x, const T array[], size_t len) { * of all inputs to convert them into the domain of the approximation. */ template -static inline std::tuple chebyshev_coefficients_i0e_A() { +inline std::tuple chebyshev_coefficients_i0e_A() { /* Chebyshev coefficients for exp(-x) I0(x) * in the interval [0,8]. * @@ -1336,7 +1336,7 @@ static inline std::tuple chebyshev_coefficients_i0e_A() { }; template -static inline std::tuple chebyshev_coefficients_i0e_B() { +inline std::tuple chebyshev_coefficients_i0e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) * in the inverted interval [8,infinity]. * @@ -1361,7 +1361,7 @@ static inline std::tuple chebyshev_coefficients_i0e_B() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_A() { /* Chebyshev coefficients for exp(-x) I1(x) * in the interval [0,8]. @@ -1388,7 +1388,7 @@ chebyshev_coefficients_i1e_A() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_A() { /* Chebyshev coefficients for exp(-x) I1(x) * in the interval [0,8]. @@ -1417,7 +1417,7 @@ chebyshev_coefficients_i1e_A() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) * in the inverted interval [8,infinity]. @@ -1443,7 +1443,7 @@ chebyshev_coefficients_i1e_B() { }; template -static inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if::value, std::tuple>::type chebyshev_coefficients_i1e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) * in the inverted interval [8,infinity]. @@ -1463,7 +1463,7 @@ chebyshev_coefficients_i1e_B() { }; template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_i0(T _x) { T x = std::abs(_x); @@ -1481,7 +1481,7 @@ calc_i0(T _x) { } // Upcast bfloat16 input to float for numerical accuracy purposes -static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } +inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } /* * This function is derived from the implementation of the i1 function in the Cephes Math Library. @@ -1493,7 +1493,7 @@ static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cas * of all inputs to convert them into the domain of the approximation. */ template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_i1(T _x) { T x = std::abs(_x); @@ -1522,7 +1522,7 @@ calc_i1(T _x) { * of all inputs to convert them into the domain of the approximation. */ template -static inline typename std::enable_if::value, T>::type +inline typename std::enable_if::value, T>::type calc_i1e(T _x) { T x = std::abs(_x); @@ -1549,7 +1549,7 @@ calc_i1e(T _x) { * (integrated from minus infinity to x) is equal to y. */ template -static inline C10_HOST_DEVICE T calc_ndtri(T y0) { +inline C10_HOST_DEVICE T calc_ndtri(T y0) { /* sqrt(2pi) */ constexpr T s2pi = 2.50662827463100050242E0; @@ -1737,7 +1737,7 @@ static inline C10_HOST_DEVICE T calc_ndtri(T y0) { template -C10_HOST_DEVICE static inline typename std::enable_if::value, T>::type +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type erfcx_y100(T y100) { switch (static_cast(y100)) { @@ -2148,7 +2148,7 @@ return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682 } template -C10_HOST_DEVICE static inline typename std::enable_if::value, T>::type +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type calc_erfcx(T x) { if (at::_isnan(x)) { @@ -2188,7 +2188,7 @@ calc_erfcx(T x) * See NOTICE for the licenses. */ template -static inline C10_HOST_DEVICE T calc_log_ndtr(T x) { +inline C10_HOST_DEVICE T calc_log_ndtr(T x) { T t = x * c10::frac_sqrt_2; if (x < T{-1.0}) { return std::log(calc_erfcx(-t) / 2) - t * t; @@ -2198,7 +2198,7 @@ static inline C10_HOST_DEVICE T calc_log_ndtr(T x) { } template -static inline C10_HOST_DEVICE T airy_ai_forward(T x) { +inline C10_HOST_DEVICE T airy_ai_forward(T x) { static const T AN[] = { +3.46538101525629032477e-01, +1.20075952739645805542e+01, @@ -2377,7 +2377,7 @@ static inline C10_HOST_DEVICE T airy_ai_forward(T x) { } // T airy_ai(T x) template -static inline C10_HOST_DEVICE T bessel_j0_forward(T x) { +inline C10_HOST_DEVICE T bessel_j0_forward(T x) { static const T PP[] = { +7.96936729297347051624e-04, +8.28352392107440799803e-02, @@ -2489,7 +2489,7 @@ static inline C10_HOST_DEVICE T bessel_j0_forward(T x) { } // bessel_j0_forward(T x) template -static inline C10_HOST_DEVICE T bessel_j1_forward(T x) { +inline C10_HOST_DEVICE T bessel_j1_forward(T x) { static const T PP[] = { +7.62125616208173112003e-04, +7.31397056940917570436e-02, @@ -2597,7 +2597,7 @@ static inline C10_HOST_DEVICE T bessel_j1_forward(T x) { } // bessel_j1_forward(T x) template -static inline C10_HOST_DEVICE T bessel_y0_forward(T x) { +inline C10_HOST_DEVICE T bessel_y0_forward(T x) { static const T PP[] = { +7.96936729297347051624e-04, +8.28352392107440799803e-02, @@ -2712,7 +2712,7 @@ static inline C10_HOST_DEVICE T bessel_y0_forward(T x) { } // bessel_y0_forward(T x) template -static inline C10_HOST_DEVICE T bessel_y1_forward(T x) { +inline C10_HOST_DEVICE T bessel_y1_forward(T x) { static const T PP[] = { +7.62125616208173112003e-04, +7.31397056940917570436e-02, @@ -2826,7 +2826,7 @@ static inline C10_HOST_DEVICE T bessel_y1_forward(T x) { } // bessel_y1_forward(T x) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -2865,12 +2865,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { } // chebyshev_polynomial_t_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { return chebyshev_polynomial_t_forward(x, static_cast(n)); } // chebyshev_polynomial_t_forward(T x, T n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -2913,12 +2913,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { } // chebyshev_polynomial_u_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { return chebyshev_polynomial_u_forward(x, static_cast(n)); } // chebyshev_polynomial_u_forward(T x, T n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -2969,12 +2969,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { } // chebyshev_polynomial_v_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { return chebyshev_polynomial_v_forward(x, static_cast(n)); } // chebyshev_polynomial_v_forward(T x, T n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3029,12 +3029,12 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { } // chebyshev_polynomial_w_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { return chebyshev_polynomial_w_forward(x, static_cast(n)); } // chebyshev_polynomial_w_forward(T x, T n) template -static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3061,17 +3061,17 @@ static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { } // hermite_polynomial_h_forward(T x, int64_t n) template::value, int> = 0> -static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { return hermite_polynomial_h_forward(x, static_cast(n)); } // hermite_polynomial_h_forward(T x, T n) template::value, int> = 0> -static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast(n) : static_cast(-1)); } // hermite_polynomial_h_forward(T x, T n) template -static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3098,12 +3098,12 @@ static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { } // hermite_polynomial_he_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { return hermite_polynomial_he_forward(x, static_cast(n)); } // hermite_polynomial_he_forward(T x, T n) template -static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3134,12 +3134,12 @@ static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { } // laguerre_polynomial_l_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { return laguerre_polynomial_l_forward(x, static_cast(n)); } // laguerre_polynomial_l_forward(T x, T n) template -static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3174,12 +3174,12 @@ static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { } // legendre_polynomial_p_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { return legendre_polynomial_p_forward(x, static_cast(n)); } // legendre_polynomial_p_forward(T x, T n) template -static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { static const T A[] = { -4.41534164647933937950e-18, +3.33079451882223809783e-17, @@ -3268,7 +3268,7 @@ static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { } // modified_bessel_i0_forward(T x) template -static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { static const T A[] = { +2.77791411276104639959e-18, -2.11142121435816608115e-17, @@ -3364,7 +3364,7 @@ static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { } // modified_bessel_i1_forward(T x) template -static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { static const T A[] = { +1.37446543561352307156e-16, +4.25981614279661018399e-14, @@ -3441,7 +3441,7 @@ static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { } // modified_bessel_k0_forward(T x) template -static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { +inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { static const T A[] = { -7.02386347938628759343e-18, -2.42744985051936593393e-15, @@ -3519,7 +3519,7 @@ static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { } // modified_bessel_k1_forward(T x) template -static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { +inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { static const T A[] = { +1.37446543561352307156e-16, +4.25981614279661018399e-14, @@ -3596,7 +3596,7 @@ static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { } // T scaled_modified_bessel_k0_forward(T x) template -static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { +inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { static const T A[] = { -7.02386347938628759343e-18, -2.42744985051936593393e-15, @@ -3674,7 +3674,7 @@ static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { } // T scaled_modified_bessel_k1_forward(T x) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3717,12 +3717,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int6 } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_t_forward(T x, T n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3769,12 +3769,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int6 } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_u_forward(T x, T n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3825,12 +3825,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int6 } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_v_forward(T x, T n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { if (n < 0) { return T(0.0); } @@ -3881,12 +3881,12 @@ static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int6 } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) template -static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); } // shifted_chebyshev_polynomial_w_forward(T x, T n) template -static inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { +inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { if (std::isinf(x)) { return T(0.0); } diff --git a/aten/src/ATen/native/MaxPooling.h b/aten/src/ATen/native/MaxPooling.h index 3c6760ca6886..7044b6ee3dc2 100644 --- a/aten/src/ATen/native/MaxPooling.h +++ b/aten/src/ATen/native/MaxPooling.h @@ -7,7 +7,7 @@ namespace at::native { -static void check_max_pool1d( +inline void check_max_pool1d( const Tensor& self, IntArrayRef kernel_size, IntArrayRef stride, diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index 518466df84ce..302a3f45bdf4 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -28,18 +28,6 @@ Tensor empty_meta_symint( size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } -// Kept only for BC with XLA -static Tensor empty_strided_meta( - IntArrayRef size, - IntArrayRef stride, - std::optional dtype_opt, - std::optional layout_opt, - std::optional device_opt, - std::optional pin_memory_opt -) { - return empty_strided_meta_symint(c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype_opt, layout_opt, device_opt, pin_memory_opt); -} - Tensor empty_strided_meta_symint( SymIntArrayRef size, SymIntArrayRef stride, diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp index fbac5d4cc72c..7da1ec9b1998 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp @@ -802,55 +802,6 @@ TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu) dilation); } -static std::tuple slow_conv_transpose2d_backward_out_cpu(const Tensor& grad_output, - const Tensor& input, - const Tensor& weight, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef output_padding, - IntArrayRef dilation, - Tensor& grad_input, - Tensor& grad_weight, - Tensor& grad_bias) { - if (grad_input.defined()) { - slow_conv_transpose2d_backward_out_cpu_template( - input, - grad_output, - grad_input, - weight, - kernel_size, - stride, - padding, - output_padding, - dilation); - } - - if (grad_bias.defined()) { - at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3}); - } - - if (grad_weight.defined()) { - grad_weight.resize_(weight.sizes(), weight.suggest_memory_format()); - grad_weight.zero_(); - slow_conv_transpose2d_acc_grad_parameters_cpu( - input, - weight, - grad_output, - grad_weight, - grad_bias, - kernel_size, - stride, - padding, - output_padding, - dilation, - 1); - } - - return std::tuple( - grad_input, grad_weight, grad_bias); -} - static std::tuple slow_conv_transpose2d_backward_cpu( const Tensor& grad_output, const Tensor& input, diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp index f82354ace3b8..9ef236d4dab9 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp @@ -871,58 +871,6 @@ Tensor slow_conv_transpose3d_cpu( return output; } -static std::tuple slow_conv_transpose3d_backward_out_cpu(const Tensor& grad_output, - const Tensor& input, - const Tensor& weight, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef output_padding, - IntArrayRef dilation, - Tensor& grad_input, - Tensor& grad_weight, - Tensor& grad_bias) { - if (grad_input.defined()) { - slow_conv_transpose3d_backward_out_cpu_template( - input, - grad_output, - grad_input, - weight, - kernel_size, - stride, - padding, - output_padding, - dilation); - } - - if (grad_weight.defined()) { - grad_weight.resize_(weight.sizes()); - grad_weight.zero_(); - } - - if (grad_bias.defined()) { - grad_bias.resize_({weight.size(1)}); - grad_bias.zero_(); - } - - if (grad_weight.defined() || grad_bias.defined()) { - slow_conv_transpose3d_acc_grad_parameters_cpu( - input, - grad_output, - grad_weight, - grad_bias, - kernel_size, - stride, - padding, - output_padding, - dilation, - 1); - } - - return std::tuple( - grad_input, grad_weight, grad_bias); -} - static std::tuple slow_conv_transpose3d_backward_cpu( const Tensor& grad_output, const Tensor& input, diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index 709d63bae636..70fb94cc6f45 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -339,12 +339,6 @@ Tensor& gather_out(const Tensor& self, Dimname dim, const Tensor& index, bool sp Tensor index_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) { reportNYIDimnameOverload("index_add"); } -static Tensor& index_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar &alpha) { - reportNYIDimnameOverload("index_add"); -} -static Tensor& index_add_out(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source, const Scalar& alpha, Tensor& result) { - reportNYIDimnameOverload("index_add"); -} Tensor index_fill(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) { return at::index_fill(self, dimname_to_position(self, dim), index, source); } @@ -372,21 +366,12 @@ Tensor index_select(const Tensor& self, Dimname dim, const Tensor& index) { Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { reportNYIDimnameOverload("scatter"); } -static Tensor& scatter_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { - reportNYIDimnameOverload("scatter"); -} Tensor scatter(const Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) { reportNYIDimnameOverload("scatter"); } -static Tensor& scatter_(Tensor& self, Dimname dim, const Tensor& index, const Scalar& source) { - reportNYIDimnameOverload("scatter"); -} Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { reportNYIDimnameOverload("scatter_add"); } -static Tensor& scatter_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) { - reportNYIDimnameOverload("scatter_add"); -} std::tuple sort_out(const Tensor& self, std::optional stable, Dimname dim, bool keepdim, Tensor& values, Tensor& indices) { reportNYIDimnameOverload("sort"); } diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index 97c35599f791..ffd19b2e93a9 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -43,7 +43,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { // non-empty tensor if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS && - self.device().type() != at::kPrivateUse1) { + self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) { // for cuda, rely on device assert thrown by scatter TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative."); } @@ -51,7 +51,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { num_classes = self.max().item().toLong() + 1; } else { if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS && - self.device().type() != at::kPrivateUse1) { + self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) { // rely on device asserts from scatter to avoid sync here TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); } else { diff --git a/aten/src/ATen/native/Padding.h b/aten/src/ATen/native/Padding.h index 083436134282..53a054027f33 100644 --- a/aten/src/ATen/native/Padding.h +++ b/aten/src/ATen/native/Padding.h @@ -26,7 +26,7 @@ DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel); namespace padding { template -static inline void check_valid_input(const Tensor& input, IntArrayRef padding) { +inline void check_valid_input(const Tensor& input, IntArrayRef padding) { TORCH_CHECK(padding.size() == 2 * dim, "padding size is expected to be ", 2 * dim, diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index df73299ea230..df677019e897 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -48,7 +48,7 @@ DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel); namespace { template -static inline dest_t +inline dest_t safe_downcast(src_t v) { TORCH_CHECK(std::numeric_limits::min() <= v && v <= std::numeric_limits::max(), @@ -58,7 +58,7 @@ safe_downcast(src_t v) } template -static inline T pooling_output_shape_pad_lr( +inline T pooling_output_shape_pad_lr( T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation, bool ceil_mode) { T outputSize = div_rtn( @@ -75,7 +75,7 @@ static inline T pooling_output_shape_pad_lr( } template -static inline T pooling_output_shape( +inline T pooling_output_shape( T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) { TORCH_CHECK(stride != 0, "stride should not be zero"); TORCH_CHECK(pad >= 0, @@ -117,7 +117,7 @@ inline std::pair pooling_same_mode_padding_lr( } // AveragePool2d/DilatedMaxPool2d (forward) -static inline void +inline void pool2d_shape_check( const Tensor& input, int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, @@ -164,7 +164,7 @@ pool2d_shape_check( } // DilatedMaxPool2d (backward) -static inline void +inline void max_pool2d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -192,7 +192,7 @@ max_pool2d_backward_shape_check( } // AveragePool2d (backward) -static inline void +inline void avg_pool2d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -218,7 +218,7 @@ avg_pool2d_backward_shape_check( } // AveragePool3d/DilatedMaxPool3d (forward) -static inline void +inline void pool3d_shape_check( const Tensor& input, int64_t nslices, @@ -280,7 +280,7 @@ pool3d_shape_check( "Output size is too small"); } -static inline void +inline void max_pool3d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -317,7 +317,7 @@ max_pool3d_backward_shape_check( check_dim_size(indices, ndim, ndim-1, owidth); } -static inline void +inline void avg_pool3d_backward_shape_check( const Tensor& input, const Tensor& gradOutput, diff --git a/aten/src/ATen/native/Pow.h b/aten/src/ATen/native/Pow.h index 068482ee300c..76ddda846a59 100644 --- a/aten/src/ATen/native/Pow.h +++ b/aten/src/ATen/native/Pow.h @@ -24,7 +24,7 @@ namespace native { // only non-zero result. template ::value, T>::type* = nullptr> -static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { +inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { T result = 1; while (b) { if (b & 1) { @@ -38,13 +38,13 @@ static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, template ::value && !std::is_signed::value, T>::type* = nullptr> -static inline HOST_DEVICE T powi(T a, T b) { +inline HOST_DEVICE T powi(T a, T b) { return powi_impl(a, b); } template ::value && std::is_signed::value, T>::type* = nullptr> -static inline HOST_DEVICE T powi(T a, T b) { +inline HOST_DEVICE T powi(T a, T b) { if ( b < 0 ) { if ( a == 1 ) { return 1; diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 2e870dc83ee1..4718e824fad8 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -2276,11 +2276,6 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { return result.load(); } -static Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, at::IntArrayRef sizes, bool keepdim) { - return at::native::value_selecting_reduction_backward_symint(grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim); -} - - // max(dim), min(dim), topk(dim), mode(dim), are examples of reduction // functions that select values. value_selecting_reduction_backward is the // backward function for those operators; it propagates the grad to the diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 505cf3bb3a77..cfb4776fa846 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -31,7 +31,7 @@ constexpr scalar_t lower_bound() { return lim::has_infinity ? -lim::infinity() : lim::lowest(); } -static inline Tensor restride_dim( +inline Tensor restride_dim( const Tensor& src, int64_t dim, IntArrayRef replacement_shape ) { @@ -96,13 +96,13 @@ inline std::optional _allreduce_return_trivial( " but found ", out.option())\ } -static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { +inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self); OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options()); OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options()); } -static inline Tensor integer_upcast(const Tensor& self, std::optional dtype) { +inline Tensor integer_upcast(const Tensor& self, std::optional dtype) { ScalarType scalarType = self.scalar_type(); TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented"); ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType); @@ -111,7 +111,7 @@ static inline Tensor integer_upcast(const Tensor& self, std::optional get_zero_numel_tensor_size( +inline std::vector get_zero_numel_tensor_size( const Tensor& self, const int64_t dim, const bool keepdim, @@ -313,7 +313,7 @@ static std::vector get_zero_numel_tensor_size( // This function should be called when you are reducing a zero-numel tensor and want to // resize the output and return it. This function exists for resizing zero-numel // tensors when the size of the reduction dimension is non-zero. -static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices, +inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices, const Tensor& self, const int64_t dim, const bool keepdim, const char *fn_name) { auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name); @@ -349,7 +349,7 @@ inline ScalarType get_dtype_from_result(Tensor& result, std::optional const Tensor& _resize_( const Tensor& self, diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 0a1f21298957..40d3bce80f42 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -40,7 +40,7 @@ TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes); TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes); TORCH_API void resize_bytes_nocuda(const Storage& storage, c10::SymInt size_bytes); -static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) { +inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) { // It does not make sense to try to resize a storage // to hold 0 elements, and this can break // if storage_offset is positive but @@ -79,7 +79,7 @@ template <> inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); } template -static inline void checkInBoundsForStorage( +inline void checkInBoundsForStorage( ArrayRef size, ArrayRef stride, T storage_offset, @@ -111,7 +111,7 @@ static inline void checkInBoundsForStorage( } template -static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, +inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, ArrayRef size, ArrayRef stride) { // FIXME: stride should be optional if (stride.data()) { diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index fa7be5a698e9..aa9173154a14 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -440,15 +440,6 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) ( } } -static Tensor softmax(const Tensor& input_, const int64_t dim_) { - auto result = [&]() { - NoNamesGuard guard; - return at::_softmax(input_, dim_, false); - }(); - namedinference::propagate_names(result, input_); - return result; -} - Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; @@ -505,15 +496,6 @@ Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 5f9ff1b83822..db8acf193199 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -1195,15 +1195,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho #undef REPR } -static Tensor istft(const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, - const optional win_lengthOpt, const Tensor& window, - const bool center, const bool normalized, const optional onesidedOpt, - const optional lengthOpt) { - return at::native::istft( - self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, - onesidedOpt, lengthOpt, /*return_complex=*/false); -} - void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { const auto input_sizes = input.sizes(); const auto input_strides = input.strides(); diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 6f132a6ea814..6d6db1477f1f 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -792,12 +792,6 @@ std::tuple max(const Tensor& self, Dimname dim, bool keepdim) { std::tuple max_out(const Tensor& self, Dimname dim, bool keepdim, Tensor& max, Tensor& max_indices) { return at::max_out(max, max_indices, self, dimname_to_position(self, dim), keepdim); } -static Tensor argmax(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) { - reportNYIDimnameOverload("argmax"); -} -static Tensor argmin(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) { - reportNYIDimnameOverload("argmin"); -} Tensor argsort(const Tensor& /*self*/, Dimname /*dim*/, bool /*keepdim*/) { reportNYIDimnameOverload("argsort"); } diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 5e7c9cf8a5f8..55961d9e0be9 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -172,18 +172,10 @@ Tensor arange( return at::arange_out(result, start, end, step); } -static Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) { - return at::arange_out(result, start, end, /*step=*/1); -} - Tensor& arange_out(const Scalar& end, Tensor& result) { return at::arange_out(result, /*start=*/0, end, /*step=*/1); } -static Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) { - return at::arange_out(result, start, end, /*step=*/1); -} - Tensor _dim_arange(const Tensor& like, int64_t dim) { return at::arange(like.size(dim), like.options().dtype(at::kLong)); } diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 899cf68a7a5a..95c88f4572cb 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -105,10 +105,6 @@ Tensor & detach_(Tensor & self) { return self; } -static Tensor contiguous(const Tensor & self) { - return contiguous(self, MemoryFormat::Contiguous); -} - Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { if (self.is_contiguous(memory_format)) { return self; diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index bdab4ce24551..adcddead041b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -210,7 +210,6 @@ #include #endif -#include #include #include #include @@ -1181,14 +1180,6 @@ Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef s return result; } -static Tensor as_strided_tensorimpl_meta(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) { - auto storage_offset = storage_offset_.value_or(self.storage_offset()); - auto result = at::detail::make_tensor( - c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); - setStrided(result, size, stride, storage_offset); - return result; -} - template inline void setStridedUnchecked( const Tensor& self, @@ -1249,10 +1240,6 @@ const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymInt return self; } -static Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) { - return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); -} - // Should just use narrow_copy_out, but this API is used internally at Meta: // https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ @@ -3587,10 +3574,6 @@ Tensor view_as(const Tensor& self, const Tensor& other) { return self.view_symint(other.sym_sizes()); } -static int64_t numel(const Tensor& self) { - return self.unsafeGetTensorImpl()->numel(); -} - std::vector unbind(const Tensor &self, int64_t dim) { dim = maybe_wrap_dim(dim, self.dim()); int64_t size = self.size(dim); diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index 9cb75a0eccf4..62440b956c80 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -180,10 +180,6 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) { compute_triu_tril(self, k, result); } -static Tensor trace_backward(const Tensor& grad, at::IntArrayRef sizes) { - return at::native::trace_backward_symint(grad, c10::fromIntArrayRefSlow(sizes)); -} - Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 4afc7619c2eb..6e694109a21f 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -24,10 +24,6 @@ namespace at::native { -static bool is_cuda(const Tensor& self) { - return self.is_cuda(); -} - bool is_distributed(const Tensor& self) { return false; } @@ -60,18 +56,6 @@ bool is_neg(const Tensor& self) { return self.is_neg(); } -static bool is_sparse(const Tensor& self) { - return self.is_sparse(); -} - -static bool is_sparse_csr(const Tensor& self) { - return self.is_sparse_csr(); -} - -static bool is_quantized(const Tensor& self) { - return self.is_quantized(); -} - // True if `self` and `from` have compatible tensor type so that `from`'s // TensorImpl can be copied to `self`. bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) { diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index 6f063a0dc2fb..9542d9953ed6 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -103,7 +103,7 @@ DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel); DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel); DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel); -static C10_UNUSED std::array upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +inline C10_UNUSED std::array upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 1, "It is expected output_size equals to 1, but got size ", @@ -131,7 +131,7 @@ static C10_UNUSED std::array upsample_1d_common_check(IntArrayRef in return {nbatch, channels, output_width}; } -static C10_UNUSED std::array upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +inline C10_UNUSED std::array upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 2, "It is expected output_size equals to 2, but got size ", @@ -167,7 +167,7 @@ static C10_UNUSED std::array upsample_2d_common_check(IntArrayRef in return {nbatch, channels, output_height, output_width}; } -static C10_UNUSED +inline C10_UNUSED std::array upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 3, @@ -210,7 +210,7 @@ std::array upsample_3d_common_check(IntArrayRef input_size, IntArray return {nbatch, channels, output_depth, output_height, output_width}; } -static inline void upsample_2d_shape_check( +inline void upsample_2d_shape_check( const Tensor& input, const Tensor& grad_output, int64_t nbatch, @@ -251,7 +251,7 @@ static inline void upsample_2d_shape_check( } template -static inline scalar_t compute_scales_value( +inline scalar_t compute_scales_value( const std::optional scale, int64_t input_size, int64_t output_size) { @@ -263,7 +263,7 @@ static inline scalar_t compute_scales_value( } template -static inline scalar_t area_pixel_compute_scale( +inline scalar_t area_pixel_compute_scale( int64_t input_size, int64_t output_size, bool align_corners, @@ -281,7 +281,7 @@ static inline scalar_t area_pixel_compute_scale( } template -static inline scalar_t area_pixel_compute_source_index( +inline scalar_t area_pixel_compute_source_index( scalar_t scale, int64_t dst_index, bool align_corners, @@ -308,7 +308,7 @@ static inline scalar_t area_pixel_compute_source_index( } } -static inline int64_t nearest_neighbor_compute_source_index( +inline int64_t nearest_neighbor_compute_source_index( const float scale, int64_t dst_index, int64_t input_size) { @@ -319,7 +319,7 @@ static inline int64_t nearest_neighbor_compute_source_index( return src_index; } -static inline int64_t nearest_neighbor_exact_compute_source_index( +inline int64_t nearest_neighbor_exact_compute_source_index( const float scale, int64_t dst_index, int64_t input_size) { @@ -331,7 +331,7 @@ static inline int64_t nearest_neighbor_exact_compute_source_index( return src_index; } -static inline int64_t nearest_idx( +inline int64_t nearest_idx( int64_t output_index, int64_t input_size, int64_t output_size, @@ -352,7 +352,7 @@ static inline int64_t nearest_idx( } } -static inline int64_t nearest_exact_idx( +inline int64_t nearest_exact_idx( int64_t output_index, int64_t input_size, int64_t output_size, @@ -365,7 +365,7 @@ static inline int64_t nearest_exact_idx( typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional); template -static scalar_t upsample_get_value_bounded( +scalar_t upsample_get_value_bounded( scalar_t* data, int64_t width, int64_t height, @@ -377,7 +377,7 @@ static scalar_t upsample_get_value_bounded( } template -static void upsample_increment_value_bounded( +void upsample_increment_value_bounded( scalar_t* data, int64_t width, int64_t height, @@ -392,17 +392,17 @@ static void upsample_increment_value_bounded( // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm template -static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { +scalar_t cubic_convolution1(scalar_t x, scalar_t A) { return ((A + 2) * x - (A + 3)) * x * x + 1; } template -static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { +scalar_t cubic_convolution2(scalar_t x, scalar_t A) { return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; } template -static inline void get_cubic_upsample_coefficients( +void get_cubic_upsample_coefficients( scalar_t coeffs[4], scalar_t t) { scalar_t A = -0.75; @@ -418,7 +418,7 @@ static inline void get_cubic_upsample_coefficients( } template -static inline scalar_t cubic_interp1d( +inline scalar_t cubic_interp1d( scalar_t x0, scalar_t x1, scalar_t x2, @@ -434,7 +434,7 @@ static inline scalar_t cubic_interp1d( // type can accurately represent, the type casting to `int64_t` might exceed // `input_size`, causing overflow. So we guard it with `std::min` below. template -static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) { +inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) { input_index = std::min(static_cast(floorf(real_input_index)), input_size - 1); lambda = std::min( std::max(real_input_index - input_index, static_cast(0)), @@ -443,7 +443,7 @@ static inline void guard_index_and_lambda(const opmath_t& real_input_index, cons } template -static inline void compute_source_index_and_lambda( +inline void compute_source_index_and_lambda( int64_t& input_index0, int64_t& input_index1, scalar_t& lambda0, diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 587809ea57c8..b664cdf262ad 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -33,6 +33,16 @@ void fp16_gemv_trans( const float beta, float16_t* y, const int incy); + +float fp16_dot_with_fp32_arith( + const float16_t* x, + const float16_t* a, + int64_t len); + +float bf16_dot_with_fp32_arith( + const at::BFloat16* x, + const at::BFloat16* a, + int64_t len); } #endif @@ -308,31 +318,21 @@ void gemm_notrans_( } -inline float32x4_t load_as_float32x4(const Half* ptr) { - return vcvt_f32_f16(vld1_f16(reinterpret_cast(ptr))); -} - inline float32x4_t load_as_float32x4(const BFloat16* ptr) { int32x4_t shift = vdupq_n_s32(16); uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast(ptr))); return vreinterpretq_f32_u32(vshlq_u32(as_int, shift)); } -template -static float compute_dot(const T* a, const T* b, int64_t l) { - if ((l&3) != 0) { - return sum(l, [&](int64_t i) -> float { - return float(a[i]) * float(b[i]); - }); - } - float32x4_t rcv = vdupq_n_f32(0); - for (int64_t idx = 0; idx < l; idx += 4) { - float32x4_t aVec = load_as_float32x4(a + idx); - float32x4_t bVec = load_as_float32x4(b + idx); - rcv = vaddq_f32(rcv, vmulq_f32(aVec, bVec)); - } - auto sum = vpaddq_f32(rcv, rcv); - return vgetq_lane_f32(vpaddq_f32(sum, sum), 0); +static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) { + return at::native::blas_impl::fp16_dot_with_fp32_arith( + reinterpret_cast(a), + reinterpret_cast(b), + len); +} + +static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) { + return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len); } template <> diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 08c3bbe43500..7d87df45c1c5 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -82,7 +82,7 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o template ::result_type>::value>::type* = nullptr> -static inline void +inline void execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; using result_type = typename traits::result_type; @@ -97,7 +97,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t template ::result_type>::value>::type* = nullptr> -static inline void +inline void execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; for (; i < n; i++) { @@ -111,7 +111,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t // Basic loop operation (one output, N inputs). May be auto-vectorized // by the compiler. Supports inputs and outputs of different types. template -static inline void +inline void basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; constexpr int ntensors = traits::arity + 1; @@ -166,7 +166,7 @@ void handle_tuple_outputs(char* C10_RESTRICT data[], // 2. Iterate over the members of the returned tuple, set the corresponding // output tensor by the tuple member in `handle_tuple_outputs` function. template -static inline void +inline void multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; @@ -195,7 +195,7 @@ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_ // a scalar (stride 0). It's position is indicated by the argument `S`. If `S` // is 0, then there are no scalar inputs. template -static inline void +inline void vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) { using traits = function_traits; using scalar_t = typename function_traits::result_type; @@ -228,7 +228,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve template -static inline void unroll_contiguous_scalar_checks( +inline void unroll_contiguous_scalar_checks( const int64_t* /*strides*/, std::index_sequence<>, cb_t&& cb) { @@ -236,7 +236,7 @@ static inline void unroll_contiguous_scalar_checks( } template -static inline void unroll_contiguous_scalar_checks( +inline void unroll_contiguous_scalar_checks( const int64_t* strides, std::index_sequence, cb_t&& cb) { diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 26155373be58..37bd32d1c4c1 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -21,21 +21,21 @@ using namespace vec; // reduction that is contiguous over the input in dim 0 template -static inline bool is_contiguous_reduction(const int64_t* strides) { +inline bool is_contiguous_reduction(const int64_t* strides) { return strides[0] == 0 && strides[1] == sizeof(typename traits::arg2_t); } // reduction that is contiguous over the input in dim 1 template -static inline bool is_outer_reduction(const int64_t* strides) { +inline bool is_outer_reduction(const int64_t* strides) { return strides[0] == 0 && strides[2] == sizeof(typename traits::result_type) && strides[3] == sizeof(typename traits::arg2_t); } template -static inline void vectorized_reduction(char** data, int64_t n, int64_t stride, +inline void vectorized_reduction(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) { VEC_LOOP_HEADER(func_t, data) const char* in1_ptr = data[1]; @@ -69,7 +69,7 @@ static inline void vectorized_reduction(char** data, int64_t n, int64_t stride, } template -static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { +inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { for (const auto j C10_UNUSED : c10::irange(n)) { f(); data[0] += strides[0]; @@ -79,7 +79,7 @@ static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int // computes the reduction out = op(out, in) template -static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { +inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); int64_t count = n / (4 * Vec::size()); @@ -93,7 +93,7 @@ static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, // computes the reduction out = op(out, in) template -static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { +inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes) @@ -132,13 +132,13 @@ static void set_results(const res_t result, const TensorIteratorBase &iter, cons } template -static inline typename std::enable_if::type +inline typename std::enable_if::type for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) { return i; } template -static inline typename std::enable_if::type +inline typename std::enable_if::type for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { if (i < (size_t)num_outputs) { set_result(i, std::get(t), iter, num_outputs); @@ -286,7 +286,7 @@ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vo // when reduction is on most inner dimension (dim 0 in TensorIterator) // and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim` // can be used. -static inline bool is_reduce_lastdim(TensorIteratorBase& iter) { +inline bool is_reduce_lastdim(TensorIteratorBase& iter) { return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0) && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1); } diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index bcfc26c7df7d..95119b5ac085 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -655,7 +655,7 @@ std::pair radix_sort_parallel( const int64_t elements_count, const int64_t max_value) { TORCH_INTERNAL_ASSERT(false, "radix_sort_parallel: ATen not compiled with FBGEMM support"); - std::make_pair(nullptr, nullptr); + return std::make_pair(nullptr, nullptr); } } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 84c59a4fd0d7..f7997fe72712 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -175,12 +175,6 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa static bool getDisableAddmmCudaLt() { static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); #ifdef USE_ROCM - // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) - // note the current tunable op is not the hipblaslt path (gemm_and_bias) - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { - return true; - } // allow both CUDA and HIP env var names for ROCm builds // also, current default for ROCm builds is disable by default if (env_value == nullptr) { @@ -214,6 +208,49 @@ static bool isSupportedHipLtROCmArch(int index) { } #endif +template +static void launchTunableGemmAndBias(cublasCommonArgs &args, Tensor& result, const Tensor& self, bool is_rocm) { + bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); + bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); + at::cuda::tunable::GemmAndBiasParams params; + params.transa = args.transa; + params.transb = args.transb; + params.m = args.m; + params.n = args.n; + params.k = args.k; + params.a = args.mata->const_data_ptr(); + params.lda = args.lda; + params.b = args.matb->const_data_ptr(); + params.ldb = args.ldb; + if (is_rocm) { + params.bias = (&result != &self) ? self.const_data_ptr() : nullptr; + } + else { + params.bias = self.const_data_ptr(); + } + params.c = args.result->data_ptr(); + params.ldc = args.result_ld; + if (transa_ && transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else if (transa_ && !transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else if (!transa_ && transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else if (!transa_ && !transb_) { + static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; + gemm(¶ms); + } + else { + TORCH_CHECK(false, "unreachable"); + } +} + Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call @@ -341,6 +378,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + launchTunableGemmAndBias(args, result, self, true); + } + else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -359,7 +401,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_to_gemm_and_blas_arg(activation) ); - }); + }}); #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) @@ -377,6 +419,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + launchTunableGemmAndBias(args, result, self, false); + } + else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -393,7 +440,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_epilogue ); - }); + }}); #endif } else { diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index cf7e40115fef..533aa38c04cf 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -1,9 +1,11 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -252,20 +254,156 @@ FOREACH_BINARY_OP_LIST( power_functor, /*division_op*/ true); -template -struct Identity { - __device__ __forceinline__ T operator()(const T& x) { - return x; +template +struct Copy { + __device__ __forceinline__ dst_t operator()(const src_t& x) { + return static_cast(x); } }; +template +struct Copy> { + __device__ __forceinline__ dst_t operator()(const c10::complex& x) { + if constexpr (!(std::is_same_v> || + std::is_same_v>)) { + return static_cast(x.real()); + } else { + return static_cast(x); + } + } +}; + +template +struct Copy> { + __device__ __forceinline__ dst_t operator()(const c10::complex& x) { + if constexpr (!(std::is_same_v> || + std::is_same_v>)) { + return static_cast(x.real()); + } else { + return static_cast(x); + } + } +}; + +#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Byte, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Char, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Short, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Int, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Double, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, src_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::ComplexDouble, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::ComplexFloat, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, \ + src_t, \ + __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Bool, \ + src_t, \ + __VA_ARGS__)) + +namespace { + +template < + typename T, + typename src_t, + int depth, + int r_args_depth, + int res_arg_index> +struct CopyFunctor { + static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1); + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + src_t* src_ptr = (src_t*)tl.addresses[0][tensor_loc]; + src_ptr += chunk_idx * chunk_size; + T* self_ptr = (T*)tl.addresses[1][tensor_loc]; + self_ptr += chunk_idx * chunk_size; + + const bool all_aligned{is_aligned(src_ptr) && is_aligned(self_ptr)}; + + n -= chunk_idx * chunk_size; + src_t src_args[kILP]; + T r_args[kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(src_args, src_ptr, 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[ii] = static_cast(op(src_args[ii])); + } + // store + load_store(self_ptr, r_args, i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + src_args[ii] = src_ptr[i]; + } + } +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[ii] = static_cast(op(src_args[ii])); + } + store_args(self_ptr, r_args, i_start, chunk_size, n); + } + } + } +}; + +} // anonymous namespace + void foreach_tensor_copy_list_kernel_cuda_( TensorList self, TensorList src, const bool non_blocking) { check_foreach_api_restrictions(self, src); - if (!can_use_fast_route( - self, src, /* does_op_promote_integer_inputs_to_float */ false)) { + if (!(_check_tensors_share_device_and_dtype( + {self, src}, /* skip_dtype_check */ true) && + std::all_of( + src.cbegin(), + src.cend(), + [&](const auto& t) -> bool { + return t.dtype() == src[0].dtype(); + }) && + _check_tensors_share_sizes_and_strides({self, src}))) { return at::native::foreach_tensor_copy_list_kernel_slow_( self, src, non_blocking); } @@ -280,16 +418,38 @@ void foreach_tensor_copy_list_kernel_cuda_( "foreach_tensor_copy", [&]() { using opmath_t = at::opmath_type; - multi_tensor_apply<2>( - tensor_lists, - UnaryOpFunctor< - scalar_t, - /* depth */ 2, - /* r_args_depth */ 1, - /* res_arg_index */ 1>(), - Identity()); + AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] { + if constexpr (std::is_same_v) { + multi_tensor_apply<2>( + tensor_lists, + UnaryOpFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Copy()); + } else { + // Ref: + // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301 + if (!self[0].is_complex() && src[0].is_complex()) { + TORCH_WARN_ONCE( + "Casting complex values to real discards the imaginary part"); + } + multi_tensor_apply<2>( + tensor_lists, + CopyFunctor< + scalar_t, + src_t, + /* depth */ 2, + /* r_args_depth */ 1, + /* res_arg_index */ 1>(), + Copy()); + } + }); }); increment_version(self); } +#undef AT_DISPATCH_SOURCE_TYPES + } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 885c5d021e8c..7c2a389351a2 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -16,6 +16,7 @@ #include #include #else +#include #include #include @@ -44,6 +45,181 @@ struct TensorListAddresses { const void* addresses[MAX_TENSORS_PER_KERNEL]; }; +template < + typename T, + int depth = 1, + int r_args_depth = 1, + int res_arg_index = 0> +struct LpMaxFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + T* output_per_tensor_ptr, + const int max_chunks_per_tensor) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* x = (T*)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + n -= chunk_idx * chunk_size; + + __shared__ T s_vals[512]; + T vals[kILP]; + T r_x[kILP]; + for (int64_t i = 0; i < kILP; i++) { + vals[i] = T(std::numeric_limits::lowest()); + r_x[i] = T(std::numeric_limits::lowest()); + } + + if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + vals[ii] = max_propagate_nan(vals[ii], r_x[ii]); + } + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + vals[ii] = max_propagate_nan(vals[ii], x[i]); + } + } + } + } + + auto val = T(std::numeric_limits::lowest()); + for (int i = 0; i < kILP; i++) { + val = max_propagate_nan(val, vals[i]); + } + auto final_val = at::native::cuda_utils::BlockReduceMax(val, s_vals); + + if (threadIdx.x == 0) { + output_per_tensor_ptr + [(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + + chunk_idx] = final_val; + } + } +}; + +template +__global__ void lpmax_cleanup( + const T* output_per_tensor, + TensorListAddresses addr_struct, + int max_chunks_per_tensor) { + __shared__ T vals[512]; + const T* output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; + T val = std::numeric_limits::lowest(); + for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) { + val = max_propagate_nan(val, output_this_tensor[i]); + } + T final_val = at::native::cuda_utils::BlockReduceMax(val, vals); + if (threadIdx.x == 0) { + *(T*)addr_struct.addresses[blockIdx.x] = final_val; + } +} + +std::vector foreach_tensor_max_cuda(TensorList tensors) { + check_foreach_api_restrictions(tensors); + if (!can_use_fast_route(tensors)) { + return foreach_tensor_max_slow(tensors); + } + + // for parity with max in ReduceAllOps.cpp, as max(empty) is ??? + TORCH_CHECK( + std::all_of( + tensors.begin(), + tensors.end(), + [](const auto& t) { return t.numel() > 0; }), + "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); + + const size_t ntensors = tensors.size(); + int max_chunks_per_tensor = -1; + + for (const auto t : c10::irange(ntensors)) { + int max_chunks_this_tensor = + (tensors[t].numel() + kChunkSize - 1) / kChunkSize; + if (max_chunks_this_tensor > max_chunks_per_tensor) { + max_chunks_per_tensor = max_chunks_this_tensor; + } + } + const auto options = tensors[0].options(); + auto output_per_tensor = at::zeros( + {static_cast(ntensors) * max_chunks_per_tensor}, options); + + std::vector vec_res; + vec_res.reserve(ntensors); + for (const auto i : c10::irange(ntensors)) { + vec_res.push_back(at::empty({}, options)); + } + + auto tensor_lists = std::vector>{tensors.vec()}; + + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, + kBFloat16, + kBool, + tensor_lists[0][0].scalar_type(), + "foreach_tensor_max_cuda_scalar_type", + [&]() { + multi_tensor_apply<1>( + tensor_lists, + LpMaxFunctor(), + output_per_tensor.mutable_data_ptr(), + max_chunks_per_tensor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + const at::cuda::OptionalCUDAGuard device_guard( + device_of(output_per_tensor)); + auto stream = at::cuda::getCurrentCUDAStream(); + + const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL); + for (const auto i : c10::irange(num_kernels)) { + const size_t num_tensors_this_kernel = + (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0) + ? MAX_TENSORS_PER_KERNEL + : (ntensors % MAX_TENSORS_PER_KERNEL); + + TensorListAddresses addr_struct; + for (const auto j : c10::irange(num_tensors_this_kernel)) { + addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j] + .mutable_data_ptr(); + } + + lpmax_cleanup<<>>( + output_per_tensor.const_data_ptr() + + i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor, + addr_struct, + max_chunks_per_tensor); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + // correctly assign values to only non-empty slots, as the empty slots should + // get skipped + std::vector result; + result.reserve(ntensors); + int i = 0; + for (const auto& t : tensors) { + if (t.numel() != 0) { + result.emplace_back(vec_res[i]); + i++; + } else { + result.emplace_back(at::empty({}, options)); + } + } + return result; +} + template < typename T, NormType norm_type, diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index b451592f1944..f559625e6b0a 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -2,9 +2,9 @@ // Licensed under the BSD-3-Clause license // This is the GPU implementation of the Connectionist Temporal Loss. // We mostly follow Graves. -// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf +// 1. Graves et al.: http://www.cs.toronto.edu/~graves/icml_2006.pdf // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based. -// Graves et al call the probabilities y, we use log_probs (also calling them inputs) +// Graves et al. call the probabilities y, we use log_probs (also calling them inputs) // A few optimizations (similar to those here, but also some I didn't take) are described in // 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf #define TORCH_ASSERT_ONLY_METHOD_OPERATORS diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 1f67ee3ea63e..85bde8b5990f 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -807,6 +807,7 @@ struct ReduceOp { bool is_last_block_done = mark_block_finished(); if (is_last_block_done) { + __threadfence(); // complete the acquire pattern after atomic value = ident; if (config.should_block_x_reduce()) { index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; diff --git a/aten/src/ATen/native/cuda/Resize.h b/aten/src/ATen/native/cuda/Resize.h index 569b145fa61d..d5de128cac1d 100644 --- a/aten/src/ATen/native/cuda/Resize.h +++ b/aten/src/ATen/native/cuda/Resize.h @@ -29,18 +29,10 @@ static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_b inline TensorImpl* resize_impl_cuda_( TensorImpl* self, IntArrayRef size, - at::OptionalIntArrayRef stride, - bool device_guard = true) { + at::OptionalIntArrayRef stride) { if (self->sizes() == size && (!stride || self->strides() == stride)) { return self; } - - // NB: We don't need to hold the device guard when calling from TH - cuda::OptionalCUDAGuard guard; - if (device_guard) { - guard.set_index(self->storage().device().index()); - } - const auto itemsize = self->dtype().itemsize(); const auto storage_offset = self->storage_offset(); size_t storage_size = 1; diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 4aca753a510b..7616b7bdcc01 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -863,8 +863,8 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); - bool can_use_smem = dim_size < max_elements_per_smem; - can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + bool can_use_smem = (size_t) dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); @@ -899,8 +899,8 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); - bool can_use_smem = dim_size < max_elements_per_smem; - can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + bool can_use_smem = (size_t) dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); can_use_smem &= !(dim_size % ILP); diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index e8fd69c0aec9..df757a11761b 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) { shared[wid] = val; } __syncthreads(); - val = (tid < B::Warps()) ? shared[lid] : T(0); + val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits::lowest()); if (wid == 0) { val = WarpReduceMax(val); } diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 0d870cef5870..67b8d3e54ba5 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -1002,7 +1002,7 @@ std::string generate_code( std::string extra_args = ""; for (size_t i = 0; i < extra_args_typenames.size(); i++) { auto type = std::string(extra_args_typenames[i]); - auto name = "extra_arg_" + std::string(to_string(i)); + auto name = "extra_arg_" + std::to_string(i); extra_params += "," + type + " " + name; extra_args += ", " + name; } diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.h b/aten/src/ATen/native/cuda/linalg/CUDASolver.h index b8901d1d6f5d..9b17086646d8 100644 --- a/aten/src/ATen/native/cuda/linalg/CUDASolver.h +++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.h @@ -18,7 +18,7 @@ namespace solver { template void getrf(CUDASOLVER_GETRF_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::solver::getrf: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrf: not implemented"); } template<> void getrf(CUDASOLVER_GETRF_ARGTYPES(float)); @@ -35,7 +35,7 @@ void getrf>(CUDASOLVER_GETRF_ARGTYPES(c10::complex)); template void getrs(CUDASOLVER_GETRS_ARGTYPES(Dtype)) { - TORCH_CHECK(false, "at::cuda::solver::getrs: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrs: not implemented"); } template<> void getrs(CUDASOLVER_GETRS_ARGTYPES(float)); @@ -51,10 +51,8 @@ void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); template void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::sytrf_bufferSize: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::sytrf_bufferSize: not implemented"); } template <> void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float)); @@ -73,10 +71,8 @@ void sytrf_bufferSize>( template void sytrf(CUDASOLVER_SYTRF_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::sytrf: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::sytrf: not implemented"); } template <> void sytrf(CUDASOLVER_SYTRF_ARGTYPES(float)); @@ -93,7 +89,7 @@ void sytrf>(CUDASOLVER_SYTRF_ARGTYPES(c10::complex)); template void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) { - TORCH_CHECK(false, "at::cuda::solver::gesvd_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd_buffersize: not implemented"); } template<> void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()); @@ -111,7 +107,7 @@ void gesvd_buffersize>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES template void gesvd(CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype)) { - TORCH_CHECK(false, "at::cuda::solver::gesvd: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd: not implemented"); } template<> void gesvd(CUDASOLVER_GESVD_ARGTYPES(float, float)); @@ -129,7 +125,7 @@ void gesvd>(CUDASOLVER_GESVD_ARGTYPES(c10::complex, template void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj_buffersize: not implemented"); } template<> void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(float, float)); @@ -147,7 +143,7 @@ void gesvdj_buffersize>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYP template void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented"); } template<> void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(float, float)); @@ -165,7 +161,7 @@ void gesvdj>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented"); } template<> void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float)); @@ -183,7 +179,7 @@ void gesvdjBatched>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10: template void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented"); } template<> void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(float, float)); @@ -203,7 +199,7 @@ void gesvdaStridedBatched_buffersize>(CUDASOLVER_GESVDA_STR template void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdaStridedBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched: not implemented"); } template<> void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(float, float)); @@ -220,7 +216,7 @@ void gesvdaStridedBatched>(CUDASOLVER_GESVDA_STRIDED_BATCHE template void potrf(CUDASOLVER_POTRF_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrf: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf: not implemented"); } template<> void potrf(CUDASOLVER_POTRF_ARGTYPES(float)); @@ -237,7 +233,7 @@ void potrf>(CUDASOLVER_POTRF_ARGTYPES(c10::complex) template void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrf_buffersize: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf_buffersize: not implemented"); } template<> void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(float)); @@ -254,7 +250,7 @@ void potrf_buffersize>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES template void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrfBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrfBatched: not implemented"); } template<> void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(float)); @@ -270,10 +266,8 @@ void potrfBatched>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::c template void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) { - TORCH_CHECK( - false, - "at::cuda::solver::geqrf_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::geqrf_bufferSize: not implemented"); } template <> void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float)); @@ -292,10 +286,8 @@ void geqrf_bufferSize>( template void geqrf(CUDASOLVER_GEQRF_ARGTYPES(scalar_t)) { - TORCH_CHECK( - false, - "at::cuda::solver::geqrf: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::geqrf: not implemented"); } template <> void geqrf(CUDASOLVER_GEQRF_ARGTYPES(float)); @@ -312,7 +304,7 @@ void geqrf>( template void potrs(CUDASOLVER_POTRS_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrs: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrs: not implemented"); } template<> void potrs(CUDASOLVER_POTRS_ARGTYPES(float)); @@ -329,7 +321,7 @@ void potrs>(CUDASOLVER_POTRS_ARGTYPES(c10::complex) template void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::potrsBatched: not implemented for ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrsBatched: not implemented"); } template<> void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(float)); @@ -347,10 +339,7 @@ void potrsBatched>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::c template void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::orgqr_buffersize: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr_buffersize: not implemented"); } template <> void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(float)); @@ -368,10 +357,7 @@ void orgqr_buffersize>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES template void orgqr(CUDASOLVER_ORGQR_ARGTYPES(Dtype)) { - TORCH_CHECK( - false, - "at::cuda::solver::orgqr: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr: not implemented"); } template <> void orgqr(CUDASOLVER_ORGQR_ARGTYPES(float)); @@ -389,10 +375,8 @@ void orgqr>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex) template void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::ormqr_bufferSize: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::ormqr_bufferSize: not implemented"); } template <> void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float)); @@ -412,10 +396,8 @@ void ormqr_bufferSize>( template void ormqr(CUDASOLVER_ORMQR_ARGTYPES(Dtype)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::ormqr: not implemented for ", - typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), + "at::cuda::solver::ormqr: not implemented"); } template <> void ormqr(CUDASOLVER_ORMQR_ARGTYPES(float)); @@ -431,7 +413,8 @@ void ormqr>( template cudaDataType get_cusolver_datatype() { - TORCH_CHECK(false, "cusolver doesn't support data type ", typeid(Dtype).name()); + static_assert(false&&sizeof(Dtype), "cusolver doesn't support data type"); + return {}; } template<> cudaDataType get_cusolver_datatype(); template<> cudaDataType get_cusolver_datatype(); @@ -459,10 +442,8 @@ void xpotrs( template void syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevd_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevd_bufferSize: not implemented"); } template <> @@ -485,10 +466,8 @@ void syevd_bufferSize, double>( template void syevd(CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevd: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevd: not implemented"); } template <> @@ -509,10 +488,8 @@ void syevd, double>( template void syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevj_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevj_bufferSize: not implemented"); } template <> @@ -535,10 +512,7 @@ void syevj_bufferSize, double>( template void syevj(CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevj: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevj: not implemented"); } template <> @@ -560,10 +534,8 @@ void syevj, double>( template void syevjBatched_bufferSize( CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevjBatched_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevjBatched_bufferSize: not implemented"); } template <> @@ -586,10 +558,8 @@ void syevjBatched_bufferSize, double>( template void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::syevjBatched: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::syevjBatched: not implemented"); } template <> @@ -612,10 +582,8 @@ void syevjBatched, double>( template void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xgeqrf_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xgeqrf_bufferSize: not implemented"); } template <> @@ -637,10 +605,7 @@ void xgeqrf_bufferSize>( template void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(scalar_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xgeqrf: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xgeqrf: not implemented"); } template <> @@ -663,10 +628,8 @@ void xgeqrf>( template void xsyevd_bufferSize( CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xsyevd_bufferSize: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xsyevd_bufferSize: not implemented"); } template <> @@ -691,10 +654,8 @@ void xsyevd_bufferSize, double>( template void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t)) { - TORCH_INTERNAL_ASSERT( - false, - "at::cuda::solver::xsyevd: not implemented for ", - typeid(scalar_t).name()); + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xsyevd: not implemented"); } template <> diff --git a/aten/src/ATen/native/cuda/reduction_template.cuh b/aten/src/ATen/native/cuda/reduction_template.cuh index a38edb538256..6d1e861493d4 100644 --- a/aten/src/ATen/native/cuda/reduction_template.cuh +++ b/aten/src/ATen/native/cuda/reduction_template.cuh @@ -595,6 +595,7 @@ struct ReduceJitOp { bool is_last_block_done = mark_block_finished(); if (is_last_block_done) { + __threadfence(); //complete acquire pattern value = ident; if (config.should_block_x_reduce()) { uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 1f6bdbf5305a..ab19b5d68a90 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -13,7 +13,8 @@ void run_cudnn_SDP_fprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool isTraining, bool is_causal, @@ -34,7 +35,8 @@ void run_cudnn_SDP_bprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, @@ -128,7 +130,8 @@ struct MHAParams { int64_t h; int64_t s_q; int64_t s_kv; - int64_t d; + int64_t d_qk; + int64_t d_v; double dropout_probability; bool is_causal; bool return_softmaxstats; @@ -140,7 +143,8 @@ void setMHAParams( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, const Tensor& q, const Tensor& k, const Tensor& v, @@ -155,7 +159,8 @@ void setMHAParams( } params.b = b; params.h = h; - params.d = d; + params.d_qk = d_qk; + params.d_v = d_v; params.s_q = s_q; params.s_kv = s_kv; params.dropout_probability = dropout_probability; @@ -193,7 +198,8 @@ struct MHACacheKeyWrapper : ParamsWrapper { int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, const Tensor& q, const Tensor& k, const Tensor& v, @@ -206,7 +212,8 @@ struct MHACacheKeyWrapper : ParamsWrapper { h, s_q, s_kv, - d, + d_qk, + d_v, q, k, v, @@ -249,7 +256,8 @@ auto build_graph_and_tensors( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool return_softmaxstats, bool is_causal, @@ -383,7 +391,8 @@ auto build_graph_and_tensors_backward( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, @@ -514,7 +523,8 @@ void run_cudnn_SDP_fprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool return_softmaxstats, bool is_causal, @@ -528,7 +538,7 @@ void run_cudnn_SDP_fprop( Tensor& dropoutoffset) { cudnnHandle_t handle = getCudnnHandle(); o = at::empty_strided( - {b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options()); + {b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options()); if (return_softmaxstats) { // TODO(eqy): verify that this is correct softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); @@ -539,7 +549,8 @@ void run_cudnn_SDP_fprop( h, s_q, s_kv, - d, + d_qk, + d_v, q, k, v, @@ -556,7 +567,8 @@ void run_cudnn_SDP_fprop( h, s_q, s_kv, - d, + d_qk, + d_v, scaling_factor, return_softmaxstats, is_causal, @@ -599,7 +611,8 @@ void run_cudnn_SDP_bprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_qk, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, @@ -614,9 +627,27 @@ void run_cudnn_SDP_bprop( Tensor& dV, const Tensor& dropoutseed, const Tensor& dropoutoffset) { + Tensor dO_ = dO; + if (!dO.strides()[dO.strides().size() - 1]) { + TORCH_WARN( + "cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported. Materializing a contiguous\ + tensor which will increase memory usage..."); + dO_ = dO.contiguous(); + } cudnnHandle_t handle = getCudnnHandle(); auto key = MHACacheKeyWrapper( - b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true); + b, + h, + s_q, + s_kv, + d_qk, + d_v, + q, + k, + v, + dropout_probability, + is_causal, + true); auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); graph_and_tensors_backward graph_and_tensors_backward_values; if (graph_and_tensors_backward_ptr) { @@ -627,7 +658,8 @@ void run_cudnn_SDP_bprop( h, s_q, s_kv, - d, + d_qk, + d_v, scaling_factor, is_causal, dropout_probability, @@ -635,7 +667,7 @@ void run_cudnn_SDP_bprop( k, v, o, - dO, + dO_, softmaxstats, dQ, dK, @@ -677,5 +709,4 @@ void run_cudnn_SDP_bprop( } // namespace native } // namespace at - #endif diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 0406cf783dc5..8b9315a5a3d8 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -9,7 +9,8 @@ void run_cudnn_SDP_fprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_k, + int64_t d_v, float scaling_factor, bool isTraining, bool is_causal, @@ -27,7 +28,8 @@ void run_cudnn_SDP_bprop( int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, + int64_t d_k, + int64_t d_v, float scaling_factor, bool is_causal, float dropout_probability, diff --git a/aten/src/ATen/native/im2col_shape_check.h b/aten/src/ATen/native/im2col_shape_check.h index f7ae0854f78e..8a6fa47ba10f 100644 --- a/aten/src/ATen/native/im2col_shape_check.h +++ b/aten/src/ATen/native/im2col_shape_check.h @@ -5,7 +5,7 @@ namespace at::native { -static inline void col2im_shape_check( +inline void col2im_shape_check( const Tensor& input, const Tensor& grad_output, int64_t output_height, @@ -135,7 +135,7 @@ static inline void col2im_shape_check( } } -static inline void im2col_shape_check( +inline void im2col_shape_check( const Tensor& input, const Tensor& grad_output, int64_t kernel_height, diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm index a1ee8e6f8ded..ec6156573e06 100644 --- a/aten/src/ATen/native/metal/MetalAten.mm +++ b/aten/src/ATen/native/metal/MetalAten.mm @@ -6,10 +6,9 @@ #include namespace at { -namespace native { -namespace metal { +namespace native::metal { -at::Tensor& copy_from_metal_(at::Tensor& dst, const at::Tensor& src) { +static Tensor& copy_from_metal_(Tensor& dst, const Tensor& src) { TORCH_INTERNAL_ASSERT( src.device().type() == DeviceType::Metal, "copy_from_metal input tensor's device is not metal"); @@ -34,7 +33,7 @@ return dst; } -at::Tensor& copy_to_metal_(at::Tensor& dst, const at::Tensor& src) { +static Tensor& copy_to_metal_(Tensor& dst, const Tensor& src) { TORCH_INTERNAL_ASSERT( dst.device().type() == DeviceType::Metal, "copy_to_metal_ output tensor's device is not metal"); @@ -54,7 +53,7 @@ return dst; } -at::Tensor& metal_copy_impl_(at::Tensor& dst, const at::Tensor& src) { +static Tensor& metal_copy_impl_(Tensor& dst, const Tensor& src) { if (src.device().type() == at::kMetal && dst.device().type() == at::kCPU) { return copy_from_metal_(dst, src); } @@ -69,7 +68,7 @@ #pragma mark - ATen Ops -Tensor empty( +static Tensor empty( c10::SymIntArrayRef sym_size, optional dtype, optional layout, @@ -88,7 +87,7 @@ Tensor empty( std::move(mt), at::device(at::kMetal).dtype(dtype)); }; -at::Tensor empty_strided( +static Tensor empty_strided( IntArrayRef size, IntArrayRef stride, optional dtype, @@ -109,8 +108,7 @@ Tensor empty( m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(empty_strided)); } -} // namespace metal -} // namespace native +} // namespace native::metal struct MetalImpl : public at::metal::MetalInterface { bool is_metal_available() const override { diff --git a/aten/src/ATen/native/metal/MetalConvParams.h b/aten/src/ATen/native/metal/MetalConvParams.h index 7b0bfc9670a1..55a8ea657e72 100644 --- a/aten/src/ATen/native/metal/MetalConvParams.h +++ b/aten/src/ATen/native/metal/MetalConvParams.h @@ -3,9 +3,7 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { struct Conv2DParams final { Conv2DParams() {} @@ -46,8 +44,6 @@ struct Conv2DParams final { int64_t OH; // output height }; -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal #endif /* MetalConvParams_h */ diff --git a/aten/src/ATen/native/metal/MetalDevice.h b/aten/src/ATen/native/metal/MetalDevice.h index 29d34246cc1b..42c3ae43cd02 100644 --- a/aten/src/ATen/native/metal/MetalDevice.h +++ b/aten/src/ATen/native/metal/MetalDevice.h @@ -5,9 +5,7 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { struct MetalDeviceInfo { std::string name; @@ -42,8 +40,6 @@ static inline MetalDeviceInfo createDeviceInfo(id device) { return device_info; } -} -} -} +} // namespace at::native::metal #endif diff --git a/aten/src/ATen/native/metal/MetalNeuronType.h b/aten/src/ATen/native/metal/MetalNeuronType.h index c5cb0b99502c..e1cada24a7fd 100644 --- a/aten/src/ATen/native/metal/MetalNeuronType.h +++ b/aten/src/ATen/native/metal/MetalNeuronType.h @@ -6,9 +6,7 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { enum class NeuronType { None, @@ -66,8 +64,6 @@ static inline MPSNNNeuronDescriptor* neuronDescriptor(NeuronType type) { } } -} -} -} +} // namespace at::native::metal #endif /* MetalNeuronType_h */ diff --git a/aten/src/ATen/native/metal/MetalPrepackOpContext.h b/aten/src/ATen/native/metal/MetalPrepackOpContext.h index 4481c879eec2..a484812d6874 100644 --- a/aten/src/ATen/native/metal/MetalPrepackOpContext.h +++ b/aten/src/ATen/native/metal/MetalPrepackOpContext.h @@ -3,9 +3,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using SerializationTypeConv2dPrePack = std::tuple< Tensor, @@ -197,6 +195,4 @@ class LinearOpContext : public torch::jit::CustomClassHolder { std::function releaseCallback_ = nullptr; }; -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp index ebf9b9daf626..d4a7e463d777 100644 --- a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp +++ b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp @@ -3,11 +3,9 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -c10::intrusive_ptr unpack( +static c10::intrusive_ptr unpack( Tensor&& weight, std::optional&& bias, std::vector&& stride, @@ -28,7 +26,7 @@ c10::intrusive_ptr unpack( output_max); } -c10::intrusive_ptr unpack( +static c10::intrusive_ptr unpack( Tensor&& weight, std::optional&& bias, const std::optional& output_min, @@ -94,7 +92,7 @@ TORCH_LIBRARY(metal_prepack, m) { TORCH_SELECTIVE_SCHEMA("metal_prepack::linear_run(Tensor X, __torch__.torch.classes.metal.LinearOpContext W_prepack) -> Tensor Y")); } -c10::intrusive_ptr conv2d_prepack( +static c10::intrusive_ptr conv2d_prepack( Tensor&& weight, std::optional&& bias, std::vector&& stride, @@ -115,7 +113,7 @@ c10::intrusive_ptr conv2d_prepack( output_max); } -c10::intrusive_ptr linear_prepack( +static c10::intrusive_ptr linear_prepack( Tensor&& weight, std::optional&& bias, const std::optional& output_min, @@ -129,6 +127,4 @@ TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_prepack"), TORCH_FN(linear_prepack)); } -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/MetalTensorImplStorage.h b/aten/src/ATen/native/metal/MetalTensorImplStorage.h index 1ac7d126de95..975827aee15a 100644 --- a/aten/src/ATen/native/metal/MetalTensorImplStorage.h +++ b/aten/src/ATen/native/metal/MetalTensorImplStorage.h @@ -1,9 +1,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { class MPSImageWrapper; class MetalTensorImplStorage final { @@ -42,6 +40,4 @@ class MetalTensorImplStorage final { std::shared_ptr _impl; }; -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/MetalTensorUtils.h b/aten/src/ATen/native/metal/MetalTensorUtils.h index 318da09d86b2..9663e59fb74d 100644 --- a/aten/src/ATen/native/metal/MetalTensorUtils.h +++ b/aten/src/ATen/native/metal/MetalTensorUtils.h @@ -10,9 +10,7 @@ typedef float16_t fp16_t; typedef uint16_t fp16_t; #endif -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { uint32_t batchSize(const Tensor& tensor); uint32_t channelsSize(const Tensor& tensor); @@ -70,6 +68,4 @@ static inline MetalCommandBuffer* getCommandBuffer( return cmdBuffer; } -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h index 13264d097e92..346d58ace539 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.h @@ -20,10 +20,7 @@ } \ } while (false) -namespace at { -namespace native { -namespace metal { -namespace mpscnn { +namespace at::native::metal::mpscnn { struct LaunchParams { MTLSize threadsPerThreadgroup; @@ -71,7 +68,4 @@ static inline int computeMPSAlignOffset(int kernel, int pad) { return mps_offset - pt_offset; } -} -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal::mpscnn diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm index ff8ad447dd0f..90f4ed030000 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm @@ -1,11 +1,8 @@ #import -namespace at { -namespace native { -namespace metal { -namespace mpscnn { +namespace at::native::metal::mpscnn { -auto divRoundUp(uint x, uint y) -> uint { +static auto divRoundUp(uint x, uint y) -> uint { return (x + y - 1) / y; } @@ -14,7 +11,7 @@ LaunchParams spatialPointwiseKernelLaunchParams( MPSImage* im) { return spatialPointwiseKernelLaunchParams( pipeline, im.numberOfImages, im.featureChannels, im.height, im.width); -}; +} LaunchParams spatialPointwiseKernelLaunchParams( id pipeline, @@ -33,9 +30,6 @@ LaunchParams spatialPointwiseKernelLaunchParams( const auto threadsPerGrid = MTLSizeMake( width, height, numberOfImages * divRoundUp(featureChannels, 4)); return {threadsPerThreadgroup, threadgroupsPerGrid, threadsPerGrid}; -}; - -} -} -} } + +} // namespace at::native::metal::mpscnn diff --git a/aten/src/ATen/native/metal/ops/MetalAddmm.mm b/aten/src/ATen/native/metal/ops/MetalAddmm.mm index e0c196ac68b3..b10b2a4b81f3 100644 --- a/aten/src/ATen/native/metal/ops/MetalAddmm.mm +++ b/aten/src/ATen/native/metal/ops/MetalAddmm.mm @@ -12,12 +12,10 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor addmm( +static Tensor addmm( const Tensor& bias, const Tensor& input, const Tensor& weight, @@ -63,7 +61,7 @@ Tensor addmm( namespace prepack { -Tensor linear(const Tensor& input, LinearOpContext& context) { +static Tensor linear(const Tensor& input, LinearOpContext& context) { TORCH_CHECK(input.is_metal()); TORCH_CHECK(context.get_weight().device() == kCPU); TORCH_CHECK(context.get_weight().dim() == 4); @@ -126,7 +124,7 @@ Tensor linear(const Tensor& input, LinearOpContext& context) { return output; } -Tensor linear_run( +static Tensor linear_run( const Tensor& input, const c10::intrusive_ptr& op_context) { return linear(input, *op_context); @@ -142,6 +140,4 @@ Tensor linear_run( m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_run"), TORCH_FN(prepack::linear_run)); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm b/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm index 0b5312632e1d..8505a89b9681 100644 --- a/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm +++ b/aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm @@ -10,9 +10,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; @@ -58,7 +56,7 @@ static inline void checkInputs(const Tensor& input1, const Tensor& input2) { } } -Tensor binaryElementwiseShaderKernel( +static Tensor binaryElementwiseShaderKernel( const Tensor& input1, const Tensor& input2, const std::string& arrayKernel, @@ -98,7 +96,7 @@ Tensor binaryElementwiseShaderKernel( return output; } -Tensor& binaryElementwiseShaderKernel_( +static Tensor& binaryElementwiseShaderKernel_( Tensor& input1, const Tensor& input2, const std::string& arrayKernel, @@ -208,7 +206,7 @@ Tensor binaryElementwiseMPSCNNKernel( return input1; } -Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -219,7 +217,7 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor& add__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor& add__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -230,7 +228,7 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -241,7 +239,7 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { +static Tensor& sub__Tensor(Tensor& input1, const Tensor& input2, const Scalar& alpha) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -252,7 +250,7 @@ Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, const Scalar& alph } } -Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { +static Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -263,7 +261,7 @@ Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { } } -Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) { +static Tensor& mul__Tensor(Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -274,7 +272,7 @@ Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) { } } -Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { +static Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -285,7 +283,7 @@ Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { } } -Tensor& div__Tensor(Tensor& input1, const Tensor& input2) { +static Tensor& div__Tensor(Tensor& input1, const Tensor& input2) { TORCH_CHECK(input1.is_metal()); auto input2_ = input2.is_metal() ? input2 : input2.metal(); if (@available(iOS 11.3, *)) { @@ -305,8 +303,6 @@ Tensor div_Tensor(const Tensor& input1, const Tensor& input2) { m.impl(TORCH_SELECTIVE_NAME("aten::sub_.Tensor"), TORCH_FN(sub__Tensor)); m.impl(TORCH_SELECTIVE_NAME("aten::div.Tensor"), TORCH_FN(div_Tensor)); m.impl(TORCH_SELECTIVE_NAME("aten::div_.Tensor"), TORCH_FN(div__Tensor)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalChunk.mm b/aten/src/ATen/native/metal/ops/MetalChunk.mm index ee02b269a580..0011b065bf81 100644 --- a/aten/src/ATen/native/metal/ops/MetalChunk.mm +++ b/aten/src/ATen/native/metal/ops/MetalChunk.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { // Split the input tensor into two on channel dimension // TODO: [T87567124] Fully implement chunk in Metal shader -std::vector chunk(const Tensor& input, int64_t chunks, int64_t dim) { +static std::vector chunk(const Tensor& input, int64_t chunks, int64_t dim) { TORCH_CHECK(chunks == 2 && dim == 1); TORCH_CHECK(input.dim() == 4); TORCH_CHECK(input.size(0) == 1); @@ -61,8 +59,6 @@ TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::chunk"), TORCH_FN(chunk)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalClamp.mm b/aten/src/ATen/native/metal/ops/MetalClamp.mm index b0eac2460ac3..4eedf3775028 100644 --- a/aten/src/ATen/native/metal/ops/MetalClamp.mm +++ b/aten/src/ATen/native/metal/ops/MetalClamp.mm @@ -8,11 +8,9 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) { +static Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) { TORCH_CHECK(input.is_metal()); MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); @@ -29,7 +27,7 @@ return input; } -Tensor hardtanh( +static Tensor hardtanh( const Tensor& input, const Scalar& min_val, const Scalar& max_val) { @@ -52,7 +50,7 @@ Tensor hardtanh( return output; } -at::Tensor clamp( +static at::Tensor clamp( const at::Tensor& input, const c10::optional& min, const c10::optional& max) { @@ -64,8 +62,6 @@ Tensor hardtanh( m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), TORCH_FN(hardtanh_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), TORCH_FN(hardtanh)); m.impl(TORCH_SELECTIVE_NAME("aten::clamp"), TORCH_FN(clamp)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalConcat.mm b/aten/src/ATen/native/metal/ops/MetalConcat.mm index be9d87d8fe5a..5de99046f2d0 100644 --- a/aten/src/ATen/native/metal/ops/MetalConcat.mm +++ b/aten/src/ATen/native/metal/ops/MetalConcat.mm @@ -12,11 +12,9 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { +static Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); MPSImage* Y = mt.texture()->image(); ushort cat_dim4_pointer = 0; @@ -53,7 +51,7 @@ Tensor cat_batch(const Tensor& tensor, const ITensorListRef& tensors, MetalTenso return output; } -Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { +static Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTensorImplStorage& mt) { MetalCommandBuffer* commandBuffer = getCommandBuffer(tensor); MPSImage* Y = mt.texture()->image(); ushort channel_offset = 0; @@ -162,7 +160,7 @@ Tensor cat_feature(const Tensor& tensor, const ITensorListRef& tensors, MetalTen return output; } -Tensor cat(const ITensorListRef& tensors, int64_t dim) { +static Tensor cat(const ITensorListRef& tensors, int64_t dim) { TORCH_CHECK( dim == 0 || dim == 1, "Metal cat is implemented only for batch dimension"); @@ -203,6 +201,4 @@ Tensor cat(const ITensorListRef& tensors, int64_t dim) { m.impl(TORCH_SELECTIVE_NAME("aten::cat"), TORCH_FN(cat)); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalConvolution.h b/aten/src/ATen/native/metal/ops/MetalConvolution.h index 77053448cbcb..dc8192812d8c 100644 --- a/aten/src/ATen/native/metal/ops/MetalConvolution.h +++ b/aten/src/ATen/native/metal/ops/MetalConvolution.h @@ -2,9 +2,7 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { Tensor conv2d( const Tensor& input, @@ -19,6 +17,4 @@ namespace prepack { Tensor conv2d(const Tensor& input, Conv2dOpContext& context); } -} // namespace metal -} // namespace native -} // namespace at +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalConvolution.mm b/aten/src/ATen/native/metal/ops/MetalConvolution.mm index 46295abefae9..eb5d1f16fabb 100644 --- a/aten/src/ATen/native/metal/ops/MetalConvolution.mm +++ b/aten/src/ATen/native/metal/ops/MetalConvolution.mm @@ -9,9 +9,7 @@ #import -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; Tensor conv2d( @@ -97,7 +95,7 @@ Tensor conv2d(const Tensor& input, Conv2dOpContext& context) { return output; } -Tensor conv2d_prepack_run( +static Tensor conv2d_prepack_run( const Tensor& input, const c10::intrusive_ptr& op_context) { return conv2d(input, *op_context); @@ -115,6 +113,4 @@ Tensor conv2d_prepack_run( m.impl(TORCH_SELECTIVE_NAME("metal_prepack::conv2d_run"), prepack::conv2d_prepack_run); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalCopy.h b/aten/src/ATen/native/metal/ops/MetalCopy.h index fdee7acad4f4..2023d3c508e2 100644 --- a/aten/src/ATen/native/metal/ops/MetalCopy.h +++ b/aten/src/ATen/native/metal/ops/MetalCopy.h @@ -3,14 +3,10 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { Tensor copy_to_host(const Tensor& input); -} -} // namespace native -} // namespace at +} // namespace at::native::metal #endif diff --git a/aten/src/ATen/native/metal/ops/MetalCopy.mm b/aten/src/ATen/native/metal/ops/MetalCopy.mm index b1df48b5c89c..c4ce058f78ed 100644 --- a/aten/src/ATen/native/metal/ops/MetalCopy.mm +++ b/aten/src/ATen/native/metal/ops/MetalCopy.mm @@ -9,11 +9,9 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor copy_to_host(const Tensor& input) { +static Tensor copy_to_host(const Tensor& input) { TORCH_CHECK(input.is_metal()); MPSImage* X = imageFromTensor(input); if (X && !X.isTemporaryImage) { @@ -52,8 +50,6 @@ Tensor copy_to_host(const Tensor& input) { TORCH_LIBRARY_IMPL(metal, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("metal::copy_to_host"), TORCH_FN(copy_to_host)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalHardshrink.mm b/aten/src/ATen/native/metal/ops/MetalHardshrink.mm index 4de506cb6526..05b6b585e7f0 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardshrink.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardshrink.mm @@ -9,15 +9,13 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; // NB: this is currently unused, but I've left it because in principle // it's useful -Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) { +static Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) { float l = lambda.toFloat(); MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); @@ -51,7 +49,7 @@ return input; } -Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { +static Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { float l = lambda.toFloat(); MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); @@ -87,8 +85,6 @@ Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalHardswish.mm b/aten/src/ATen/native/metal/ops/MetalHardswish.mm index 07706483c1ae..22d84d6c1bf0 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardswish.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardswish.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; -Tensor& hardswish_(Tensor& input) { +static Tensor& hardswish_(Tensor& input) { MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); IntArrayRef outputSize = input.sizes(); @@ -47,7 +45,7 @@ return input; } -Tensor hardswish(const at::Tensor& input) { +static Tensor hardswish(const at::Tensor& input) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); MetalTensorImplStorage mt{outputSize.vec()}; @@ -82,8 +80,6 @@ Tensor hardswish(const at::Tensor& input) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), TORCH_FN(hardswish_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), TORCH_FN(hardswish)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm b/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm index 2034a64d82d5..0bd476ffa4f5 100644 --- a/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm +++ b/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; -Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) { +static Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) { MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); IntArrayRef outputSize = input.sizes(); @@ -49,7 +47,7 @@ return input; } -Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { +static Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); MetalTensorImplStorage mt{outputSize.vec()}; @@ -86,8 +84,6 @@ Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), TORCH_FN(leaky_relu_)); m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), TORCH_FN(leaky_relu)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalNeurons.mm b/aten/src/ATen/native/metal/ops/MetalNeurons.mm index ca925d9b841b..09944092f6a1 100644 --- a/aten/src/ATen/native/metal/ops/MetalNeurons.mm +++ b/aten/src/ATen/native/metal/ops/MetalNeurons.mm @@ -9,13 +9,11 @@ #import #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { using MetalTensorImpl = at::MetalTensorImpl; -Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { +static Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); if(input.numel() == 0){ @@ -33,7 +31,7 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { return output; } -Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) { +static Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) { MPSImage* X = imageFromTensor(input); IntArrayRef outputSize = input.sizes(); if(input.numel() == 0){ @@ -52,30 +50,30 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor relu(const Tensor& input) { +static Tensor relu(const Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel(input, [MPSCNNNeuronOp relu]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor& relu_(Tensor& input) { +static Tensor& relu_(Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel_(input, [MPSCNNNeuronOp relu]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor sigmoid(const Tensor& input) { +static Tensor sigmoid(const Tensor& input) { return neuronKernel(input, [MPSCNNNeuronOp sigmoid]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor& hardsigmoid_(Tensor& input) { +static Tensor& hardsigmoid_(Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel_(input, [MPSCNNNeuronOp hardSigmoid]); } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor tanh(const Tensor& input) { +static Tensor tanh(const Tensor& input) { TORCH_CHECK(input.is_metal()); return neuronKernel(input, [MPSCNNNeuronOp tanh]); } @@ -86,8 +84,6 @@ Tensor tanh(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::relu_"), TORCH_FN(relu_)); m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), TORCH_FN(sigmoid)); m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), TORCH_FN(hardsigmoid_)); -}; - -} -} } + +} // namepsace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalPadding.mm b/aten/src/ATen/native/metal/ops/MetalPadding.mm index 748fa8f4b653..c924c40cc62b 100644 --- a/aten/src/ATen/native/metal/ops/MetalPadding.mm +++ b/aten/src/ATen/native/metal/ops/MetalPadding.mm @@ -9,12 +9,10 @@ #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { +static Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { TORCH_CHECK(input.is_metal()); const int pad_dim = padding.size(); @@ -87,8 +85,6 @@ Tensor reflection_pad2d(const Tensor& input, IntArrayRef padding) { TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::reflection_pad2d"), TORCH_FN(reflection_pad2d)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalPooling.mm b/aten/src/ATen/native/metal/ops/MetalPooling.mm index 5e3b9110756e..a4d5c07f39fd 100644 --- a/aten/src/ATen/native/metal/ops/MetalPooling.mm +++ b/aten/src/ATen/native/metal/ops/MetalPooling.mm @@ -11,12 +11,10 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor max_pool2d( +static Tensor max_pool2d( const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -71,7 +69,7 @@ Tensor max_pool2d( } API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { +static Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { // averages across the width and height, and outputs a 1x1xC image. TORCH_CHECK(output_size[0] == 1 && output_size[1] == 1); TORCH_CHECK(input.is_metal()); @@ -108,6 +106,4 @@ Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) { m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_avg_pool2d"), TORCH_FN(adaptive_avg_pool2d)); } -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalReduce.mm b/aten/src/ATen/native/metal/ops/MetalReduce.mm index b0da375809b8..3de3104f6f93 100644 --- a/aten/src/ATen/native/metal/ops/MetalReduce.mm +++ b/aten/src/ATen/native/metal/ops/MetalReduce.mm @@ -11,9 +11,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.3), macos(10.13)) static inline MPSNNReduceUnary* kernelForReducedDim(int dim) { @@ -28,7 +26,7 @@ return nil; } -Tensor wrapper_mean_dim( +static Tensor wrapper_mean_dim( const Tensor& input, OptionalIntArrayRef opt_dims, bool keepdim, @@ -82,6 +80,4 @@ Tensor wrapper_mean_dim( m.impl(TORCH_SELECTIVE_NAME("aten::mean.dim"), TORCH_FN(wrapper_mean_dim)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalReshape.mm b/aten/src/ATen/native/metal/ops/MetalReshape.mm index a4336d1b92d4..de224018eb7c 100644 --- a/aten/src/ATen/native/metal/ops/MetalReshape.mm +++ b/aten/src/ATen/native/metal/ops/MetalReshape.mm @@ -11,12 +11,10 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { +static Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { auto size = C10_AS_INTARRAYREF_SLOW(sym_size); TORCH_CHECK(input.is_metal()); auto inferred_size = at::infer_size(size, input.numel()); @@ -62,12 +60,12 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { return output; } -Tensor reshape(const Tensor& input, IntArrayRef shape) { +static Tensor reshape(const Tensor& input, IntArrayRef shape) { TORCH_CHECK(input.is_metal()); return view(input, c10::fromIntArrayRefSlow(shape)); } -Tensor flatten_using_ints( +static Tensor flatten_using_ints( const Tensor& input, int64_t start_dim, int64_t end_dim) { @@ -97,7 +95,7 @@ Tensor flatten_using_ints( return input.reshape(shape); } -Tensor detach(const Tensor& input) { +static Tensor detach(const Tensor& input) { TORCH_CHECK(input.is_metal()); return input; } @@ -107,8 +105,6 @@ Tensor detach(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::view"), TORCH_FN(view)); m.impl(TORCH_SELECTIVE_NAME("aten::reshape"), TORCH_FN(reshape)); m.impl(TORCH_SELECTIVE_NAME("aten::flatten.using_ints"), TORCH_FN(flatten_using_ints)); -}; - -} -} } + +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalSoftmax.mm b/aten/src/ATen/native/metal/ops/MetalSoftmax.mm index 11ebe255953f..6ec8f60f3ae7 100644 --- a/aten/src/ATen/native/metal/ops/MetalSoftmax.mm +++ b/aten/src/ATen/native/metal/ops/MetalSoftmax.mm @@ -10,9 +10,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { template Tensor mpscnn_softmax( @@ -50,14 +48,14 @@ Tensor mpscnn_softmax( return output; } -Tensor log_softmax_int( +static Tensor log_softmax_int( const Tensor& input, int64_t dim, c10::optional dtype) { return mpscnn_softmax(input, dim, dtype); } -Tensor softmax_int( +static Tensor softmax_int( const Tensor& input, int64_t dim, c10::optional dtype) { @@ -69,6 +67,4 @@ Tensor softmax_int( m.impl(TORCH_SELECTIVE_NAME("aten::softmax.int"), TORCH_FN(metal::softmax_int)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalTranspose.mm b/aten/src/ATen/native/metal/ops/MetalTranspose.mm index e1b57a2a4019..d0df9f7596e6 100644 --- a/aten/src/ATen/native/metal/ops/MetalTranspose.mm +++ b/aten/src/ATen/native/metal/ops/MetalTranspose.mm @@ -10,9 +10,7 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { // TODO: Move this function to MetalContext template @@ -24,7 +22,7 @@ return buffer; } -Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { +static Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { TORCH_CHECK(input.is_metal()); auto ndims = input.dim(); // Support maximum eight channels on mobile @@ -87,7 +85,7 @@ Tensor transpose(const Tensor& input, int64_t dim0, int64_t dim1) { } } -Tensor t(const Tensor& input) { +static Tensor t(const Tensor& input) { TORCH_CHECK(input.is_metal()); TORCH_CHECK(input.dim() == 2); return metal::transpose(input, 0, input.dim() < 2 ? 0 : 1); @@ -98,6 +96,4 @@ Tensor t(const Tensor& input) { m.impl(TORCH_SELECTIVE_NAME("aten::transpose.int"), TORCH_FN(transpose)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm b/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm index 39524569bae5..165e139c886d 100644 --- a/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm +++ b/aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm @@ -11,11 +11,9 @@ #include #include -namespace at { -namespace native { -namespace metal { +namespace at::native::metal { -Tensor upsample_nearest2d_vec( +static Tensor upsample_nearest2d_vec( const Tensor& input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) { @@ -70,6 +68,4 @@ Tensor upsample_nearest2d_vec( m.impl(TORCH_SELECTIVE_NAME("aten::upsample_nearest2d.vec"), TORCH_FN(upsample_nearest2d_vec)); }; -} -} -} +} // namespace at::native::metal diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 09dca06e2b5a..643bd7eed0a2 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -27,53 +27,7 @@ Tensor mkldnn_convolution( TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support"); } -static Tensor mkldnn_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_backward_weights( - IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_backward( - const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) { - TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support"); -} - REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub); - -static Tensor mkldnn_convolution_transpose( - const Tensor& input, const Tensor& weight, const std::optional& bias_opt, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) { - TORCH_CHECK(false, "mkldnn_convolution_transpose: ATen not compiled with MKLDNN support"); -} - -static Tensor mkldnn_convolution_transpose_backward_input( - IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_input: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_transpose_backward_weights( - IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool bias_defined) { - TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_weights: ATen not compiled with MKLDNN support"); -} - -static std::tuple mkldnn_convolution_transpose_backward( - const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, std::array output_mask) { - TORCH_CHECK(false, "mkldnn_convolution_transpose_backward: ATen not compiled with MKLDNN support"); -} - REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub); REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub); diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 47dbe792d73a..6ed703c3b5fd 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -14,6 +14,7 @@ #include #include #endif +#include #if !AT_MKLDNN_ENABLED() @@ -37,7 +38,7 @@ std::tuple mkldnn_batch_norm_backward( TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support"); } -static std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( +std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( const Tensor& input, IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, bool inplace) { @@ -81,7 +82,6 @@ std::tuple _new_batch_norm_backward_mkldnn( #else // AT_MKLDNN_ENABLED #include -#include #include #include diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index 75f1b2c1b709..a63d9ebfa2c1 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -36,7 +36,7 @@ void check_mkldnn_binary_fusion_inputs( const Tensor& weight, const Tensor& bias); -static inline std::vector padding_r( +inline std::vector padding_r( IntArrayRef padding, IntArrayRef output_padding) { // ConvTranpose padding adjustment @@ -60,7 +60,7 @@ static inline std::vector padding_r( // Make sure input has default contiguous strides if it's contiguous tensors for better performance. // For example, for tensor of size = [1, 1280], stride = [0, 1], we'll convert it to size = [1, 1280], stride = [1280, 1] // before calling oneDNN for better performance. -static inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) { +inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) { auto input_size = input.sizes().vec(); auto input_stride = input.strides().vec(); auto input_default_contiguous_strides = c10::contiguous_strides(input_size); diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h index c7e7a5e94b40..afef4552c153 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h @@ -12,7 +12,7 @@ namespace at::native::onednn { -TORCH_API dnnl::memory make_onednn_memory( +TORCH_XPU_API dnnl::memory make_onednn_memory( dnnl::memory::desc md, dnnl::engine& engine, void* ptr); @@ -21,7 +21,7 @@ TORCH_API dnnl::memory make_onednn_memory( bool set_onednn_verbose(int level); // GpuEngineManager singleton -struct TORCH_API GpuEngineManager { +struct TORCH_XPU_API GpuEngineManager { static GpuEngineManager& Instance(); // Singleton dnnl::engine& get_engine(const Device& device) { @@ -51,7 +51,7 @@ struct TORCH_API GpuEngineManager { }; // GpuStreamManager singleton -struct TORCH_API GpuStreamManager { +struct TORCH_XPU_API GpuStreamManager { static GpuStreamManager& Instance(); // Singleton dnnl::stream get_stream() { diff --git a/aten/src/ATen/native/mps/UnaryConstants.h b/aten/src/ATen/native/mps/UnaryConstants.h index 4adf1d0e333e..b1a92f688d12 100644 --- a/aten/src/ATen/native/mps/UnaryConstants.h +++ b/aten/src/ATen/native/mps/UnaryConstants.h @@ -18,26 +18,21 @@ kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]], /* coefficients in rational expansion */ float y_abs = abs(y); - if(y_abs > 1.0f){{ - output[index] = NAN; + if (y_abs >= 1.0f) {{ + output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y)); return; }} - if(y_abs == 1.0f){{ - output[index] = copysign(INFINITY, y); - return; - }} - if(y_abs <= 0.7f) {{ + if (y_abs <= 0.7f) {{ z = y * y; - num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); - dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f); + num = ((a[3] * z + a[2]) * z + a[1])*z + a[0]; + dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f; x = y * num / dem; - }} - else{{ + }} else {{ z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); - num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; - dem = (d[1]*z + d[0])*z + 1.0f; + num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; + dem = (d[1] * z + d[0]) * z + 1.0f; x = copysign(num, y) / dem; }} - output[index] = x; -}})METAL"; \ No newline at end of file + output[index] = {0}(x); +}})METAL"; diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index da11401c948d..741789c7eac9 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -143,7 +143,7 @@ Tensor relu_mps(const Tensor& self) { Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); @autoreleasepool { - string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to()); + string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -193,8 +193,8 @@ Tensor relu_mps(const Tensor& self) { Tensor output_ = at::empty_like(self, self.suggest_memory_format()); @autoreleasepool { - string key = - "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to()); + string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + + std::to_string(negative_slope.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); @@ -242,7 +242,7 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim); + string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -285,7 +285,7 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim); + string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output)); MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); @@ -539,8 +539,8 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to()) + ":" + - to_string(value.to()); + string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to()) + + ":" + std::to_string(value.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -587,7 +587,7 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c @autoreleasepool { string key = - "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to()); + "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -826,8 +826,8 @@ static void elu_variants_out_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to()) + ":" + - to_string(scale.to()) + ":" + to_string(input_scale.to()); + string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to()) + ":" + + std::to_string(scale.to()) + ":" + std::to_string(input_scale.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); @@ -916,8 +916,8 @@ static void elu_variants_out_mps(const Tensor& self, @autoreleasepool { string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + - to_string(alpha.to()) + ":" + to_string(scale.to()) + ":" + - to_string(input_scale.to()) + ":" + to_string(is_result); + std::to_string(alpha.to()) + ":" + std::to_string(scale.to()) + ":" + + std::to_string(input_scale.to()) + ":" + std::to_string(is_result); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); @@ -1010,7 +1010,7 @@ static void elu_variants_out_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim); + string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); NSArray* outputTensorsArray = [mpsGraph splitTensor:inputTensor @@ -1052,7 +1052,7 @@ static void elu_variants_out_mps(const Tensor& self, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim); + string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); MPSGraphTensor* gradOutputTensor = @@ -1855,8 +1855,8 @@ Tensor hardtanh_backward_mps(const Tensor& grad_output, const Tensor& self, cons MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to()) + - ":" + to_string(max.to()); + string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + + std::to_string(min.to()) + ":" + std::to_string(max.to()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index 1714a8e7e2f8..25cc732c1e62 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -136,8 +136,8 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) { Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1); @autoreleasepool { - string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) + - ":" + to_string(alpha_.toDouble()); + string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + + std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec); MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index 2e7d0881bb60..353978547186 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -33,7 +33,7 @@ }; @autoreleasepool { - string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble()); + string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index fbf5a67262be..08ad620a2028 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -193,24 +193,24 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, string bias_shape_key; if (bias_defined) { - bias_shape_key = to_string(bias_shape[0]); + bias_shape_key = std::to_string(bias_shape[0]); } else { bias_shape_key = "nobias"; } string key; if (is3DConv) { - key = "mps_3d_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) + - ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" + - to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) + - ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" + - bias_shape_key; + key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + + std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + + std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + + mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; } else { - key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) + - ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + - to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + - to_string(bias_defined) + ":" + bias_shape_key; + key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + + std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + + mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; } MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); @@ -388,16 +388,16 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; string key; if (is3DConv) { - key = "mps_3d_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + ":" + - to_string(stride[2]) + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + - ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + - to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" + - string([ns_shape_key UTF8String]); + key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + ":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + + std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + + std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + + getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); } else { - key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + - to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + - to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + + key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + + std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); } auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { @@ -547,15 +547,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; string key; if (is3DConv) { - key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + - to_string(stride[2]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + - to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + - to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key + + key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + + std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + + std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); } else { - key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + - to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + - to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + + key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + + std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); } auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 7ed06c8bf437..303a7bda99f7 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -63,7 +63,7 @@ @autoreleasepool { string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" + - to_string(val1) + ":" + to_string(val2); + std::to_string(val1) + ":" + std::to_string(val2); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); @@ -469,7 +469,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(key, [&](auto mpsGraph, auto newCachedGraph) { MPSShape* prob_shape = getMPSShape(self_v); newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]); diff --git a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm index 1b6e650f51d4..a9ac70110617 100644 --- a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm +++ b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm @@ -1,5 +1,6 @@ #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 6686c2bed06e..fc8253e341f2 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -236,7 +236,7 @@ static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& g MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" + + string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" + getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index e0db2c1e8b9b..25405cf4d395 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -229,8 +229,8 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) @autoreleasepool { string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl"); - key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" + - to_string(alpha.toDouble()); + key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" + + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input); @@ -331,8 +331,8 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) }; @autoreleasepool { - string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(beta.toDouble()) + - ":" + to_string(alpha.toDouble()); + string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + + std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* selfTensor = nil; MPSGraphTensor* otherTensor = nil; @@ -615,8 +615,8 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons }; @autoreleasepool { - string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) + - ":" + to_string(alpha.toDouble()); + string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + + std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape); MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 3e58d2ca8a4b..65540c770db4 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -69,7 +69,7 @@ static string reductionToString(int64_t reduction) { }; @autoreleasepool { - string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) + + string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) + getTensorsStringKey({input, target, grad_output}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); @@ -327,8 +327,8 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg, } @autoreleasepool { string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) + - to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + - to_string(isTargetCasted) + ":" + reductionToString(reduction); + std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) + + ":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); @@ -463,9 +463,9 @@ static void nllnd_loss_forward_impl(Tensor& output, NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; // TODO: Make the key - string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + - reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + - getMPSTypeString(target) + ":" + to_string(isTargetCasted) + ":" + getMPSTypeString(weight); + string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) + + ":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + + getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape); @@ -598,7 +598,7 @@ static void smooth_l1_loss_impl(const Tensor& input, NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + - to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); + std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { // smooth_l1_loss_mps: // ln = 0.5 * ( xn - yn ) ^ 2 / beta, if |xn - yn| < beta @@ -734,7 +734,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, @autoreleasepool { string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" + - reductionToString(reduction) + ":" + to_string(beta); + reductionToString(reduction) + ":" + std::to_string(beta); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 137c14be6ef4..9010dd3add24 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -38,6 +38,19 @@ static void addc_mul_div_out_mps(const Tensor& self, }; @autoreleasepool { + bool executeGatherOpOnSelf = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor output_ = at::empty_like(self, executeGatherOpOnSelf ? MemoryFormat::Contiguous : MemoryFormat::Preserve); + + bool executeGatherOpOnFirstTensor = + !(tensor1.is_contiguous(MemoryFormat::Contiguous) || tensor1.is_contiguous(MemoryFormat::ChannelsLast) || + tensor1.is_contiguous(MemoryFormat::ChannelsLast3d)); + + bool executeGatherOpOnSecondTensor = + !(tensor2.is_contiguous(MemoryFormat::Contiguous) || tensor2.is_contiguous(MemoryFormat::ChannelsLast) || + tensor2.is_contiguous(MemoryFormat::ChannelsLast3d)); + string key = op_name + getTensorsStringKey({self, tensor1, tensor2}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { @@ -72,10 +85,12 @@ static void addc_mul_div_out_mps(const Tensor& self, }); // Inputs as placeholders - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor, self); - Placeholder tensor1Placeholder = Placeholder(cachedGraph->firstTensor, tensor1); - Placeholder tensor2Placeholder = Placeholder(cachedGraph->secondTensor, tensor2); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor, self, nil, executeGatherOpOnSelf); + Placeholder tensor1Placeholder = Placeholder(cachedGraph->firstTensor, tensor1, nil, executeGatherOpOnFirstTensor); + Placeholder tensor2Placeholder = + Placeholder(cachedGraph->secondTensor, tensor2, nil, executeGatherOpOnSecondTensor); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor, executeGatherOpOnSelf ? output_ : output, nil, false); MPSScalar value_scalar = getMPSScalar(value_opt, self.scalar_type()); // Create dictionary of inputs and outputs @@ -87,6 +102,10 @@ static void addc_mul_div_out_mps(const Tensor& self, }; runMPSGraph(mpsStream, cachedGraph->graph(), feeds, outputPlaceholder); + + if (executeGatherOpOnSelf) { + output.copy_(output_); + } } } diff --git a/aten/src/ATen/native/mps/operations/RangeFactories.mm b/aten/src/ATen/native/mps/operations/RangeFactories.mm index 102c54c251db..e558cb1d0d15 100644 --- a/aten/src/ATen/native/mps/operations/RangeFactories.mm +++ b/aten/src/ATen/native/mps/operations/RangeFactories.mm @@ -106,7 +106,7 @@ auto stream = getCurrentMPSStream(); auto mpsDataType = getMPSDataType(result); @autoreleasepool { - string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); + string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size); auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { @@ -173,7 +173,7 @@ auto stream = getCurrentMPSStream(); auto mpsDataType = getMPSDataType(result); @autoreleasepool { - string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); + string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size); auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { @@ -221,8 +221,8 @@ bool start_less_end = (start.to() <= end.to()); @autoreleasepool { - string key = - "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end); + string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) + + std::to_string(start_less_end); auto cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 416c83f0d3b3..b5ebd959932d 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -359,8 +359,8 @@ static void impl_func_norm_mps(const Tensor& input_tensor, NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); - string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + - keepdim_info + ":" + toString(in_dtype) + ":" + to_string(castInputData); + string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" + + keepdim_info + ":" + toString(in_dtype) + ":" + std::to_string(castInputData); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); @@ -572,7 +572,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; - string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0"; + string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0"; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value); @@ -700,7 +700,7 @@ static void min_max_out_mps(const Tensor& input_t, auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_); + string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; @@ -860,7 +860,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t, @autoreleasepool { NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","]; string key = - func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); + func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputScalarType = input_t.scalar_type(); MPSGraphTensor* inputTensor = @@ -1217,7 +1217,7 @@ Tensor std_mps(const Tensor& input_t, @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + + string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" + getMPSTypeString(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSDataType input_type = getMPSDataType(input_t); @@ -1313,7 +1313,7 @@ Tensor std_mps(const Tensor& input_t, @autoreleasepool { MPSShape* input_t_shape = getMPSShape(input_t); - string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + + string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" + getMPSTypeString(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSDataType input_type = getMPSDataType(input_t); @@ -1531,8 +1531,8 @@ static void median_out_mps(const Tensor& input_t, auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = - func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t); + string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + + getTensorsStringKey(indices_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* castInputTensor = diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 135041be1f41..c32553094855 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -108,8 +108,8 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in // Input as placeholders MPSShape* input_shape = getMPSShape(self); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) + - ":dim" + to_string(dim_) + ":largest" + to_string(largest); + string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) + + ":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); @@ -320,12 +320,12 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in }; @autoreleasepool { - string key = - "cat_out_mps:" + to_string(dimension) + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); + string key = "cat_out_mps:" + std::to_string(dimension) + ":" + + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); if (!all_same_dtype) { key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride); } else { - key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + to_string(inputs.size()); + key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size()); } for (auto idx : skipped_tensor_indices) { key += "," + std::to_string(idx); diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index e3ee85cfe230..5b94240846da 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -60,8 +60,8 @@ // Input as placeholders MPSShape* input_shape = getMPSShape(self); NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) + - ":descending" + to_string(descending); + string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + + std::to_string(dim) + ":descending" + std::to_string(descending); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 4da5c302214d..6f8bfff53b8c 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -240,8 +240,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t, @autoreleasepool { // the optional min/max refs could affect how we build the cached graph - string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + - (has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); + string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") + + (has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { if (has_min) newCachedGraph->minTensor = [mpsGraph diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm index 540fc6a26cd8..5c894efb89fd 100644 --- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm @@ -13,32 +13,6 @@ #include namespace at::native { -static const std::string& getMetalType(const c10::ScalarType& t) { - // Mapping from c10::ScalarType to integral type that can be used for unary ops - static std::unordered_map scalar_to_metal_type = { - {c10::ScalarType::Half, "half"}, - {c10::ScalarType::Float, "float"}, - {c10::ScalarType::Long, "long"}, - {c10::ScalarType::Int, "int"}, - {c10::ScalarType::Short, "short"}, - {c10::ScalarType::Bool, "bool"}, - {c10::ScalarType::Char, "int8_t"}, - {c10::ScalarType::Byte, "uint8_t"}, - }; - - auto it = scalar_to_metal_type.find(t); - TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t); - return it->second; -} - -static const std::string& getMetalType(const c10::Scalar& s) { - return getMetalType(s.type()); -} - -static const std::string& getMetalType(const Tensor& t) { - return getMetalType(t.scalar_type()); -} - static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2); TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) { @@ -57,7 +31,8 @@ } using namespace mps; @autoreleasepool { - auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", {getMetalType(outputTensor), getMetalType(self)}); + auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", + {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)}); if (!self.is_contiguous()) { inputTensor = inputTensor.contiguous(); diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm index fc30c2d0b797..a9948183b04c 100644 --- a/aten/src/ATen/native/mps/operations/Unique.mm +++ b/aten/src/ATen/native/mps/operations/Unique.mm @@ -36,8 +36,8 @@ const bool consecutive, c10::optional dimOpt) { return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" + - (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" + - to_string(return_counts) + "]:[" + to_string(consecutive) + "]"; + (dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" + + std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]"; } // dim arg not supported when non consecutive, ie sorted diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index f4973f600015..fca71ed346c5 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -99,7 +99,7 @@ static void upsample_out_template(const Tensor& input, @autoreleasepool { string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") + - getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" + + getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" + (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index b583a19ef5e6..ae530ad12bde 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -42,7 +42,7 @@ } return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" + - getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]"; + getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + std::to_string(storage_offset) + "]"; } // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e396ccb67672..7970e17eb960 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5426,7 +5426,7 @@ autogen: slice_backward.out # NB: This op exists to back the implementation of reverse view_funcs for various views (chunk, -# slice.Tensor, split_with_sizes, et. al.). Currently, these are only used during fake-ification +# slice.Tensor, split_with_sizes, et al.). Currently, these are only used during fake-ification # of PT2 graph input subclass instances that are views. This means: # * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it) # * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it) @@ -10323,14 +10323,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow CUDA: foreach_tensor_add_scalar_kernel_cuda - func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ autogen: _foreach_add.Scalar_out @@ -10338,14 +10338,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda - func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ autogen: _foreach_add.List_out @@ -10353,14 +10353,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow CUDA: foreach_tensor_add_scalarlist_kernel_cuda - func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow_ CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ autogen: _foreach_add.ScalarList_out @@ -10368,14 +10368,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_tensor_kernel_slow + CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow CUDA: foreach_tensor_add_tensor_kernel_cuda - func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_add_tensor_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_ CUDA: foreach_tensor_add_tensor_kernel_cuda_ autogen: _foreach_add.Tensor_out @@ -10383,14 +10383,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow CUDA: foreach_tensor_sub_scalar_kernel_cuda - func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ autogen: _foreach_sub.Scalar_out @@ -10398,14 +10398,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow CUDA: foreach_tensor_sub_list_kernel_cuda - func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ autogen: _foreach_sub.List_out @@ -10413,14 +10413,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow CUDA: foreach_tensor_sub_scalarlist_kernel_cuda - func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sub_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow_ CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ autogen: _foreach_sub.ScalarList_out @@ -10428,14 +10428,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow CUDA: foreach_tensor_mul_scalar_kernel_cuda - func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ autogen: _foreach_mul.Scalar_out @@ -10443,14 +10443,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ autogen: _foreach_mul.List_out @@ -10458,14 +10458,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow CUDA: foreach_tensor_mul_scalarlist_kernel_cuda - func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow_ CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ autogen: _foreach_mul.ScalarList_out @@ -10473,14 +10473,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_tensor_kernel_slow + CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow CUDA: foreach_tensor_mul_tensor_kernel_cuda - func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_mul_tensor_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_ CUDA: foreach_tensor_mul_tensor_kernel_cuda_ autogen: _foreach_mul.Tensor_out @@ -10488,14 +10488,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow CUDA: foreach_tensor_div_scalar_kernel_cuda - func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ autogen: _foreach_div.Scalar_out @@ -10503,14 +10503,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow CUDA: foreach_tensor_div_list_kernel_cuda - func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ autogen: _foreach_div.List_out @@ -10518,14 +10518,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow CUDA: foreach_tensor_div_scalarlist_kernel_cuda - func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow_ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ autogen: _foreach_div.ScalarList_out @@ -10533,14 +10533,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_tensor_kernel_slow + CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow CUDA: foreach_tensor_div_tensor_kernel_cuda - func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_div_tensor_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_ CUDA: foreach_tensor_div_tensor_kernel_cuda_ autogen: _foreach_div.Tensor_out @@ -10548,14 +10548,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda - func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ autogen: _foreach_clamp_max.Scalar_out @@ -10563,14 +10563,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda - func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ autogen: _foreach_clamp_max.List_out @@ -10578,14 +10578,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda - func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ autogen: _foreach_clamp_max.ScalarList_out @@ -10593,14 +10593,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda - func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ autogen: _foreach_clamp_min.Scalar_out @@ -10608,14 +10608,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda - func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ autogen: _foreach_clamp_min.List_out @@ -10623,14 +10623,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda - func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ autogen: _foreach_clamp_min.ScalarList_out @@ -10639,14 +10639,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda - func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ autogen: _foreach_maximum.Scalar_out @@ -10655,14 +10655,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda - func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ autogen: _foreach_maximum.List_out @@ -10671,14 +10671,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda - func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ autogen: _foreach_maximum.ScalarList_out @@ -10686,14 +10686,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda - func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ autogen: _foreach_minimum.Scalar_out @@ -10701,14 +10701,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda - func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ autogen: _foreach_minimum.List_out @@ -10716,14 +10716,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda - func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ autogen: _foreach_minimum.ScalarList_out @@ -10731,28 +10731,28 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalar_slow + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow CUDA: foreach_tensor_addcdiv_scalar_cuda - func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalarlist_slow + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow CUDA: foreach_tensor_addcdiv_scalarlist_cuda - func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_tensor_slow + CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow CUDA: foreach_tensor_addcdiv_tensor_cuda - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalar_slow_ + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow_ CUDA: foreach_tensor_addcdiv_scalar_cuda_ autogen: _foreach_addcdiv.Scalar_out @@ -10760,7 +10760,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_scalarlist_slow_ + CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow_ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ autogen: _foreach_addcdiv.ScalarList_out @@ -10768,7 +10768,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcdiv_tensor_slow_ + CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow_ CUDA: foreach_tensor_addcdiv_tensor_cuda_ autogen: _foreach_addcdiv.Tensor_out @@ -10776,28 +10776,28 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalar_slow + CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalarlist_slow + CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow CUDA: foreach_tensor_addcmul_scalarlist_cuda - func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_tensor_slow + CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow CUDA: foreach_tensor_addcmul_tensor_cuda - func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalar_slow_ + CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ autogen: _foreach_addcmul.Scalar_out @@ -10805,7 +10805,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_scalarlist_slow_ + CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow_ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ autogen: _foreach_addcmul.ScalarList_out @@ -10813,7 +10813,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_addcmul_tensor_slow_ + CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow_ CUDA: foreach_tensor_addcmul_tensor_cuda_ autogen: _foreach_addcmul.Tensor_out @@ -10821,14 +10821,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_abs_slow + CompositeExplicitAutograd: foreach_tensor_abs_slow CUDA: foreach_tensor_abs_cuda - func: _foreach_abs_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_abs_slow_ + CompositeExplicitAutograd: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ autogen: _foreach_abs.out @@ -10836,14 +10836,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_acos_slow + CompositeExplicitAutograd: foreach_tensor_acos_slow CUDA: foreach_tensor_acos_cuda - func: _foreach_acos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_acos_slow_ + CompositeExplicitAutograd: foreach_tensor_acos_slow_ CUDA: foreach_tensor_acos_cuda_ autogen: _foreach_acos.out @@ -10851,14 +10851,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_asin_slow + CompositeExplicitAutograd: foreach_tensor_asin_slow CUDA: foreach_tensor_asin_cuda - func: _foreach_asin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_asin_slow_ + CompositeExplicitAutograd: foreach_tensor_asin_slow_ CUDA: foreach_tensor_asin_cuda_ autogen: _foreach_asin.out @@ -10866,14 +10866,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_atan_slow + CompositeExplicitAutograd: foreach_tensor_atan_slow CUDA: foreach_tensor_atan_cuda - func: _foreach_atan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_atan_slow_ + CompositeExplicitAutograd: foreach_tensor_atan_slow_ CUDA: foreach_tensor_atan_cuda_ autogen: _foreach_atan.out @@ -10881,14 +10881,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_ceil_slow + CompositeExplicitAutograd: foreach_tensor_ceil_slow CUDA: foreach_tensor_ceil_cuda - func: _foreach_ceil_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_ceil_slow_ + CompositeExplicitAutograd: foreach_tensor_ceil_slow_ CUDA: foreach_tensor_ceil_cuda_ autogen: _foreach_ceil.out @@ -10896,14 +10896,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cos_slow + CompositeExplicitAutograd: foreach_tensor_cos_slow CUDA: foreach_tensor_cos_cuda - func: _foreach_cos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cos_slow_ + CompositeExplicitAutograd: foreach_tensor_cos_slow_ CUDA: foreach_tensor_cos_cuda_ autogen: _foreach_cos.out @@ -10911,14 +10911,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cosh_slow + CompositeExplicitAutograd: foreach_tensor_cosh_slow CUDA: foreach_tensor_cosh_cuda - func: _foreach_cosh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_cosh_slow_ + CompositeExplicitAutograd: foreach_tensor_cosh_slow_ CUDA: foreach_tensor_cosh_cuda_ autogen: _foreach_cosh.out @@ -10926,14 +10926,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erf_slow + CompositeExplicitAutograd: foreach_tensor_erf_slow CUDA: foreach_tensor_erf_cuda - func: _foreach_erf_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erf_slow_ + CompositeExplicitAutograd: foreach_tensor_erf_slow_ CUDA: foreach_tensor_erf_cuda_ autogen: _foreach_erf.out @@ -10941,14 +10941,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erfc_slow + CompositeExplicitAutograd: foreach_tensor_erfc_slow CUDA: foreach_tensor_erfc_cuda - func: _foreach_erfc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_erfc_slow_ + CompositeExplicitAutograd: foreach_tensor_erfc_slow_ CUDA: foreach_tensor_erfc_cuda_ autogen: _foreach_erfc.out @@ -10956,14 +10956,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_exp_slow + CompositeExplicitAutograd: foreach_tensor_exp_slow CUDA: foreach_tensor_exp_cuda - func: _foreach_exp_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_exp_slow_ + CompositeExplicitAutograd: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ autogen: _foreach_exp.out @@ -10971,14 +10971,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_expm1_slow + CompositeExplicitAutograd: foreach_tensor_expm1_slow CUDA: foreach_tensor_expm1_cuda - func: _foreach_expm1_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_expm1_slow_ + CompositeExplicitAutograd: foreach_tensor_expm1_slow_ CUDA: foreach_tensor_expm1_cuda_ autogen: _foreach_expm1.out @@ -10986,14 +10986,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_floor_slow + CompositeExplicitAutograd: foreach_tensor_floor_slow CUDA: foreach_tensor_floor_cuda - func: _foreach_floor_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_floor_slow_ + CompositeExplicitAutograd: foreach_tensor_floor_slow_ CUDA: foreach_tensor_floor_cuda_ autogen: _foreach_floor.out @@ -11001,14 +11001,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_frac_slow + CompositeExplicitAutograd: foreach_tensor_frac_slow CUDA: foreach_tensor_frac_cuda - func: _foreach_frac_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_frac_slow_ + CompositeExplicitAutograd: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ autogen: _foreach_frac.out @@ -11016,7 +11016,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_ternary_lerp_slow + CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow CUDA: foreach_tensor_lerp_ternary_cuda autogen: _foreach_lerp.List_out @@ -11024,7 +11024,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_ternary_lerp_slow_ + CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow_ CUDA: foreach_tensor_lerp_ternary_cuda_ autogen: _foreach_lerp.List_out @@ -11032,7 +11032,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_lerp_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow CUDA: foreach_tensor_lerp_list_cuda autogen: _foreach_lerp.Scalar_out @@ -11040,7 +11040,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices variants: function dispatch: - CPU: foreach_tensor_lerp_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow_ CUDA: foreach_tensor_lerp_list_cuda_ autogen: _foreach_lerp.Scalar_out @@ -11048,14 +11048,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_lgamma_slow + CompositeExplicitAutograd: foreach_tensor_lgamma_slow CUDA: foreach_tensor_lgamma_cuda - func: _foreach_lgamma_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_lgamma_slow_ + CompositeExplicitAutograd: foreach_tensor_lgamma_slow_ CUDA: foreach_tensor_lgamma_cuda_ autogen: _foreach_lgamma.out @@ -11063,14 +11063,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log_slow + CompositeExplicitAutograd: foreach_tensor_log_slow CUDA: foreach_tensor_log_cuda - func: _foreach_log_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log_slow_ + CompositeExplicitAutograd: foreach_tensor_log_slow_ CUDA: foreach_tensor_log_cuda_ autogen: _foreach_log.out @@ -11078,14 +11078,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log10_slow + CompositeExplicitAutograd: foreach_tensor_log10_slow CUDA: foreach_tensor_log10_cuda - func: _foreach_log10_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log10_slow_ + CompositeExplicitAutograd: foreach_tensor_log10_slow_ CUDA: foreach_tensor_log10_cuda_ autogen: _foreach_log10.out @@ -11093,14 +11093,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log1p_slow + CompositeExplicitAutograd: foreach_tensor_log1p_slow CUDA: foreach_tensor_log1p_cuda - func: _foreach_log1p_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log1p_slow_ + CompositeExplicitAutograd: foreach_tensor_log1p_slow_ CUDA: foreach_tensor_log1p_cuda_ autogen: _foreach_log1p.out @@ -11108,29 +11108,37 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log2_slow + CompositeExplicitAutograd: foreach_tensor_log2_slow CUDA: foreach_tensor_log2_cuda - func: _foreach_log2_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_log2_slow_ + CompositeExplicitAutograd: foreach_tensor_log2_slow_ CUDA: foreach_tensor_log2_cuda_ autogen: _foreach_log2.out +- func: _foreach_max(Tensor[] self) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_max_slow + CUDA: foreach_tensor_max_cuda + autogen: _foreach_max.out + - func: _foreach_neg(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_neg_slow + CompositeExplicitAutograd: foreach_tensor_neg_slow CUDA: foreach_tensor_neg_cuda - func: _foreach_neg_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_neg_slow_ + CompositeExplicitAutograd: foreach_tensor_neg_slow_ CUDA: foreach_tensor_neg_cuda_ autogen: _foreach_neg.out @@ -11138,7 +11146,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_norm_slow + CompositeExplicitAutograd: foreach_tensor_norm_slow CUDA: foreach_tensor_norm_cuda autogen: _foreach_norm.Scalar_out @@ -11146,35 +11154,35 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_pow_list_kernel_slow + CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow CUDA: foreach_tensor_pow_list_kernel_cuda - func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_pow_scalar_kernel_slow + CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow CUDA: foreach_tensor_pow_scalar_kernel_cuda - func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_pow_scalarlist_kernel_slow + CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow CUDA: foreach_tensor_pow_scalarlist_kernel_cuda - func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_scalar_pow_list_kernel_slow + CompositeExplicitAutograd: foreach_scalar_pow_list_kernel_slow CUDA: foreach_scalar_pow_list_kernel_cuda - func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () device_check: NoCheck variants: function dispatch: - CPU: foreach_tensor_pow_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow_ CUDA: foreach_tensor_pow_list_kernel_cuda_ autogen: _foreach_pow.List_out @@ -11182,7 +11190,7 @@ device_check: NoCheck variants: function dispatch: - CPU: foreach_tensor_pow_scalar_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow_ CUDA: foreach_tensor_pow_scalar_kernel_cuda_ autogen: _foreach_pow.Scalar_out @@ -11190,7 +11198,7 @@ device_check: NoCheck variants: function dispatch: - CPU: foreach_tensor_pow_scalarlist_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow_ CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_ autogen: _foreach_pow.ScalarList_out @@ -11198,14 +11206,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_reciprocal_slow + CompositeExplicitAutograd: foreach_tensor_reciprocal_slow CUDA: foreach_tensor_reciprocal_cuda - func: _foreach_reciprocal_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_reciprocal_slow_ + CompositeExplicitAutograd: foreach_tensor_reciprocal_slow_ CUDA: foreach_tensor_reciprocal_cuda_ autogen: _foreach_reciprocal.out @@ -11213,14 +11221,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_round_slow + CompositeExplicitAutograd: foreach_tensor_round_slow CUDA: foreach_tensor_round_cuda - func: _foreach_round_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_round_slow_ + CompositeExplicitAutograd: foreach_tensor_round_slow_ CUDA: foreach_tensor_round_cuda_ autogen: _foreach_round.out @@ -11228,14 +11236,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sigmoid_slow + CompositeExplicitAutograd: foreach_tensor_sigmoid_slow CUDA: foreach_tensor_sigmoid_cuda - func: _foreach_sigmoid_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sigmoid_slow_ + CompositeExplicitAutograd: foreach_tensor_sigmoid_slow_ CUDA: foreach_tensor_sigmoid_cuda_ autogen: _foreach_sigmoid.out @@ -11243,14 +11251,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sign_slow + CompositeExplicitAutograd: foreach_tensor_sign_slow CUDA: foreach_tensor_sign_cuda - func: _foreach_sign_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sign_slow_ + CompositeExplicitAutograd: foreach_tensor_sign_slow_ CUDA: foreach_tensor_sign_cuda_ autogen: _foreach_sign.out @@ -11258,14 +11266,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sin_slow + CompositeExplicitAutograd: foreach_tensor_sin_slow CUDA: foreach_tensor_sin_cuda - func: _foreach_sin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sin_slow_ + CompositeExplicitAutograd: foreach_tensor_sin_slow_ CUDA: foreach_tensor_sin_cuda_ autogen: _foreach_sin.out @@ -11273,14 +11281,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sinh_slow + CompositeExplicitAutograd: foreach_tensor_sinh_slow CUDA: foreach_tensor_sinh_cuda - func: _foreach_sinh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sinh_slow_ + CompositeExplicitAutograd: foreach_tensor_sinh_slow_ CUDA: foreach_tensor_sinh_cuda_ autogen: _foreach_sinh.out @@ -11288,14 +11296,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sqrt_slow + CompositeExplicitAutograd: foreach_tensor_sqrt_slow CUDA: foreach_tensor_sqrt_cuda - func: _foreach_sqrt_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_sqrt_slow_ + CompositeExplicitAutograd: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ autogen: _foreach_sqrt.out @@ -11303,14 +11311,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tan_slow + CompositeExplicitAutograd: foreach_tensor_tan_slow CUDA: foreach_tensor_tan_cuda - func: _foreach_tan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tan_slow_ + CompositeExplicitAutograd: foreach_tensor_tan_slow_ CUDA: foreach_tensor_tan_cuda_ autogen: _foreach_tan.out @@ -11318,14 +11326,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tanh_slow + CompositeExplicitAutograd: foreach_tensor_tanh_slow CUDA: foreach_tensor_tanh_cuda - func: _foreach_tanh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_tanh_slow_ + CompositeExplicitAutograd: foreach_tensor_tanh_slow_ CUDA: foreach_tensor_tanh_cuda_ autogen: _foreach_tanh.out @@ -11333,14 +11341,14 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_trunc_slow + CompositeExplicitAutograd: foreach_tensor_trunc_slow CUDA: foreach_tensor_trunc_cuda - func: _foreach_trunc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_trunc_slow_ + CompositeExplicitAutograd: foreach_tensor_trunc_slow_ CUDA: foreach_tensor_trunc_cuda_ autogen: _foreach_trunc.out @@ -11348,7 +11356,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_zero_slow_ + CompositeExplicitAutograd: foreach_tensor_zero_slow_ CUDA: foreach_tensor_zero_cuda_ autogen: _foreach_zero, _foreach_zero.out @@ -11356,7 +11364,7 @@ device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function dispatch: - CPU: foreach_tensor_copy_list_kernel_slow_ + CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_ CUDA: foreach_tensor_copy_list_kernel_cuda_ autogen: _foreach_copy.out @@ -14636,6 +14644,16 @@ NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda autogen: to_padded_tensor.out +- func: _jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_jagged_to_padded_dense_forward + +- func: _padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + variants: function + dispatch: + CUDA: _fbgemm_dense_to_jagged_forward_symint + - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout @@ -14710,12 +14728,12 @@ CUDA: _scaled_dot_product_efficient_attention_backward_cuda tags: nondeterministic_seeded -- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) dispatch: CUDA: _scaled_dot_product_cudnn_attention_cuda tags: nondeterministic_seeded -- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor) +- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda tags: nondeterministic_seeded @@ -14733,13 +14751,13 @@ CUDA: _flash_attention_backward # Returns output, logsumexp if compute_logsumexp -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) variants: function dispatch: CUDA: _efficient_attention_forward tags: nondeterministic_seeded -- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None) -> (Tensor, Tensor, Tensor, Tensor) +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) device_check: NoCheck variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 56cac2a89803..c425cf504dc9 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -1,3 +1,4 @@ +#include #include #include @@ -11,6 +12,7 @@ #include #include +#include #include #include @@ -462,5 +464,1084 @@ template void add_padding_kernelLauncher( const int batch_size, const int output_batch_size); +// NB: The following code covers jagged <-> padded dense conversions and was lifted +// from fbgemm_gpu. For more details, see +// https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/src/jagged_tensor_ops + +// Passing lambda exp argument by value instead of by reference to avoid +// "internal compiler error: in maybe_undo_parenthesized_ref" error for specific +// compiler version. +#define JAGGED_TENSOR_DISPATCH_DIMS() \ + AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \ + switch (num_jagged_dim) { \ + case 1: \ + INVOKE_KERNEL_WITH_DIM(1); \ + break; \ + case 2: \ + INVOKE_KERNEL_WITH_DIM(2); \ + break; \ + case 3: \ + INVOKE_KERNEL_WITH_DIM(3); \ + break; \ + case 4: \ + INVOKE_KERNEL_WITH_DIM(4); \ + break; \ + case 5: \ + INVOKE_KERNEL_WITH_DIM(5); \ + break; \ + default: \ + TORCH_CHECK( \ + false, "unsupported number of jagged dim ", num_jagged_dim); \ + } \ + }); + +inline std::string torch_tensor_device_name(const at::Tensor& ten) { + return c10::DeviceTypeName(ten.device().type()); +} + +inline std::string torch_tensor_device_name( + const c10::optional& ten) { + if (ten.has_value()) { + return torch_tensor_device_name(ten.value()); + } else { + return "N/A"; + } +} + +inline bool torch_tensor_on_cuda_gpu_check(const at::Tensor& ten) { + return ten.is_cuda(); +} + +inline bool torch_tensor_on_cuda_gpu_check( + const c10::optional& ten) { + return !ten.has_value() || torch_tensor_on_cuda_gpu_check(ten.value()); +} + +#define TENSOR_ON_CUDA_GPU(x) \ + TORCH_CHECK( \ + torch_tensor_on_cuda_gpu_check(x), \ + #x " must be a CUDA tensor; it is currently on device ", \ + torch_tensor_device_name(x)) + +// A wrapper class for passing dynamically sized dimension information (e.g. +// tensor.dims()) from the host to device. +constexpr size_t kStackArrayMaxDims = 5; + +template +struct StackArray { + T vals[kStackArrayMaxDims]; + size_t ndim; +}; + +// Warp size +#ifdef USE_ROCM +static constexpr int32_t kWarpSize = 64; +#else +static constexpr int32_t kWarpSize = 32; +#endif +// Max thread num in one thread block +static constexpr int32_t kMaxThreads = 1024; + +#define DEVICE_INLINE __device__ C10_ALWAYS_INLINE + +__host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { + return (a + b - 1) / b; +} + +__host__ DEVICE_INLINE int32_t round_down(int32_t a, int32_t b) { + return a / b * b; +} + +inline std::tuple> check_shape_and_partition_( + const Tensor& values, + const std::vector& offsets, + const Tensor& dense_tensor) { + const int outer_dense_size = dense_tensor.size(0); + TORCH_CHECK( + outer_dense_size == offsets[0].numel() - 1, + "outer_dense_size, ", + outer_dense_size, + " != offsets[0].numel() - 1, ", + offsets[0].numel() - 1); + const int inner_dense_size = dense_tensor.size(-1); + TORCH_CHECK( + inner_dense_size == values.size(-1), + "inner_dense_size, ", + inner_dense_size, + " != values.size(-1), ", + values.size(-1)); + const int jagged_folded_size = + dense_tensor.numel() / (outer_dense_size * inner_dense_size); + + const int threads_x = + inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size; + const int threads_y = kMaxThreads / kWarpSize; + const dim3 blocks( + div_round_up(outer_dense_size * jagged_folded_size, threads_y)); + + StackArray jagged_dims_tensor; + const int num_jagged_dim = dense_tensor.dim() - 2; + TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); + jagged_dims_tensor.ndim = num_jagged_dim; + std::memcpy( + &(jagged_dims_tensor.vals[0]), + dense_tensor.sizes().data() + 1, + num_jagged_dim * sizeof(int64_t)); + return {dim3(threads_x, threads_y), blocks, jagged_dims_tensor}; +} + +template +DEVICE_INLINE bool walk_down_tensor_storage_tree_( + int& offset, + const int flattened_jagged_idx, + const StackArray& jagged_dims, + const StackArray& x_offsets) { + // compute coorindates + int jagged_coords[NUM_JAGGED_DIM]; + int j_temp = flattened_jagged_idx; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + const int jagged_size = jagged_dims.vals[d]; + jagged_coords[d] = j_temp % jagged_size; + j_temp /= jagged_size; + } + + // walk down the tree + bool is_zero = false; +#pragma unroll + for (int d = 0; d < NUM_JAGGED_DIM; ++d) { + const int begin = x_offsets.vals[d][offset]; + const int end = x_offsets.vals[d][offset + 1]; + if (jagged_coords[d] >= end - begin) { + is_zero = true; + break; + } + offset = begin + jagged_coords[d]; + } + return is_zero; +} + +// output = f(x, y) where x is jagged, y is dense, and output is dense. +// A generic elementwise operation between a jagged tensor and a dense tensor +// This kernel assumes jagged dims are clustered together, preceded by outer +// dense dimensions and followed by inner dense dimensions. +// The outer/inner dense dimensions, and jagged dimensions in between are +// assumed to be folded so physically the dense tensor is 3D and the value of +// jagged tensor is 2D. +// To support arbitrary number of jagged dimensions, we pass a vector of +// pointers to offset tensors (this is ugly and probably we can use nested +// tensor here). +// This kernel parallelizes the (folded) inner dense dimension across +// blockDim.x so the inner dense dimension should be similar to or bigger than +// warp size. +// We rely on compiler unrolling the compiler time constant NUM_JAGGED_DIM. +template +__global__ +__launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_( + const at::PackedTensorAccessor32 + x_values, + StackArray x_offsets, + const at::PackedTensorAccessor32 y, + at::PackedTensorAccessor32 output, + StackArray jagged_dims, + F f, + const scalar_t padding_value) { + const int outer_dense_size = y.size(0); + const int jagged_folded_size = y.size(1); + const int inner_dense_size = y.size(2); + + const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int outer_stride = gridDim.x * blockDim.y; + for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size; + outer += outer_stride) { + const int oidx = outer / jagged_folded_size; + const int jidx = outer % jagged_folded_size; + + int offset = oidx; + const bool is_zero = walk_down_tensor_storage_tree_( + offset, jidx, jagged_dims, x_offsets); + + if (is_zero) { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + output[oidx][jidx][2 * iidx + 1] = + f(padding_value, y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output[oidx][jidx][2 * iidx] = + f(padding_value, y[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + output[oidx][jidx][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], y[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output[oidx][jidx][2 * iidx] = + f(x_values[offset][2 * iidx], y[oidx][jidx][2 * iidx]); + } + } + } +} + +template +void jagged_dense_elementwise_dense_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output, + F f, + const scalar_t padding_value = static_cast(0)) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim ", + num_jagged_dim); + + if (y.numel() == 0) { + return; + } + + dim3 threads, blocks; + StackArray jagged_dims_tensor; + std::tie(threads, blocks, jagged_dims_tensor) = + check_shape_and_partition_(x_values, x_offsets, y); + + // Canonicalize y and output to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + Tensor output_reshaped = output.view(y_reshaped.sizes()); + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + } \ + jagged_dense_elementwise_dense_output_kernel_ \ + <<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + y_reshaped \ + .packed_accessor32(), \ + output_reshaped \ + .packed_accessor32(), \ + jagged_dims_tensor, \ + f, \ + padding_value); \ + } + + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +#undef INVOKE_KERNEL_WITH_DIM +} + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ + } + +template +__global__ +__launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_( + const at::PackedTensorAccessor32 + x_values, + StackArray x_offsets, + StackArray x_offsets_sizes, + const at::PackedTensorAccessor32 y_0, + const at::PackedTensorAccessor32 y_1, + at::PackedTensorAccessor32 + output_values, + StackArray jagged_dims, + F f) { + const int outer_dense_size = y_0.size(0); + const int inner_dense_size = y_0.size(2); + const int nnz = x_values.size(0); + + const int offset_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int offset_stride = gridDim.x * blockDim.y; + for (int offset = offset_begin; offset < nnz; offset += offset_stride) { + int offset_temp = offset; + int jidx = 0; + bool truncated = false; + int dim_prod = 1; +#pragma unroll + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { + // Binary search the first that is bigger than offset + int count = x_offsets_sizes.vals[d] - 1; + int first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (x_offsets.vals[d][idx] <= offset_temp) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + + --first; + int coord = offset_temp - x_offsets.vals[d][first]; + if (coord >= jagged_dims.vals[d]) { + truncated = true; + break; + } + jidx += coord * dim_prod; + dim_prod *= jagged_dims.vals[d]; + offset_temp = first; + } + + if (offset_temp >= outer_dense_size) { + // This can happen when values have more elements than the last element of + // offset + truncated = true; + } + if (!truncated) { + const int oidx = offset_temp; + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], + y_0[oidx][jidx][2 * iidx + 1], + y_1[oidx][jidx][2 * iidx + 1]); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = + f(x_values[offset][2 * iidx], + y_0[oidx][jidx][2 * iidx], + y_1[oidx][jidx][2 * iidx]); + } + } else { + int iidx; + for (iidx = threadIdx.x; iidx * 2 + 1 < inner_dense_size; + iidx += blockDim.x) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + output_values[offset][2 * iidx + 1] = + f(x_values[offset][2 * iidx + 1], 0, 0); + } + if (iidx * 2 + 1 == inner_dense_size) { + output_values[offset][2 * iidx] = f(x_values[offset][2 * iidx], 0, 0); + } + } + } +} + +///@addtogroup jagged-tensor-ops-cuda +template +void jagged_dense_elementwise_jagged_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + if (y.numel() == 0 || x_values.numel() == 0) { + return; + } + + // Canonicalize y to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +#undef INVOKE_KERNEL_WITH_DIM + +template +struct SharedMemory; + +template <> +struct SharedMemory { + __device__ int64_t* getPointer() { + extern __shared__ int64_t s_int64_t[]; + return s_int64_t; + } +}; + +template <> +struct SharedMemory { + __device__ int32_t* getPointer() { + extern __shared__ int32_t s_int32_t[]; + return s_int32_t; + } +}; + +template +__global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_( + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 rows, + at::PackedTensorAccessor32 cols, + int nnz, + int B) { + struct SharedMemory smem; + index_t* offsets_sh = smem.getPointer(); + + for (int i = threadIdx.x; i < B + 1; i += blockDim.x) { + offsets_sh[i] = offsets[i]; + } + __syncthreads(); + int row = threadIdx.x + blockIdx.x * blockDim.x; + if (row >= nnz) + return; + int first = -1; + int count = B - 1; + first = 1; + while (count > 0) { + int idx = first; + int step = count / 2; + idx += step; + if (offsets_sh[idx] <= row) { + first = ++idx; + count -= step + 1; + } else { + count = step; + } + } + --first; + + int dense_row = first; + int offset = offsets_sh[dense_row]; + int dense_col = row - offset; + rows[row] = dense_row; + cols[row] = dense_col; +} + +struct VecType128 { + typedef float4 TType; // Transaction Type + typedef struct __align__(16) { + __half a, b, c, d, w, x, y, z; + } + half8; + + union Data { + half8 val; + TType mask; + } data; + + __device__ VecType128() { + data.mask = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + } +}; + +struct VecType64 { + typedef float2 TType; // Transaction Type + typedef struct __align__(8) { + __half a, b, c, d; + } + half4; + + union Data { + half4 val; + TType mask; + } data; + + __device__ VecType64() { + data.mask = make_float2(0.0f, 0.0f); + } +}; + +struct VecType32 { + typedef float TType; // Transaction Type + + union Data { + __half2 val; + TType mask; + } data; + + __device__ VecType32() { + data.mask = 0.0f; + } +}; + +template +__device__ void f128( + VecType128& v_out, + const VecType128& x, + const VecType128& y0, + const VecType128& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); + v_out.data.val.w = f(x.data.val.w, y0.data.val.w, y1.data.val.w); + v_out.data.val.x = f(x.data.val.x, y0.data.val.x, y1.data.val.x); + v_out.data.val.y = f(x.data.val.y, y0.data.val.y, y1.data.val.y); + v_out.data.val.z = f(x.data.val.z, y0.data.val.z, y1.data.val.z); +} + +template +__device__ void f64( + VecType64& v_out, + const VecType64& x, + const VecType64& y0, + const VecType64& y1, + F f) { + v_out.data.val.a = f(x.data.val.a, y0.data.val.a, y1.data.val.a); + v_out.data.val.b = f(x.data.val.b, y0.data.val.b, y1.data.val.b); + v_out.data.val.c = f(x.data.val.c, y0.data.val.c, y1.data.val.c); + v_out.data.val.d = f(x.data.val.d, y0.data.val.d, y1.data.val.d); +} + +template +__device__ void f32( + VecType32& v_out, + const VecType32& x, + const VecType32& y0, + const VecType32& y1, + F f) { + v_out.data.val = __halves2half2( + f(__low2half(x.data.val), + __low2half(y0.data.val), + __low2half(y1.data.val)), + f(__high2half(x.data.val), + __high2half(y0.data.val), + __high2half(y1.data.val))); +} + +template +__device__ void +fh(__half& v_out, const __half& x, const __half& y0, const __half& y1, F f) { + v_out = f(x, y0, y1); +} + +template +__global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_( + at::PackedTensorAccessor32 values, + const at::PackedTensorAccessor32 + x_values, + const at::PackedTensorAccessor32 y0, + const at::PackedTensorAccessor32 y1, + const at::PackedTensorAccessor32 rows, + const at::PackedTensorAccessor32 cols, + const int nnz, + const int E, + F f) { + int values_row = threadIdx.y + blockIdx.y * blockDim.y; + if (values_row >= nnz) + return; + for (int real_row = values_row; real_row < nnz; + real_row += blockDim.y * gridDim.y) { + int dense_row = rows[real_row]; + int dense_col = cols[real_row]; + __half* values_ptr = reinterpret_cast<__half*>(&values[real_row][0]); + const __half* x_ptr = + reinterpret_cast(&x_values[real_row][0]); + const __half* y0_ptr = + reinterpret_cast(&y0[dense_row][dense_col][0]); + const __half* y1_ptr = + reinterpret_cast(&y1[dense_row][dense_col][0]); + if ((dense_col < y0.size(1)) && (dense_row < y0.size(0)) && + (dense_col < y1.size(1)) && (dense_row < y1.size(0)) && + (dense_col >= 0) && (dense_row >= 0)) { + for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4; + tid += blockDim.x) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2; + tid += blockDim.x) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + v_y0.data.mask = + (reinterpret_cast(y0_ptr))[tid]; + v_y1.data.mask = + (reinterpret_cast(y1_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) { + __half v_x, v_out, v_y0, v_y1; + v_x = static_cast<__half>(x_ptr[tid]); + v_y0 = static_cast<__half>(y0_ptr[tid]); + v_y1 = static_cast<__half>(y1_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } else { + for (int tid = threadIdx.x; tid < E / 8; tid += blockDim.x) { + VecType128 v_x, v_out, v_y0, v_y1; + v_x.data.mask = + (reinterpret_cast(x_ptr))[tid]; + f128(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 8) * 8; tid < E / 4; + tid += blockDim.x) { + VecType64 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f64(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 4) * 4; tid < E / 2; + tid += blockDim.x) { + VecType32 v_x, v_out, v_y0, v_y1; + v_x.data.mask = (reinterpret_cast(x_ptr))[tid]; + f32(v_out, v_x, v_y0, v_y1, f); + (reinterpret_cast(values_ptr))[tid] = + v_out.data.mask; + } + for (int tid = threadIdx.x + (E / 2) * 2; tid < E; tid += blockDim.x) { + __half v_x, v_out, v_y0, v_y1; + v_x = static_cast<__half>(x_ptr[tid]); + fh(v_out, v_x, v_y0, v_y1, f); + values_ptr[tid] = v_out; + } + } + } +} + +// Check to see if the inputs to the op are amenable to the fast path +inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( + const int& num_jagged_dim, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y_0_reshaped, + const Tensor& y_1_reshaped, + const Tensor& output_values) { + bool matches = true; + matches &= (num_jagged_dim == 1); + + // Unit stride embedding dim + matches &= (x_values.stride(-1) == 1); + matches &= (output_values.stride(-1) == 1); + matches &= (y_0_reshaped.stride(-1) == 1); + matches &= (y_1_reshaped.stride(-1) == 1); + + // Each row is aligned to 128-bit + matches &= (x_values.stride(-2) % 8 == 0); + matches &= (output_values.stride(-2) % 8 == 0); + matches &= (y_0_reshaped.stride(-2) % 8 == 0); + matches &= (y_1_reshaped.stride(-2) % 8 == 0); + + // Base addresses aligned to 128-bit + matches &= (reinterpret_cast(x_values.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(output_values.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(y_0_reshaped.data_ptr()) % 16 == 0); + matches &= (reinterpret_cast(y_1_reshaped.data_ptr()) % 16 == 0); + + // Rows and col fit into int32_t + matches &= (y_0_reshaped.size(0) < INT_MAX); + matches &= (y_0_reshaped.size(1) < INT_MAX); + + int max_shared_bytes; +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + y_0_reshaped.get_device())); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef USE_ROCM + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "check_shared_memory", [&] { + auto B = y_0_reshaped.size(0); + // the default shared memory on V100/A100/H100 is 48 KB from + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x + if ((B + 1) * sizeof(index_t) >= used_shared_bytes) { + matches = false; + } + }); + return matches; +} + +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + x_values.packed_accessor32(), \ + x_offset_ptrs, \ + x_offset_sizes, \ + y_reshaped.packed_accessor32(), \ + y_reshaped.packed_accessor32(), \ + output_values.packed_accessor32(), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ + } + +inline int calc_used_shared_bytes(const int device) { + int max_shared_bytes; +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device)); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef USE_ROCM + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + return used_shared_bytes; +} + +template +inline void set_max_dynamic_shared_mem_size_for_opt_search_kernel(const int used_shared_bytes) { +#ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes)); // V100: 64 KB; A100: 96 KB; H100: 144 KB +#endif +} + +///@addtogroup jagged-tensor-ops-cuda +template +void jagged_dense_elementwise_jagged_output_opt_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + const Tensor& output_values, + F f) { + TENSOR_ON_CUDA_GPU(x_values); + for (auto& x_offset : x_offsets) { + TENSOR_ON_CUDA_GPU(x_offset); + } + + const int num_jagged_dim = y.dim() - 2; + TORCH_CHECK( + x_offsets.size() == static_cast(num_jagged_dim), + "x_offsets.size(), ", + x_offsets.size(), + " != num_jagged_dim, ", + num_jagged_dim); + + if (y.numel() == 0 || x_values.numel() == 0) { + return; + } + + // Canonicalize y to 3D, collapsing jagged dimensions. + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); + if (jagged_dense_dense_elementwise_jagged_output_matches_opt( + num_jagged_dim, + x_values, + x_offsets, + y_reshaped, + y_reshaped, + output_values)) { + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] { + auto nnz = output_values.size(0); + auto B = y_reshaped.size(0); + auto E = y_reshaped.size(2); + Tensor t_rows_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kCUDA, at::cuda::current_device())); + Tensor t_cols_after_bs = at::empty( + {nnz}, + at::TensorOptions().dtype(at::kInt).device( + at::kCUDA, at::cuda::current_device())); + + // Binary search + size_t dynamic_smem_size = (B + 1) * sizeof(index_t); + auto cur_max_shared_bytes = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + if (dynamic_smem_size > cur_max_shared_bytes) { + int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device()); + set_max_dynamic_shared_mem_size_for_opt_search_kernel(used_shared_bytes); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + } + dim3 threads_bs = dim3(1024, 1, 1); + dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t> + <<>>( + x_offsets[0] + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + B); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Gather kernel + dim3 threads = dim3(16, 16, 1); + dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1); + if (blocks.y > 65535) { + blocks.y = 65535; + } + jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< + index_t> + <<>>( + output_values + .packed_accessor32(), + x_values + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + y_reshaped + .packed_accessor32(), + t_rows_after_bs + .packed_accessor32(), + t_cols_after_bs + .packed_accessor32(), + nnz, + E, + [f] __device__(__half x, __half y0, __half) -> __half { + // NB: added the static_casts here + return static_cast<__half>( + f(static_cast(x), static_cast(y0)) + ); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); // AT_DISPATCH + } else { + JAGGED_TENSOR_DISPATCH_DIMS(); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +at::Tensor _fbgemm_jagged_to_padded_dense_forward( + const Tensor& values, + TensorList offsets, + c10::IntArrayRef max_lengths, + const double padding_value) { + const size_t num_jagged_dim = offsets.size(); + TORCH_CHECK( + max_lengths.size() == num_jagged_dim, + "max_lengths.size(), ", + max_lengths.size(), + " != num_jagged_dim, ", + num_jagged_dim); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(values.get_device()); + + const Tensor values_canonicalized = values.view( + {values.size(0), + std::accumulate( + values.sizes().begin() + 1, + values.sizes().end(), + 1, + std::multiplies())}); + at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)}); + padded_values_shape.insert( + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); + + // Canonicalize padded_values by unsqueeze the last dim if the inner dense + // dimension is 1 and folded. + const bool D_folded = values.dim() == 1; + if (!D_folded) { + padded_values_shape.push_back(values.size(-1)); + } + Tensor padded_values = + at::empty_symint(padded_values_shape, values.options()); + Tensor padded_values_view = + D_folded ? padded_values.unsqueeze(-1) : padded_values; + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + values.scalar_type(), + "jagged_to_padded_dense", + [&] { + jagged_dense_elementwise_dense_output_( + values_canonicalized, + offsets.vec(), + padded_values_view, // dummy not used in the lambda function + padded_values_view, + [] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t { + return x; + }, + static_cast(padding_value)); + }); + + return padded_values; +} + +#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \ + AT_DISPATCH_CASE(TYPE, [&] { \ + jagged_dense_elementwise_jagged_output_opt_( \ + values, \ + offsets.vec(), \ + dense, \ + output, \ + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \ + return y; \ + }); \ + }) + +Tensor _fbgemm_dense_to_jagged_forward_symint( + const Tensor& dense, + TensorList offsets, + c10::optional total_L) { + // D is the embedding dimension + auto D = dense.size(-1); + + // If total_L is not given then compute it + at::SymInt total_L_computed; + if (total_L.has_value()) { + total_L_computed = total_L.value(); + } else { + total_L_computed = (int64_t)offsets.back().max().item(); + } + auto values = at::empty_symint({total_L_computed, D}, dense.options()); + auto output = at::empty_like(values); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(dense.get_device()); + + // clang-format off + AT_DISPATCH_SWITCH( + values.scalar_type(), + "dense_to_jagged_gpu_op_forward", + DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half) + // NB: removed this to build + // DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int) + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( + at::ScalarType::Long, + at::ScalarType::BFloat16, + [&] { + jagged_dense_elementwise_jagged_output_( + values, + offsets.vec(), + dense, + output, + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { + return y; + }); // device lambda + } // lambda + ) // CASE_FLOATING_TYPES_AND + ); // SWITCH + // clang-format on + +#undef DISPATCH_DENSE_TO_JAGGED_CASE + + return output; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp index 8b5fb286ec61..be39a7db2cfa 100644 --- a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -500,7 +501,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { } // namespace -static Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){ +Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){ return qadd(std::move(qa), std::move(qb), scale, zero_point); } diff --git a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp index 113c57f2cc35..5d471d235275 100644 --- a/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/ReduceOps.cpp @@ -172,16 +172,6 @@ Tensor mean_quantized_cpu( return result; } -static Tensor& mean_out_quantized_cpu( - Tensor& result, - const Tensor& self, - DimnameList dim, - bool keepdim, - std::optional opt_dtype) { - return mean_out_quantized_cpu( - self, dimnames_to_positions(self, dim), keepdim, opt_dtype, result); -} - // qstd inline bool is_std_inner_dim_fast_path( const Tensor& self, @@ -237,24 +227,5 @@ Tensor std_quantized_cpu( return result; } -static Tensor std_quantized_cpu( - const Tensor& self, - DimnameList dim, - const std::optional& correction, - bool keepdim) { - return std_quantized_cpu( - self, dimnames_to_positions(self, dim), correction, keepdim); -} - -static Tensor& std_out_quantized_cpu( - Tensor& result, - const Tensor& self, - DimnameList dim, - const std::optional& correction, - bool keepdim) { - return std_out_quantized_cpu( - self, dimnames_to_positions(self, dim), correction, keepdim, result); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp index d4dfa7ff08c9..947f9f1696dd 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp @@ -216,20 +216,6 @@ Tensor upsample_bilinear2d_quantized_cpu( } } -using at::native::upsample::compute_output_size; -using at::native::upsample::get_scale_value; - -static Tensor upsample_bilinear2d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - bool align_corners, - std::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return upsample_bilinear2d_quantized_cpu(input, osize, align_corners, scale_h, scale_w); -} - DEFINE_DISPATCH(qupsample_bilinear2d_nhwc_stub); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp index 191407bed66a..03cbb080d558 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp @@ -218,25 +218,5 @@ Tensor _upsample_nearest_exact2d_quantized_cpu( return _upsample_nearest2d_quantized_cpu(input, osize, scale_h, scale_w); } -static Tensor upsample_nearest2d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - std::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return upsample_nearest2d_quantized_cpu(input, osize, scale_h, scale_w); -} - -static Tensor _upsample_nearest_exact2d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - std::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return _upsample_nearest_exact2d_quantized_cpu(input, osize, scale_h, scale_w); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 6b9cbc4a92c1..25b9b2b4e92c 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +#include #include #include @@ -35,7 +36,6 @@ #endif #include -#include namespace { // To have a sanity check for maximum matrix size. @@ -1848,15 +1848,15 @@ class QConvInt8ForBC final { int64_t output_zero_point) { if (kReluFused) { TORCH_WARN_ONCE( - "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" - + c10::to_string(kSpatialDim) + "d_relu, " + - "have been removed, please update your model to remove these arguments."); + "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" + + std::to_string(kSpatialDim), + "d_relu, have been removed, please update your model to remove these arguments."); return packed_weight->apply_relu(act, output_scale, output_zero_point); } else { TORCH_WARN_ONCE( - "Arguments [stride, padding, dilation, groups] in ops.quantized.conv" - + c10::to_string(kSpatialDim) + "d, " + - "have been removed, please update your model to remove these arguments."); + "Arguments [stride, padding, dilation, groups] in ops.quantized.conv", + std::to_string(kSpatialDim), + "d, have been removed, please update your model to remove these arguments."); return packed_weight->apply(act, output_scale, output_zero_point); } } diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 9cfbce72e31d..4b9c8ea2bdc9 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -342,7 +342,10 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { output_shape[cols_dim] = output_columns; at::SymDimVector output_shape_vec(output_shape); - return at::empty_symint(output_shape_vec, weight.options().dtype(weight.scalar_type()), weight.suggest_memory_format()); + return at::empty_symint( + output_shape_vec, + weight.options().dtype(weight.scalar_type()), + weight.suggest_memory_format()); } namespace { @@ -373,9 +376,10 @@ Tensor _qembeddingbag_nbit_prepack_helper( int NUM_ELEM_PER_BYTE = 8 / bit_width; TORCH_CHECK( weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0, - "qembeddingbag_" + c10::to_string(bit_width) + - "bit_prepack only works for the number of columns a multiple of " + - c10::to_string(NUM_ELEM_PER_BYTE)); + "qembeddingbag_", + std::to_string(bit_width), + "bit_prepack only works for the number of columns a multiple of ", + std::to_string(NUM_ELEM_PER_BYTE)); // The "fused" representation stores the scale and bias with the // row-wise quantized data in one tensor. @@ -551,11 +555,9 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { TORCH_FN(QEmbeddingPackWeights::run)); } - TORCH_LIBRARY_IMPL(quantized, Meta, m) { m.impl( - "quantized::embedding_bag_byte_prepack", - qembeddingbag_byte_prepack_meta); + "quantized::embedding_bag_byte_prepack", qembeddingbag_byte_prepack_meta); } } // namespace diff --git a/aten/src/ATen/native/sparse/SoftMax.cpp b/aten/src/ATen/native/sparse/SoftMax.cpp index 179db48beacc..33ac3d176e6c 100644 --- a/aten/src/ATen/native/sparse/SoftMax.cpp +++ b/aten/src/ATen/native/sparse/SoftMax.cpp @@ -606,15 +606,6 @@ Tensor log_softmax_backward_sparse_cpu( return grad_input; } -static Tensor _sparse_softmax(const Tensor& input_, const int64_t dim_) { - auto result = [&]() { - NoNamesGuard guard; - return at::_sparse_softmax(input_, dim_, false); - }(); - namedinference::propagate_names(result, input_); - return result; -} - Tensor _sparse_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; @@ -633,15 +624,6 @@ Tensor _sparse_softmax(const Tensor& self, Dimname dim, optional dty return at::_sparse_softmax(self, dimname_to_position(self, dim), dtype); } -static Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_) { - auto result = [&]() { - NoNamesGuard guard; - return at::_sparse_log_softmax(input_, dim_, false); - }(); - namedinference::propagate_names(result, input_); - return result; -} - Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp index fd67f0694f2d..c2e8c4439ab9 100644 --- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp @@ -410,6 +410,9 @@ void addmv_out_sparse_csr( const Tensor& result) { #if !AT_USE_MKL_SPARSE() TORCH_CHECK(mat.layout() == kSparseBsr || mat.layout() == kSparseCsr, "Unexpected layout", mat.layout()); + if (beta.toComplexDouble() == 0.) { + result.zero_(); + } AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( result.scalar_type(), "addmv_out_sparse_csr_impl_reference", [&] { if (mat.crow_indices().scalar_type() == kLong) { diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index f058c68579f8..fff755c7b418 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -270,10 +270,6 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value) { return div_out_sparse_zerodim(self, value, self); } -static SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseTensor& r) { - return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); -} - Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional rounding_mode) { auto commonDtype = at::result_type(self, value); if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) { @@ -287,10 +283,6 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional rounding_mode, SparseTensor& r) { - return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), std::move(rounding_mode), r); -} - // -------------------------------------------------------------------- // floor_divide(SparseTensor, Scalar) // -------------------------------------------------------------------- @@ -350,10 +342,6 @@ Tensor& floor_divide_sparse_(Tensor& self, const Tensor& value) { return floor_divide_out_sparse_zerodim(self, value, self); } -static SparseTensor& floor_divide_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const Scalar& value) { - return floor_divide_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r); -} - // -------------------------------------------------------------------- // norm(SparseTensor, Scalar) // -------------------------------------------------------------------- diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index f3aabc63e2a2..d5f654097677 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -204,7 +204,7 @@ void csrmm2( T alpha, T *csrvala, int *csrrowptra, int *csrcolinda, T *b, int64_t ldb, T beta, T *c, int64_t ldc) { - TORCH_INTERNAL_ASSERT(false, "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); + static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); } template<> void csrmm2( @@ -381,7 +381,7 @@ void csrmm2( T alpha, T *csrvala, int *csrrowptra, int *csrcolinda, T *b, int64_t ldb, T beta, T *c, int64_t ldc) { - TORCH_INTERNAL_ASSERT(false, "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); + static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble."); } template<> void csrmm2( diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 20af0ee866a5..88c3ee05ab53 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -387,7 +387,7 @@ struct CusparseMatrixMultiplyOp { Tensor &output_values, Tensor &output_indices) { - TORCH_INTERNAL_ASSERT(false, "cusparse csr sparse-sparse MM only supports data type of float and double."); + static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double."); } }; diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index 727534c0d347..c31721729036 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -46,16 +46,6 @@ desc: | This tag indicates that the operator should be passed Tensors following the same stride permutation as observed in eager when compiled in inductor. - The default for custom ops (i.e. not torch._library.utils.is_builtin) - is that they do need a fixed stride order; add `does_not_need_fixed_stride_order` - to change the behavior. - The default for builtin ops is that they do not need a fixed stride order; - add `needs_fixed_stride_order` to change the behavior. -- tag: does_not_need_fixed_stride_order - desc: | - This tag indicates that the operator doesn't need to be passed Tensors following - the same stride permutation as observed in eager when compiled in inductor. - See `needs_fixed_stride_order` for more details. # NOTE [Core ATen Ops] - tag: core diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 50b47e5b1731..6a83175a15fb 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -666,7 +666,7 @@ Tensor scaled_dot_product_attention( case sdp::SDPBackend::cudnn_attention: { bool compute_logsumexp = should_compute_logsumexp(query_, key, value); auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention( - query_, key, value, dropout_p, is_causal, compute_logsumexp, scale); + query_, key, value, compute_logsumexp, dropout_p, is_causal, scale); return std::get<0>(out_lse_softmax); } case sdp::SDPBackend::flash_attention: { diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index a1cdb47c12b4..3e307b29512f 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -72,10 +72,17 @@ #include #endif #ifdef USE_MEM_EFF_ATTENTION -// MemoryEfficient Attention Specific Imports +#ifndef USE_ROCM +// MemoryEfficient Attention Specific Imports for CUDA #include #include #include +#else +// MemoryEfficient Attention Specific Imports for ROCM +#include +#include +#include +#endif #endif namespace at { @@ -728,14 +735,27 @@ std::tuple _scaled_dot_product_cudnn_attention_cuda( +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) { + if (arg.captured_) { + *seed_ptr = static_cast(*arg.seed_.ptr); + *offset_ptr = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + *seed_ptr = static_cast(arg.seed_.val); + *offset_ptr = static_cast(arg.offset_.val); + } +} + +std::tuple _scaled_dot_product_cudnn_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, + bool compute_logsumexp, double dropout_p, bool is_causal, - bool training, - std::optional scale) { + c10::optional scale) { // Used for tracking usage statistics C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn"); // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) @@ -744,8 +764,8 @@ std::tuple( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + // if using dropout, we produce 1 random number for each element of the + // attention tensor + // TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount + philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k); + unpack_cudnn<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_state, static_cast(cudnn_seed.data_ptr()), static_cast(cudnn_offset.data_ptr())); + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + Tensor debugmask; run_cudnn_SDP_fprop(batch_size/*int64_t b*/, num_heads/*int64_t h*/, max_seqlen_batch_q/*int64_t s_q*/, max_seqlen_batch_k/*int64_t s_kv*/, - head_dim/*int64_t d*/, + head_dim_qk/*int64_t d_qk*/, + head_dim_v/*int64_t d_v*/, softmax_scale/*float scaling_factor*/, - training/* bool */, + compute_logsumexp/* bool */, is_causal/* bool */, dropout_p/*double dropout_probability*/, query/* Tensor q*/, @@ -775,7 +820,7 @@ std::tuple _scaled_dot_product_efficient_attention_cuda( @@ -959,7 +1004,6 @@ std::tuple _efficient_ int64_t custom_mask_type, bool compute_logsumexp, std::optional scale, - const std::optional& causal_diagonal, const std::optional& seqlen_k, const std::optional window_size) { #if defined(USE_MEM_EFF_ATTENTION) @@ -1063,6 +1107,64 @@ std::tuple _efficient_ offset_t = at::empty({}, at::dtype(at::kLong).device(device)); } +#ifdef USE_ROCM + // ROCM Implementation + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + } + + // AOTriton may accept aligned on logsumexp tensor in the future for better + // performance, but for now it requires compact logsumexp tensor, even if + // compute_logsumexp is false + constexpr int kAlignLSE = 1; + res = at::empty({B, M, num_heads, Kv}, query.options()); + logsumexp = at::empty( + { B, num_heads, max_seqlen_q }, + query.options().dtype(at::ScalarType::Float)); + at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + at::Tensor q_t = query.transpose(1, 2); + at::Tensor k_t = key.transpose(1, 2); + at::Tensor v_t = value.transpose(1, 2); + at::Tensor output_t = res.transpose(1, 2); + bool is_causal; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + is_causal = true; + } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + is_causal = false; + } else { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } + + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + + using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); + at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); + hipError_t err; // TODO: Error handling + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + use_dropout ? *seed_t.data_ptr() : 0, + use_dropout ? *offset_t.data_ptr() : 0, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + if (!compute_logsumexp) { + // Set the tensor to empty when compute_logsumexp is false + logsumexp = at::empty( + { B * num_heads, max_seqlen_q, 0 }, + query.options().dtype(at::ScalarType::Float)); + } +#else + // CUDA Implementation cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); const int computeCapability = p->major * 10 + p->minor; @@ -1147,12 +1249,6 @@ std::tuple _efficient_ p.num_keys = max_seqlen_k; p.num_batches = seqstart_q.has_value() ? seqstart_q->size(0) - 1 : B; p.custom_mask_type = custom_mask_type; - p.causal_diagonal_ptr = nullptr; - if (causal_diagonal.has_value()) { - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(causal_diagonal.value()); - TORCH_CHECK(causal_diagonal->scalar_type() == at::ScalarType::Int); - p.causal_diagonal_ptr = (const int32_t*)causal_diagonal->const_data_ptr(); - } p.seqlen_k_ptr = nullptr; if (seqlen_k.has_value()) { @@ -1222,8 +1318,13 @@ std::tuple _efficient_ " kb)"); AT_CUDA_CHECK(err); } + auto blocks = p.getBlocksGrid(); + if (blocks.x * blocks.y * blocks.z == 0 || key.size(1) == 0) { + res.zero_(); + return; + } Kernel::check_supported(p); - kernel_fn<<>>(p); + kernel_fn<<>>(p); }; // Dispatch to the right kernel @@ -1233,6 +1334,7 @@ std::tuple _efficient_ TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!"); AT_CUDA_CHECK(cudaGetLastError()); +#endif // USE_ROCM return std::make_tuple( std::move(res), std::move(logsumexp), @@ -1253,7 +1355,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda); -#ifdef USE_MEM_EFF_ATTENTION +#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM) namespace { /** * simple kernel that populates a tensor with rand uniform values. @@ -1303,7 +1405,7 @@ __global__ void rand_uniform_kernel( } } } // namespace -#endif +#endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM) /** * fill tensor with random uniform values. only used for testing, not much * attention is paid to performance @@ -1321,6 +1423,17 @@ at::Tensor& _fill_mem_eff_dropout_mask_( const int64_t n_keys = self.size(3); #if defined(USE_MEM_EFF_ATTENTION) +#ifdef USE_ROCM + using aotriton::v2::flash::debug_fill_dropout_rng; + using sdp::aotriton_adapter::mk_aotensor; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + hipError_t err; // TODO: Error handling + + err = debug_fill_dropout_rng(mk_aotensor(self, "r"), + static_cast(seed), + static_cast(offset), + stream); +#else at::PhiloxCudaState rng_engine_inputs; rng_engine_inputs = at::PhiloxCudaState(seed, offset); at::cuda::CUDAGuard device_guard(self.device()); @@ -1334,6 +1447,7 @@ at::Tensor& _fill_mem_eff_dropout_mask_( rng_engine_inputs, reinterpret_cast(self.data_ptr()), self.numel()); +#endif return self; #endif diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 690f433aa5f2..14d389bf8653 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -36,11 +36,18 @@ #include #endif #ifdef USE_MEM_EFF_ATTENTION -// MemoryEfficient Attention Specific Imports +#ifndef USE_ROCM +// MemoryEfficient Attention Specific Imports for CUDA #include #include #include #include +#else +// MemoryEfficient Attention Specific Imports for ROCM +#include +#include +#include +#endif #endif #ifdef __HIP_PLATFORM_AMD__ @@ -164,21 +171,34 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ const Tensor& value, const Tensor& out, const Tensor& logsumexp, - const Tensor& cumulative_sequence_length_q, - const Tensor& cumulative_sequence_length_k, - const int64_t max_seqlen_batch_q, - const int64_t max_seqlen_batch_k, - double dropout_p, - bool is_causal, const Tensor& philox_seed, const Tensor& philox_offset, - std::optional scale) { +// const Tensor& cumulative_sequence_length_q, +// const Tensor& cumulative_sequence_length_k, +// const int64_t max_seqlen_batch_q, +// const int64_t max_seqlen_batch_k, + double dropout_p, + bool is_causal, + c10::optional scale) { + + + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "cuDNN Attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } + } + + const int64_t batch_size = query.size(0); const int64_t num_heads = query.size(1); - const int64_t head_dim = query.size(3); - + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_batch_q = query.size(1); + const int64_t max_seqlen_batch_k = key.size(1); const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); - auto dq = at::empty_like(query); auto dk = at::empty_like(key); auto dv = at::empty_like(value); @@ -186,7 +206,8 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ num_heads /*int64_t h*/, max_seqlen_batch_q /*int64_t s_q*/, max_seqlen_batch_k /*int64_t s_kv*/, - head_dim /*int64_t d*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, softmax_scale /*float scaling_factor*/, is_causal /*bool is_causal*/, dropout_p /*float dropout_probability*/, @@ -230,7 +251,8 @@ _efficient_attention_backward( const bool bias_requires_grad, const std::optional scale, std::optional num_splits_key, - const std::optional window_size) { + const std::optional window_size, + const bool shared_storage_dqdkdv) { #if defined(USE_MEM_EFF_ATTENTION) if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); @@ -310,9 +332,33 @@ _efficient_attention_backward( int64_t Kv = value.size(3); at::Tensor grad_q, grad_k, grad_v, grad_bias; - grad_q = at::empty(query.sizes(), query.options()); - grad_k = at::empty(key.sizes(), key.options()); - grad_v = at::empty(value.sizes(), value.options()); + if (shared_storage_dqdkdv) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + TORCH_CHECK( + query.size(1) == key.size(1), + "`shared_storage_dqdkdv` is only supported when Q/K/V " + "have the same sequence length: got ", query.size(1), + " query tokens and ", key.size(1), " key/value tokens" + ); + TORCH_CHECK( + query.size(3) == key.size(3), + "`shared_storage_dqdkdv` is only supported when Q/K/V " + "have the same embed dim: got ", query.size(3), + " for Q, and ", key.size(3), " for K" + ); + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else { + grad_q = at::empty(query.sizes(), query.options()); + grad_k = at::empty(key.sizes(), key.options()); + grad_v = at::empty(value.sizes(), value.options()); + } if (bias_requires_grad) { // force alignment for the last dim @@ -323,7 +369,6 @@ _efficient_attention_backward( grad_bias = at::empty(sz, bias->options()) .slice(/*dim=*/-1, /*start=*/0, /*end=*/lastDim); } - at::Tensor workspace; const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO; @@ -343,6 +388,62 @@ _efficient_attention_backward( } } +#ifdef USE_ROCM + // ROCM Implementation + TORCH_CHECK(!num_splits_key.has_value(), + "ROCM does not support num_split_keys in _efficient_attention_forward"); + TORCH_CHECK(!window_size.has_value(), + "ROCM does not support window_size in _efficient_attention_forward"); + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + bool is_causal; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + is_causal = true; + } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + is_causal = false; + } else { + TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now"); + } + at::Tensor q_t = query.permute({0,2,1,3}); + at::Tensor k_t = key.permute({0,2,1,3}); + at::Tensor v_t = value.permute({0,2,1,3}); + at::Tensor out_t = out.permute({0,2,1,3}); + at::Tensor dq_t = grad_q.permute({0,2,1,3}); + at::Tensor dk_t = grad_k.permute({0,2,1,3}); + at::Tensor dv_t = grad_v.permute({0,2,1,3}); + at::Tensor dout_t = grad_out.permute({0,2,1,3}); + at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + + hipError_t err; + using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + rng_engine_inputs.seed_.val, + rng_engine_inputs.offset_.val, + is_causal, + stream); +#else + at::Tensor workspace; cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); const int computeCapability = p->major * 10 + p->minor; @@ -439,8 +540,7 @@ _efficient_attention_backward( ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); - // We removed the chunk/cat optimization and the multiplier is always 1 - p.gQKV_strideM_multiplier = 1; + p.gQKV_strideM_multiplier = shared_storage_dqdkdv ? 3 : 1; TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); @@ -503,8 +603,12 @@ _efficient_attention_backward( auto parallelism_without_split_key = p.getBlocksGrid().x * p.getBlocksGrid().y * p.getBlocksGrid().z; p.num_splits_key = cutlass::ceil_div(p.num_keys, Kernel::kBlockSizeJ); - if (num_splits_key.has_value()) { // Skip heuristic, if user provided an explicit value - p.num_splits_key = std::max(p.num_splits_key, num_splits_key.value()); + if (num_splits_key.has_value()) { + p.num_splits_key = + std::min(p.num_splits_key, num_splits_key.value()); + } else { + // Keys splitting heuristic + // If we already have enough parallelism, split-keys can help // better use L2 cache. // This is negligible when the seqlen is too small tho @@ -545,6 +649,15 @@ _efficient_attention_backward( workspace.zero_(); } } + + // Handle the edge-cases where some tensors are empty + if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 || + p.num_heads == 0) { + grad_k.zero_(); + grad_v.zero_(); + grad_q.zero_(); + return; + } Kernel::check_supported(p); if (smem_bytes > 0xc000) { @@ -587,8 +700,9 @@ _efficient_attention_backward( })); TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!"); AT_CUDA_CHECK(cudaGetLastError()); +#endif // USE_ROCM return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v), std::move(grad_bias)); - #endif + #endif // defined(USE_MEM_EFF_ATTENTION) TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.") return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); } diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 9eb3958bf569..24ba7e1343b1 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -602,7 +602,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q cu_seqlens_q_d = nullptr; } - const int total_q = q.sizes()[0]; + const int total_q = temp_q.sizes()[0]; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h index db817a0657ff..5089fb2e294f 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h @@ -540,7 +540,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in : pytorch_flash::convert_type_relu(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h index 0386a07cc64f..9d97abb5eb90 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h @@ -339,7 +339,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -402,7 +402,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -895,7 +895,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = pytorch_flash::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); @@ -957,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = pytorch_flash::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index 564e3f2f3522..05fa314a2bf6 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -607,7 +607,10 @@ struct AttentionBackwardKernel { using AccumTileGmem = GmemTile; }; - static constexpr bool kEnableSplitKeys = true; + // NOTE: nvcc 12.4 has correctness errors with this on M60 (sm52) + // when there is an attention bias. Let's just disable it for now. + static constexpr auto kMinSm = ArchTag::kMinComputeCapability; + static constexpr bool kEnableSplitKeys = kMinSm >= 70; static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || !cutlass::platform::is_same::value; @@ -720,11 +723,19 @@ struct AttentionBackwardKernel { int64_t gV_strideH = 0; int64_t gB_strideH = 0; - CUTLASS_DEVICE int16_t num_splits_key_device() const { + CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const { +#ifdef __CUDA_ARCH__ return kEnableSplitKeys ? gridDim.x : 1; +#else + return num_splits_key; // for host-side tests +#endif } - CUTLASS_DEVICE int16_t split_key_device() const { + CUTLASS_HOST_DEVICE int16_t split_key_device() const { +#ifdef __CUDA_ARCH__ return kEnableSplitKeys ? blockIdx.x : 0; +#else + return 0; // for host-side tests +#endif } CUTLASS_DEVICE bool advance_to_block() { @@ -846,14 +857,14 @@ struct AttentionBackwardKernel { if (!kNeedsAccumGradK) { return 0; } - return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + return num_splits_key * kBlockSizeJ * align_up(head_dim, (int32_t)kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { if (!kNeedsAccumGradV) { return 0; } - return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + return num_splits_key * kBlockSizeJ * align_up(head_dim_value, (int32_t)kBlockSizeI); } CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { @@ -877,7 +888,7 @@ struct AttentionBackwardKernel { return num_batches * num_heads * workspace_strideBH() * sizeof(float); } CUTLASS_HOST_DEVICE bool should_zero_workspace() const { - return num_splits_key > 1; + return num_splits_key > 1 || window_size > 0; } }; @@ -1174,8 +1185,12 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); - TORCH_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); - TORCH_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + TORCH_CHECK( + p.num_heads <= 1 || p.lse_strideH % 8 == 0, + "LSE is not correctly aligned (strideH)"); + TORCH_CHECK( + p.num_batches <= 1 || p.lse_strideB % 8 == 0, + "LSE is not correctly aligned (strideB)"); TORCH_CHECK( p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned (strideH)"); @@ -1187,7 +1202,7 @@ struct AttentionBackwardKernel { "value is not correctly aligned (strideH)"); TORCH_CHECK( p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, - "query is not correctly aligned (strideB)."); + "query is not correctly aligned (strideB)"); TORCH_CHECK( p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, "key is not correctly aligned (strideB)"); @@ -1268,15 +1283,18 @@ struct AttentionBackwardKernel { } TORCH_CHECK( kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); - TORCH_CHECK(p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + TORCH_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); TORCH_CHECK( p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), "Invalid `num_splits_key` (", p.num_splits_key, ") - too large for `num_keys` = ", p.num_keys); - if (p.window_size > 0) { - TORCH_CHECK(p.custom_mask_type == CausalFromTopLeft); + if (p.window_size != 0) { + TORCH_CHECK( + p.custom_mask_type != NoCustomMask, + "LocalAttention only supported in causal mode"); } return true; } @@ -1338,15 +1356,15 @@ struct AttentionBackwardKernel { std::get<1>(seeds) + p.dropout_batch_head_rng_offset, &rng_state_init); } + CUTLASS_PRAGMA_UNROLL for (; key_start < p.num_keys; key_start += p.num_splits_key_device() * kBlockSizeJ) { output_frags.clear(); - CUTLASS_PRAGMA_UNROLL - for (int32_t query_start_shifted = getQueryStart(p, key_start); - query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); - query_start_shifted += kBlockSizeI) { + int32_t next_key = key_start; + int32_t query_start = getQueryStart(p, key_start); + while (next_key == key_start && query_start < p.num_queries) { // This line here // vvvvvvvvvvvvvv warp_id = warp_uniform(warp_id); @@ -1357,11 +1375,6 @@ struct AttentionBackwardKernel { // from the previous iteration, which prevents MASSIVE // register spilling. - int32_t query_start = query_start_shifted; - if (query_start >= p.num_queries) { - query_start = query_start % getQueryEnd(p); - } - processBlockIJ( shared_storage, output_frags, @@ -1371,6 +1384,10 @@ struct AttentionBackwardKernel { rng_state_init, warp_id, lane_id); + + int32_t next_query; + incrIteration(p, query_start, key_start, next_query, next_key); + query_start = next_query; } if (kOutputInRF) { writeFragsToGmem( @@ -1466,13 +1483,7 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : warp_uniform(cutlass::fast_min( (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); - if (p.window_size > 0) { - if (p.custom_mask_type == CausalFromTopLeft && - key_start + num_keys_in_block <= - int32_t(query_start) - p.window_size) { - return; - } - } + auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( {int32_t(p.gO_strideM)}, @@ -2119,14 +2130,20 @@ struct AttentionBackwardKernel { p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); - bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || + // if `direct_store` is True, we store to gmem (`*gmem = accum`) + // otherwise, we accumulate in gmem (`*gmem = *gmem + accum`) + // If we know ahead of time when we will write for the first time + // we can: + // (1) Avoid an additional memory read + // (2) Avoid the cost of initializing memory to 0 + bool direct_store = kNeedsAccumGradQ || key_start == 0 || (p.num_splits_key_device() > 1); accumulateInGmem( isLastColumn ? shared_storage.gradQ_epilogue_lastIter() : shared_storage.gradQ_epilogue(), accum, output_it, - storage_contains_zeros, + direct_store, warp_id, lane_id); } @@ -2237,12 +2254,13 @@ struct AttentionBackwardKernel { isFirstQuery || kNeedsAccumGradK, warp_id, lane_id); + __syncthreads(); } } } } - static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { + static CUTLASS_HOST_DEVICE int32_t getQueryStartShift(Params const& p) { if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); } @@ -2250,55 +2268,70 @@ struct AttentionBackwardKernel { } // Iteration order logic - static CUTLASS_DEVICE int32_t + static CUTLASS_HOST_DEVICE int32_t getQueryStart(Params const& p, int32_t key_start) { return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); }; - static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { + static CUTLASS_HOST_DEVICE int32_t getQueryEnd(Params const& p) { return align_up(p.num_queries, kBlockSizeI); }; - static CUTLASS_DEVICE int32_t + static CUTLASS_HOST_DEVICE int32_t getSmallestQueryForKey(Params const& p, int32_t key_start) { - if (p.custom_mask_type == CausalFromTopLeft) { - return (key_start / kBlockSizeI) * kBlockSizeI; - } else if (p.custom_mask_type == CausalFromBottomRight) { - int first_query = - cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); - return (first_query / kBlockSizeI) * kBlockSizeI; + if (p.custom_mask_type == NoCustomMask) { + return 0; } - return 0; + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int32_t window_size = + p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size; + + auto last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1; + int first_query = key_start - shift; + int last_query = last_key_for_block - shift + window_size - 1; + if (last_query < 0 || first_query >= p.num_queries) { + return getQueryEnd(p); // nothing to compute in this column + } + first_query = cutlass::fast_max(0, first_query); + return (first_query / kBlockSizeI) * kBlockSizeI; }; // Returns how many kernel blocks will write to a given block in `grad_query` // This is usually equal to the number of key splits, but can be different // for instance in the causal case, or varying seqlen - static CUTLASS_DEVICE int32_t + static CUTLASS_HOST_DEVICE int32_t getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); - if (p.custom_mask_type == CausalFromTopLeft) { - int32_t last_key_for_block = query_start + kBlockSizeI - 1; - last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + if (p.custom_mask_type != NoCustomMask) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + int32_t last_query_for_block = + cutlass::fast_min(query_start + kBlockSizeI, p.num_queries) - 1; + int32_t last_key_for_block = + cutlass::fast_min(last_query_for_block + shift, p.num_keys - 1); + int32_t first_key_for_block = p.window_size == 0 + ? 0 + : cutlass::fast_max(query_start - p.window_size + 1 + shift, 0); + if (p.window_size == 0) { - num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + num_key_blocks = last_key_for_block / kBlockSizeJ + 1; } else { - int32_t first_key_for_block = - cutlass::fast_max(query_start - p.window_size + 1, 0); - int32_t first_key_block = first_key_for_block / kBlockSizeJ; - int32_t last_key_block = last_key_for_block / kBlockSizeJ; - num_key_blocks = last_key_block - first_key_block + 1; + num_key_blocks = (last_key_for_block / kBlockSizeJ) - + (first_key_for_block / kBlockSizeJ) + 1; + } + + if (last_key_for_block < 0 || first_key_for_block >= p.num_keys) { + num_key_blocks = 0; } - } else if (p.custom_mask_type == CausalFromBottomRight) { - int32_t last_key_for_block = - query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); - last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); - num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); } return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); }; // Returns the next block to process - static CUTLASS_DEVICE void incrIteration( + static CUTLASS_HOST_DEVICE void incrIteration( Params const& p, int32_t query_start, int32_t key_start, @@ -2318,14 +2351,19 @@ struct AttentionBackwardKernel { return; } } else { - if (p.window_size == 0 && next_query < p.num_queries) { - return; - } else if (p.window_size > 0) { - if (next_query < - cutlass::fast_min( - key_start + kBlockSizeJ + p.window_size, p.num_queries)) { + if (p.window_size > 0) { + int32_t shift = p.custom_mask_type == CausalFromBottomRight + ? p.num_keys - p.num_queries + : 0; + // last key that is not masked out + int last_key_for_block = + cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1; + int last_query = last_key_for_block - shift + p.window_size - 1; + if (next_query <= last_query && next_query < p.num_queries) { return; } + } else if (next_query < p.num_queries) { + return; } // jump to next key } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 74330ecd242a..a10e5a9c44a0 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -138,7 +138,6 @@ struct AttentionKernel { const int32_t* seqstart_q_ptr = nullptr; const int32_t* seqstart_k_ptr = nullptr; - const int32_t* causal_diagonal_ptr = nullptr; const int32_t* seqlen_k_ptr = nullptr; uint32_t causal_diagonal_offset = 0; @@ -153,46 +152,46 @@ struct AttentionKernel { int32_t window_size = 0; // Scale - accum_t scale; + accum_t scale = 0.0; // Dimensions/strides - int32_t head_dim; - int32_t head_dim_value; - int32_t num_queries; - int32_t num_keys; - int32_t num_keys_absolute; + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; uint8_t custom_mask_type = NoCustomMask; - int32_t q_strideM; - int32_t k_strideM; - int32_t v_strideM; + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; int32_t bias_strideM = 0; int32_t o_strideM = 0; // Everything below is only used in `advance_to_block` // and shouldn't use registers - int32_t q_strideH; - int32_t k_strideH; - int32_t v_strideH; + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; int64_t bias_strideH = 0; - int64_t q_strideB; - int64_t k_strideB; - int64_t v_strideB; + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; int64_t bias_strideB = 0; - int32_t num_batches; - int32_t num_heads; + int32_t num_batches = 0; + int32_t num_heads = 0; // dropout - bool use_dropout; - unsigned long long dropout_batch_head_rng_offset; - float dropout_prob; - at::PhiloxCudaState rng_engine_inputs; - int64_t* extragraph_offset; - int64_t* seed; + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); + int64_t* extragraph_offset = nullptr; + int64_t* seed = nullptr; // Moves pointers to what we should process // Returns "false" if there is no work to do @@ -209,7 +208,7 @@ struct AttentionKernel { head_id * num_queries * num_keys; } - int64_t q_start, k_start; + int64_t q_start = 0, k_start = 0; // Advance to current batch - in case of different sequence lengths if (seqstart_q_ptr != nullptr) { assert(seqstart_k_ptr != nullptr); @@ -274,11 +273,8 @@ struct AttentionKernel { } // Custom masking - if (causal_diagonal_ptr) { - causal_diagonal_offset = causal_diagonal_ptr[batch_id]; - } if (custom_mask_type == CausalFromBottomRight) { - causal_diagonal_offset += num_keys - num_queries; + causal_diagonal_offset = num_keys - num_queries; } // We use num_keys_absolute to index into the rng_state // We need this index to match between forward and backwards @@ -302,7 +298,7 @@ struct AttentionKernel { // - we only launch kernels for head_id % kQueriesPerBlock == 0 // - we iterate over heads instead of queries (strideM = strideH) if (num_queries == 1 && k_strideH == 0 && v_strideH == 0 && - logsumexp_ptr == nullptr) { + logsumexp_ptr == nullptr && window_size == 0) { if (head_id % kQueriesPerBlock != 0) { return false; } @@ -318,6 +314,7 @@ struct AttentionKernel { // Make sure the compiler knows these variables are the same on all // the threads of the warp. + // Only worth doing if they could have been modified above. query_ptr = warp_uniform(query_ptr); key_ptr = warp_uniform(key_ptr); value_ptr = warp_uniform(value_ptr); @@ -330,8 +327,6 @@ struct AttentionKernel { num_queries = warp_uniform(num_queries); num_keys = warp_uniform(num_keys); num_heads = warp_uniform(num_heads); - head_dim = warp_uniform(head_dim); - head_dim_value = warp_uniform(head_dim_value); o_strideM = warp_uniform(o_strideM); custom_mask_type = warp_uniform(custom_mask_type); return true; @@ -614,16 +609,14 @@ struct AttentionKernel { TORCH_CHECK( p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, "value is not correctly aligned (strideH)"); - TORCH_CHECK( - p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask, - "`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal"); TORCH_CHECK( p.custom_mask_type < NumCustomMaskTypes, "invalid value for `custom_mask_type`"); if (p.window_size > 0) { TORCH_CHECK( p.custom_mask_type == CausalFromTopLeft || - p.custom_mask_type == CausalFromBottomRight); + p.custom_mask_type == CausalFromBottomRight, + "custom_mask_type not supported"); } return true; } diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 372377e1eca6..389c08b152ba 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -44,14 +45,28 @@ namespace sdp { namespace { + +// TODO(eqy): more benchmarking to determine whether this should include sm86/89 +// Needs to be kept in-sync with test_fused_chocie in test_transformers.py +bool check_prefer_cudnn_attention() { + auto dprops = at::cuda::getCurrentDeviceProperties(); + return dprops->major >= 9; +} + // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { constexpr std::array default_order{ + SDPBackend::flash_attention, + SDPBackend::cudnn_attention, + SDPBackend::efficient_attention, + SDPBackend::math}; + constexpr std::array cudnn_order{ SDPBackend::cudnn_attention, SDPBackend::flash_attention, SDPBackend::efficient_attention, SDPBackend::math}; - return default_order; + static const bool prefer_cudnn = check_prefer_cudnn_attention(); + return prefer_cudnn ? cudnn_order : default_order; } bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) { @@ -215,6 +230,17 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; +#if USE_ROCM + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (debug) { + TORCH_WARN( + "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + } + return false; + } +#else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { @@ -227,6 +253,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } +#endif return true; } @@ -439,17 +466,6 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) { return true; } -bool check_is_causal(sdp_params const& params, bool debug) { - // Check that the input is causal - if (!params.is_causal) { - if (debug) { - TORCH_WARN("CuDNN requires is_causal=True."); - } - return false; - } - return true; -} - bool check_for_nested_inputs(sdp_params const& params, bool debug) { // Check that the input is nested if (has_for_nested_inputs(params)) { @@ -473,22 +489,6 @@ bool check_dtypes_low_precision(sdp_params const& params, bool debug) { } } -bool check_runtime_enabled_cudnn(sdp_params const& params, bool debug) { - static c10::once_flag supported_flag; - static bool supported = false; - c10::call_once(supported_flag, []() { - supported = (c10::utils::check_env("TORCH_CUDNN_SDPA_ENABLED") == true); - }); - if (!supported) { - if (debug) { - TORCH_WARN( - "The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1`"); - } - return false; - } - return true; -} - bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // We check the global context to see if user has explicitly turned of cudnn // sdp kernels @@ -501,13 +501,15 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { return true; } -bool check_cudnn_requires_grad(sdp_params const& params, bool debug) { - // Check that the input is causal - if (input_requires_grad(params)) { - if (debug) { - TORCH_WARN("CuDNN does not currently support inputs with requires_grad=True."); +bool check_cudnn_deterministic(const sdp_params& params, bool debug) { + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (!ctx.deterministicAlgorithmsWarnOnly()) { + if (debug) { + TORCH_WARN("cuDNN SDPA is not deterministic."); + } + return false; } - return false; } return true; } @@ -515,21 +517,29 @@ bool check_cudnn_requires_grad(sdp_params const& params, bool debug) { } // namespace bool can_use_cudnn_attention(const sdp_params& params, bool debug) { - +#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) + TORCH_WARN_ONCE(!debug, "Torch was not compiled with cuDNN attention."); + return false; +#endif // Define gate functions that determine if a flash kernel can be ran // Replace with std::to_array when we migrate to c++20 constexpr auto general_constraints = array_of( - check_runtime_enabled_cudnn, - check_runtime_disabled_cudnn, - check_cudnn_hardware_support, + check_for_nested_inputs, + check_nonzero_sequence_lengths_dense, + check_last_dim_stride_equals_1_dense*/>, check_all_tensors_on_device, + check_tensor_shapes, check_cudnn_tensor_shapes, - check_cudnn_layout, + check_runtime_disabled_cudnn, + check_cudnn_deterministic, + // check_cudnn_layout, // check_is_causal, - check_for_nested_inputs, - check_cudnn_requires_grad, - check_dtypes_low_precision); + check_dtypes_low_precision, + check_for_attn_mask_cudnn, + check_cudnn_hardware_support + ); for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; @@ -597,6 +607,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { array_of(at::kHalf, at::kFloat, at::kBFloat16); constexpr auto less_than_sm80_mem_efficient_dtypes = array_of(at::kHalf, at::kFloat); +#ifdef USE_ROCM + constexpr auto aotriton_mem_efficient_dtypes = + array_of(at::kHalf, at::kFloat, at::kBFloat16); +#endif // Define gate functions that determine if a mem efficient kernel can be ran constexpr auto general_constraints = array_of( @@ -612,6 +626,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } if (has_for_nested_inputs(params)) { +#ifdef USE_ROCM + TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention."); + return false; +#endif constexpr auto nested_constraints = array_of( check_requires_grad_and_nested, check_batch_size_nested, @@ -634,10 +652,14 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } } +#ifdef USE_ROCM + return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); +#else auto dprop = at::cuda::getCurrentDeviceProperties(); if (dprop->major >= 8) { return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug); } +#endif return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); } diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h new file mode 100644 index 000000000000..1c238c751a05 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -0,0 +1,130 @@ +#pragma once + +#ifdef USE_ROCM + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h +//////////////////////////////////////////////////////////////////////////////// + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace sdp { + +namespace aotriton_adapter { + +inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) +{ +#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname + CAST_TYPE(kByte, kUInt8); + CAST_TYPE(kUInt16, kUInt16); + CAST_TYPE(kUInt32, kUInt32); + CAST_TYPE(kUInt64, kUInt64); + CAST_TYPE(kChar, kInt8); + CAST_TYPE(kShort, kInt16); + CAST_TYPE(kInt, kInt32); + CAST_TYPE(kLong, kInt64); + CAST_TYPE(kHalf, kFloat16); + CAST_TYPE(kFloat, kFloat32); + CAST_TYPE(kBFloat16, kBFloat16); + return aotriton::DType::kUnknown; +#undef CAST_TYPE +} + +template +struct IntArrayRefCaster { + // std::array cast(IntArrayRef); +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ static_cast(ref.at(0)) }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)), + static_cast(ref.at(3)) + }}; + } +}; + + +template +aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) +{ + const auto strides = q.strides(); + int real_rank = strides.size(); + if (real_rank != Rank) { // Lazy convertion of tensor_name + TORCH_CHECK(false, + std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + + " but is " + std::to_string(real_rank)); + } + return aotriton::TensorView(reinterpret_cast(q.data_ptr()), + IntArrayRefCaster::cast(q.sizes()), + IntArrayRefCaster::cast(strides), + cast_dtype(q.dtype())); +} + +} // namespace aotriton_adapter + +} // namespace sdp + +namespace at::native { + +inline int64_t ceil_div(int64_t numerator, int64_t denominator) { + return (numerator + (denominator - 1)) / denominator; +} + +} + +#endif // USE_ROCM diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index e110e4ae1c64..7af480a7ae49 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -54,16 +54,15 @@ #include #endif +#include #include #include #include // AOTriton headers -#include #include #include -#include namespace pytorch_flash { @@ -73,90 +72,10 @@ void check_gpu_arch(hipStream_t stream) { auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") } } -aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) -{ -#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname - CAST_TYPE(kByte, kUInt8); - CAST_TYPE(kUInt16, kUInt16); - CAST_TYPE(kUInt32, kUInt32); - CAST_TYPE(kUInt64, kUInt64); - CAST_TYPE(kChar, kInt8); - CAST_TYPE(kShort, kInt16); - CAST_TYPE(kInt, kInt32); - CAST_TYPE(kLong, kInt64); - CAST_TYPE(kHalf, kFloat16); - CAST_TYPE(kFloat, kFloat32); - CAST_TYPE(kBFloat16, kBFloat16); - return aotriton::DType::kUnknown; -#undef CAST_TYPE -} - -template -struct IntArrayRefCaster { - // std::array cast(IntArrayRef); -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ static_cast(ref.at(0)) }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)) - }}; - } -}; - -template -struct IntArrayRefCaster { - static auto cast(at::IntArrayRef ref) { - return std::array{{ - static_cast(ref.at(0)), - static_cast(ref.at(1)), - static_cast(ref.at(2)), - static_cast(ref.at(3)) - }}; - } -}; - - -template -aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) -{ - const auto strides = q.strides(); - int real_rank = strides.size(); - if (real_rank != Rank) { // Lazy convertion of tensor_name - TORCH_CHECK(false, - std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) - + " but is " + std::to_string(real_rank)); - } - return aotriton::TensorView(reinterpret_cast(q.data_ptr()), - IntArrayRefCaster::cast(q.sizes()), - IntArrayRefCaster::cast(strides), - cast_dtype(q.dtype())); -} - } #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") @@ -300,9 +219,13 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_fwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), mk_aotensor(v_t, "v"), + empty_bias, softmax_scale, mk_aotensor<2>(M, "M"), mk_aotensor(output_t, "Out"), @@ -495,15 +418,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si hipError_t err; // TODO: Error handling { using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), mk_aotensor(v_t, "v"), + empty_bias, softmax_scale, mk_aotensor(out_t, "out"), mk_aotensor(dout_t, "dout"), mk_aotensor(dq_t, "dq"), mk_aotensor(dk_t, "dk"), mk_aotensor(dv_t, "dv"), + empty_bias, mk_aotensor<2>(softmax_lse_cont, "L"), mk_aotensor<2>(delta, "delta"), p_dropout, diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 7c56a1f617db..70d9be903ce9 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -266,7 +266,18 @@ inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) inline bool check_for_attn_mask(sdp_params const& params, bool debug) { if (params.attn_mask.has_value()) { if (debug) { - TORCH_WARN("Flash Attention does not support non-null attn_mask."); + TORCH_WARN("Flash Attention do not support non-null attn_mask."); + } + return false; + } + return true; +} + +// TODO(eqy): remove this once support is added +inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) { + if (params.attn_mask.has_value()) { + if (debug) { + TORCH_WARN("cuDNN Attention does not support non-null attn_mask."); } return false; } @@ -313,7 +324,7 @@ inline bool check_tensor_shapes(sdp_params const& params, bool debug) { (query_dim == 4))) { if (debug) { TORCH_WARN( - "Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ", + "All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ", query_dim, ", Key dim: ", params.key.dim(), @@ -425,7 +436,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool if (zero_seq_len_q || zero_seq_len_k) { if (debug) { TORCH_WARN( - "Both fused kernels do not support zero seq_len_q or seq_len_kv."); + "All fused kernels do not support zero seq_len_q or seq_len_kv."); } return false; } @@ -460,7 +471,7 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool } epilogue_message << " instead."; TORCH_WARN( - "Both fused kernels require the last dimension of the input to have stride 1. ", + "All fused kernels require the last dimension of the input to have stride 1. ", "Got Query.stride(-1): ", params.query.sym_stride(-1), ", Key.stride(-1): ", diff --git a/aten/src/ATen/native/vol2col.h b/aten/src/ATen/native/vol2col.h index ccbfc69ce3c6..fa5c46b8c52e 100644 --- a/aten/src/ATen/native/vol2col.h +++ b/aten/src/ATen/native/vol2col.h @@ -5,7 +5,7 @@ namespace at::native { template -static void vol2col( +void vol2col( const T* data_vol, const int64_t channels, const int64_t depth, @@ -56,7 +56,7 @@ static void vol2col( } template -static void col2vol( +void col2vol( const T* data_col, const int64_t channels, const int64_t depth, diff --git a/aten/src/ATen/test/pow_test.cpp b/aten/src/ATen/test/pow_test.cpp index fb3b073f29f3..95bb48b341f5 100644 --- a/aten/src/ATen/test/pow_test.cpp +++ b/aten/src/ATen/test/pow_test.cpp @@ -10,12 +10,6 @@ #include #include -#ifdef _WIN32 -#define DISABLED_ON_WINDOWS(x) DISABLED_##x -#else -#define DISABLED_ON_WINDOWS(x) x -#endif - using namespace at; namespace { @@ -204,7 +198,7 @@ void tensor_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, Pows pows, c std::cout.precision(dbl::max_digits10); const auto vals_tensor = torch::tensor(vals, vals_dtype); - for (const auto shift : c10::irange(pows.size())) { + for ([[maybe_unused]] const auto shirt : c10::irange(pows.size())) { const auto pows_tensor = torch::tensor(pows, pows_dtype); const auto actual_pow = vals_tensor.pow(pows_tensor); diff --git a/aten/src/ATen/test/reduce_ops_test.cpp b/aten/src/ATen/test/reduce_ops_test.cpp index bcae3fdc51f9..a9ce7e4cf8f4 100644 --- a/aten/src/ATen/test/reduce_ops_test.cpp +++ b/aten/src/ATen/test/reduce_ops_test.cpp @@ -9,9 +9,8 @@ TEST(ReduceOpsTest, MaxValuesAndMinValues) { const int W = 10; const int H = 10; if (hasCUDA()) { - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - for (const auto dtype : {kHalf, kFloat, kDouble, kShort, kInt, kLong}) { - auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(at::kHalf)); + for (const auto dtype : {kHalf, kFloat, kDouble}) { + auto a = at::rand({H, W}, TensorOptions(kCUDA).dtype(dtype)); ASSERT_FLOAT_EQ( a.amax(c10::IntArrayRef{0, 1}).item(), a.max().item() diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index c10e8386d683..0d7b62b44d21 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -82,7 +82,6 @@ TEST(TestScalar, TestScalar) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_NO_THROW(gen.set_current_seed(std::random_device()())); } - auto&& C = at::globalContext(); if (at::hasCUDA()) { auto t2 = zeros({4, 4}, at::kCUDA); cout << &t2 << "\n"; diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 4c7e3e5b2b02..f9a0557f8bdf 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -978,26 +978,6 @@ namespace { b[i] = b[i - 1] + (T)(1.0); } } - template<> - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - void blend_init, 4>(Complex(&a)[4], Complex(&b)[4]) { - auto add = Complex(1., 100.); - a[0] = Complex(1., 100.); - b[0] = Complex(5., 1000.); - for (const auto i : c10::irange(1, 4)) { - a[i] = a[i - 1] + add; - b[i] = b[i - 1] + add; - } - } - template<> - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - void blend_init, 2>(Complex(&a)[2], Complex(&b)[2]) { - auto add = Complex(1.0, 100.0); - a[0] = Complex(1.0, 100.0); - b[0] = Complex(3.0, 1000.0); - a[1] = a[0] + add; - b[1] = b[0] + add; - } TYPED_TEST(BitwiseFloatsAdditional, Blendv) { using vec = TypeParam; using VT = ValueType; diff --git a/aten/src/ATen/xpu/XPUGeneratorImpl.h b/aten/src/ATen/xpu/XPUGeneratorImpl.h index ce77d2e444e6..a1f264382a36 100644 --- a/aten/src/ATen/xpu/XPUGeneratorImpl.h +++ b/aten/src/ATen/xpu/XPUGeneratorImpl.h @@ -4,7 +4,7 @@ namespace at { -struct TORCH_API XPUGeneratorImpl : public GeneratorImpl { +struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { // Constructors XPUGeneratorImpl(DeviceIndex device_index = -1); ~XPUGeneratorImpl() override = default; diff --git a/aten/src/ATen/xpu/detail/XPUHooks.cpp b/aten/src/ATen/xpu/detail/XPUHooks.cpp index 22f4ff22b4bb..61bc19faa95e 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.cpp +++ b/aten/src/ATen/xpu/detail/XPUHooks.cpp @@ -25,7 +25,13 @@ std::string XPUHooks::showConfig() const { int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const { TORCH_CHECK(device.is_xpu(), "Only the XPU device type is expected."); +#ifdef _WIN32 + TORCH_CHECK( + false, + "Default context is not supported on XPU on Windows. So we can NOT find its global index of the ATen device."); +#else return at::xpu::getGlobalIdxFromDevice(device.index()); +#endif } Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const { @@ -38,7 +44,13 @@ const Generator& XPUHooks::getDefaultXPUGenerator( } Device XPUHooks::getDeviceFromPtr(void* data) const { +#ifdef _WIN32 + TORCH_CHECK( + false, + "Default context is not supported on XPU on Windows. So we can NOT find the ATen device of a pointer."); +#else return at::xpu::getDeviceFromPtr(data); +#endif } c10::DeviceIndex XPUHooks::getNumGPUs() const { diff --git a/benchmarks/distributed/pipeline/benchmark_dataset.py b/benchmarks/distributed/pipeline/benchmark_dataset.py deleted file mode 100644 index 3cd22e9a468d..000000000000 --- a/benchmarks/distributed/pipeline/benchmark_dataset.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from torch.utils.data import Dataset - - -def collate_sentences_lm(samples): - if len(samples) == 0: - return {} - - id = torch.LongTensor([s["id"] for s in samples]) - src_tokens = torch.stack([s["source"] for s in samples], 0) - tgt_tokens = torch.stack([s["target"] for s in samples], 0) - ntokens = len(samples) * len(samples[0]["target"]) - src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) - - batch = { - "id": id, - "nsentences": len(samples), - "ntokens": ntokens, - "input": src_tokens, - "target": tgt_tokens, - } - return batch - - -class BenchmarkLMDataset(Dataset): - """ - Dataset to benchmark a translation like seq2seq task. - Args: - vocab_size (int, optional): size of the vocabulary (default 10000). - max_source_positions (int, optional): max number of tokens in the - source sentence (default: 1024). - total_samples (int, optional): the total number of rows in the - dataset (default: 10000). - """ - - def __init__( - self, - vocab_size=10000, - max_source_positions=1024, - total_samples=10000, - ): - self.vocab_size = vocab_size - self.max_source_positions = max_source_positions - self.total_samples = total_samples - self.sizes = [self.max_source_positions] * self.total_samples - - def __getitem__(self, index): - length = self.sizes[index] - source = torch.randint(1, self.vocab_size, (length,)) - target = source.clone() - return { - "id": index, - "source": source, - "target": target, - } - - def __len__(self): - return self.total_samples diff --git a/benchmarks/distributed/pipeline/pipe.py b/benchmarks/distributed/pipeline/pipe.py deleted file mode 100644 index c465c2488565..000000000000 --- a/benchmarks/distributed/pipeline/pipe.py +++ /dev/null @@ -1,296 +0,0 @@ -import argparse -import math -import os -import time - -from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm - -import torch -import torch.nn as nn -from torch.distributed import rpc - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.utils import partition_model -from torch.optim import Adam -from torch.utils.data import DataLoader - - -def sizeof_fmt(num, suffix="B"): - for unit in ["", "Ki", "Mi", "Gi", "Ti"]: - if abs(num) < 1024.0: - return f"{num:3.2f}{unit}B" - num /= 1024.0 - - -def init_random_seed(seed: int): - import numpy - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - numpy.random.seed(seed) - - -iteration_count = 0 - - -class EmbeddingLayer(nn.Embedding): - def __init__(self, ntoken, ninp, initrange): - super().__init__(ntoken, ninp) - self.ninp = ninp - nn.init.uniform_(self.weight, -initrange, initrange) - - def forward(self, src): - return super().forward(src) * math.sqrt(self.ninp) - - -class PositionalEncodingLayer(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) - - -class TransformerDecoderLayer(nn.TransformerEncoderLayer): - """Though this class inherits from torch.nn.TransformerEncoderLayer, - it functions as a decoder in this model""" - - def __init__(self, ninp, nhead, nhid, droupout): - super().__init__(ninp, nhead, nhid, droupout) - self.src_mask = None - - def forward(self, src): - global iteration_count - iteration_count += 1 - - if self.src_mask is None or self.src_mask.size(0) != len(src): - device = src.device - mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) - self.src_mask = mask - - return super().forward(src, self.src_mask) - - -class LinearLayer(nn.Linear): - def __init__(self, ninp, ntoken, initrange): - super().__init__(ninp, ntoken) - nn.init.zeros_(self.bias) - nn.init.uniform_(self.weight, -initrange, initrange) - - -class TransformerLMSequential(nn.Sequential): - """A small language model based on the design of GPT-2 using nn.Sequential - for compatibility with Pipe""" - - def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder): - layers = [ - EmbeddingLayer(ntokens, ninp, initrange), - PositionalEncodingLayer(ninp, dropout), - ] - for _ in range(ndecoder): - layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout)) - - layers.append(LinearLayer(ninp, ntokens, initrange)) - super().__init__(*layers) - - -def make_model(args, device, ntokens): - ninp = 2048 # embedding dimension - nhid = ( - 2048 # the dimension of the feedforward network model in nn.TransformerEncoder - ) - nhead = 32 # the number of heads in the multiheadattention models - dropout = 0 - initrange = 0.1 - ndecoder = args.num_decoder_layers - - model = TransformerLMSequential( - ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder - ).to(device) - - criterion = nn.CrossEntropyLoss() - lr = 0.01 # learning rate - - def make_adam(model): - return Adam(model.parameters(), lr=lr) - - optimizer = make_adam - - return model, criterion, optimizer - - -def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): - model.train() - - vocab_size = 10000 - total_loss = 0.0 - start_time = time.time() - word_counter = 0 - - optimizer = optimizer(model) - - def get_first_device(model): - if model.devices: - return model.devices[0] - else: - return torch.cuda.current_device() - - def get_last_device(model): - if model.devices: - return model.devices[-1] - else: - return torch.cuda.current_device() - - print( - f"Number of parameters for model: {sum(p.numel() for p in model.parameters())}" - ) - for i, batch in enumerate(lm_dataloader): - bi = batch["input"] - if args.max_batch and i > args.max_batch: - break - optimizer.zero_grad() - try: - tmp = batch["input"].to(get_first_device(model)) - output = model(tmp).local_value() - except Exception as e: - raise RuntimeError( - f"training failed on {torch.distributed.get_rank()}" - ) from e - - target = batch["target"].to(get_last_device(model)) - output = output.to(target.device) - - loss = criterion(output.view(-1, vocab_size), target.view(-1)) - loss.backward() - del target - del output - - torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) - optimizer.step() - - total_loss += loss.item() - log_interval = 1 - word_counter += batch["ntokens"] - if i % log_interval == 0 and i > 0: - cur_loss = total_loss / log_interval - elapsed = time.time() - start_time - print( - f"| batch {i:5d} | wps {word_counter / elapsed:5.2f} | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}" - ) - word_counter = 0 - total_loss = 0 - start_time = time.time() - - print("Peak memory usage for GPUs: ", end="") - for i in range(len(model.devices)): - print( - f"cuda:{i}: {sizeof_fmt(torch.cuda.memory_stats(i)['allocated_bytes.all.peak'])}, ", - end="", - ) - print() - - -def generate_balance(num_devices, num_layers): - balance = [] - layers_assigned = 0 - for i in range(num_devices): - x = (num_layers - layers_assigned) / (num_devices - i) - if x.is_integer(): - balance.append(int(x)) - layers_assigned += x - else: - balance.append(math.ceil(x)) - layers_assigned += math.ceil(x) - return balance - - -def make_model_and_data(args, device): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - vocab_size = 10000 - model, criterion, optimizer = make_model(args, device, vocab_size) - lm_dataset = BenchmarkLMDataset() - lm_dataloader = DataLoader( - lm_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=0, - collate_fn=collate_sentences_lm, - ) - return { - "model": model, - "criterion": criterion, - "optimizer": optimizer, - "data": lm_dataloader, - "vocab_size": vocab_size, - } - - -def bench_single_process(args): - os.environ.update({"MASTER_ADDR": args.host}) - os.environ.update({"MASTER_PORT": "10638"}) - - rpc.init_rpc( - "worker", - rank=0, - world_size=1, - ) - - num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 - num_devices = min(args.num_devices, num_devices) - assert num_devices > 0 - init_random_seed(0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - blob = make_model_and_data(args, None) - model = blob["model"] - - balance = generate_balance(num_devices, len(model)) - model = partition_model(model, balance) - p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) - del model - del blob["model"] - - train( - blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args - ) - - -parser = argparse.ArgumentParser(description="benchmark") -parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") -parser.add_argument( - "--chunks", type=int, default=4, help="number of microbatches per batch" -) -parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") -parser.add_argument("--max-batch", type=int, default=10, help="Max number of batches") -parser.add_argument( - "--num-decoder-layers", - type=int, - default=10, - help="Number of decoder layers in the model", -) -parser.add_argument( - "--checkpoint", - default="except_last", - choices=["always", "except_last", "never"], - help="Checkpointing strategy for pipe", -) -parser.add_argument( - "--num-devices", type=int, default=4, help="Number of GPU devices to use" -) - -if __name__ == "__main__": - args = parser.parse_args() - print(f"Running benchmark with args: {args}") - bench_single_process(args) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv index 1def1d99bd53..fe7efa082cea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,pass,6 +tinynet_a,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 20fb340690ac..9863aa7da6a2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,46 +hf_BigBird,pass,43 @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 5131c2e9ade4..82048af8775a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass, 52 +hf_BigBird,pass,49 @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,fail_accuracy,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 40382a4f277c..e29c62dd5b71 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,fail_to_run,0 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv new file mode 100644 index 000000000000..3af215541c1d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -0,0 +1,341 @@ +name,accuracy,graph_breaks + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,model_fail_to_load,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,pass,42 + + + +detectron2_fasterrcnn_r_101_dc5,pass,42 + + + +detectron2_fasterrcnn_r_101_fpn,pass,46 + + + +detectron2_fasterrcnn_r_50_c4,pass,42 + + + +detectron2_fasterrcnn_r_50_dc5,pass,42 + + + +detectron2_fasterrcnn_r_50_fpn,pass,46 + + + +detectron2_fcos_r_50_fpn,pass,23 + + + +detectron2_maskrcnn_r_101_c4,pass,57 + + + +detectron2_maskrcnn_r_101_fpn,fail_accuracy,64 + + + +detectron2_maskrcnn_r_50_c4,fail_accuracy,57 + + + +detectron2_maskrcnn_r_50_fpn,pass,64 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5_base,pass,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,pass,2 + + + +mobilenet_v3_large,pass,0 + + + +moco,model_fail_to_load,0 + + + +moondream,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,pass,2 + + + +resnext50_32x4d,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientdet,model_fail_to_load,0 + + + +timm_efficientnet,pass,0 + + + +timm_nfnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,28 + + + +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index fcd87f4d2454..a497fb45d7d4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -338,4 +338,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv new file mode 100644 index 000000000000..1def1d99bd53 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,pass,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,fail_accuracy,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv new file mode 100644 index 000000000000..20fb340690ac --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_inference.csv @@ -0,0 +1,381 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,pass,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +sam_fast,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv new file mode 100644 index 000000000000..5131c2e9ade4 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv @@ -0,0 +1,289 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,pass,6 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass, 52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +speech_transformer,pass,16 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..784d3788e335 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,fail_to_run,0 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,0 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,0 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv new file mode 100644 index 000000000000..c7e86a6d317e --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,fail_to_run,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..e29c62dd5b71 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -0,0 +1,353 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,fail_to_run,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,fail_to_run,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,fail_to_run,0 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,0 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,fail_to_run,0 + + + +doctr_reco_predictor,fail_to_run,0 + + + +drq,fail_to_run,0 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,fail_to_run,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,fail_to_run,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,fail_to_run,0 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,fail_to_run,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,fail_to_run,0 + + + +sam_fast,fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,fail_to_run,0 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,fail_to_run,0 + + + +vgg16,pass,0 + + + +vision_maskrcnn,fail_to_run,0 + + + +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..fcd87f4d2454 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/cpu_inductor_torchbench_inference.csv @@ -0,0 +1,341 @@ +name,accuracy,graph_breaks + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,model_fail_to_load,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,pass,42 + + + +detectron2_fasterrcnn_r_101_dc5,pass,42 + + + +detectron2_fasterrcnn_r_101_fpn,pass,46 + + + +detectron2_fasterrcnn_r_50_c4,pass,42 + + + +detectron2_fasterrcnn_r_50_dc5,pass,42 + + + +detectron2_fasterrcnn_r_50_fpn,pass,46 + + + +detectron2_fcos_r_50_fpn,pass,23 + + + +detectron2_maskrcnn_r_101_c4,fail_accuracy,57 + + + +detectron2_maskrcnn_r_101_fpn,pass,64 + + + +detectron2_maskrcnn_r_50_c4,pass,57 + + + +detectron2_maskrcnn_r_50_fpn,pass,64 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5_base,pass,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,pass,2 + + + +mobilenet_v3_large,pass,0 + + + +moco,model_fail_to_load,0 + + + +moondream,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,pass,2 + + + +resnext50_32x4d,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientdet,model_fail_to_load,0 + + + +timm_efficientnet,pass,0 + + + +timm_nfnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,28 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv new file mode 100644 index 000000000000..1def1d99bd53 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,pass,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,fail_accuracy,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv new file mode 100644 index 000000000000..bcdf06917b64 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_inference.csv @@ -0,0 +1,377 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,pass,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv new file mode 100644 index 000000000000..1e1a4be4149e --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv @@ -0,0 +1,285 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,fail_to_run,3 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..ce271939b18c --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_cpu_inductor_torchbench_inference.csv @@ -0,0 +1,301 @@ +name,accuracy,graph_breaks + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,model_fail_to_load,0 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fcos_r_50_fpn,pass,23 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5_base,pass,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,pass,2 + + + +mobilenet_v3_large,pass,0 + + + +moco,model_fail_to_load,0 + + + +moondream,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,pass,2 + + + +resnext50_32x4d,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientdet,model_fail_to_load,0 + + + +timm_efficientnet,pass,0 + + + +timm_nfnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,3 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,28 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv new file mode 100644 index 000000000000..08dad9b4a06a --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,6 + + + +BartForConditionalGeneration,pass,8 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,6 + + + +BlenderbotSmallForConditionalGeneration,pass,8 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,6 + + + +MBartForConditionalGeneration,pass,8 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,6 + + + +PLBartForCausalLM,pass,6 + + + +PLBartForConditionalGeneration,pass,8 + + + +PegasusForCausalLM,pass,6 + + + +PegasusForConditionalGeneration,pass,7 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,6 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,6 + + + +XGLMForCausalLM,pass,6 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv new file mode 100644 index 000000000000..ae860db793c9 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,fail_accuracy,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,pass,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv new file mode 100644 index 000000000000..3f60be5afd97 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_inference.csv @@ -0,0 +1,377 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,fail_accuracy,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv new file mode 100644 index 000000000000..ee58808c0bb0 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv @@ -0,0 +1,285 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,fail_to_run,3 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,fail_to_run,4 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv new file mode 100644 index 000000000000..a5e00513153d --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,12 + + + +BartForConditionalGeneration,pass,24 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,12 + + + +BlenderbotSmallForConditionalGeneration,pass,24 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,12 + + + +MBartForConditionalGeneration,pass,24 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,12 + + + +PLBartForCausalLM,pass,12 + + + +PLBartForConditionalGeneration,pass,29 + + + +PegasusForCausalLM,pass,12 + + + +PegasusForConditionalGeneration,pass,23 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,12 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,12 + + + +XGLMForCausalLM,pass,12 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv new file mode 100644 index 000000000000..e5464160d32f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,pass,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,pass,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv new file mode 100644 index 000000000000..20fb340690ac --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_inference.csv @@ -0,0 +1,381 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,pass,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +sam_fast,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv new file mode 100644 index 000000000000..cfc524426644 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv @@ -0,0 +1,289 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,pass,6 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +speech_transformer,pass,16 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv new file mode 100644 index 000000000000..349239b058a7 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_inference.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,0 + + + +AlbertForQuestionAnswering,pass,0 + + + +AllenaiLongformerBase,pass,4 + + + +BartForCausalLM,pass,0 + + + +BartForConditionalGeneration,pass,0 + + + +BertForMaskedLM,pass,0 + + + +BertForQuestionAnswering,pass,0 + + + +BlenderbotForCausalLM,pass_due_to_skip,0 + + + +BlenderbotSmallForCausalLM,pass,0 + + + +BlenderbotSmallForConditionalGeneration,pass,0 + + + +CamemBert,pass,0 + + + +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,pass,0 + + + +DistilBertForMaskedLM,pass,0 + + + +DistilBertForQuestionAnswering,pass,0 + + + +DistillGPT2,pass,0 + + + +ElectraForCausalLM,pass,0 + + + +ElectraForQuestionAnswering,pass,0 + + + +GPT2ForSequenceClassification,pass,2 + + + +GoogleFnet,pass,0 + + + +LayoutLMForMaskedLM,pass,0 + + + +LayoutLMForSequenceClassification,pass,2 + + + +M2M100ForConditionalGeneration,pass,0 + + + +MBartForCausalLM,pass,0 + + + +MBartForConditionalGeneration,pass,0 + + + +MT5ForConditionalGeneration,pass,0 + + + +MegatronBertForCausalLM,pass,0 + + + +MegatronBertForQuestionAnswering,pass,0 + + + +MobileBertForMaskedLM,pass,0 + + + +MobileBertForQuestionAnswering,pass,0 + + + +OPTForCausalLM,pass,0 + + + +PLBartForCausalLM,pass,0 + + + +PLBartForConditionalGeneration,pass,0 + + + +PegasusForCausalLM,pass,0 + + + +PegasusForConditionalGeneration,pass,0 + + + +RobertaForCausalLM,pass,0 + + + +RobertaForQuestionAnswering,pass,0 + + + +Speech2Text2ForCausalLM,pass,0 + + + +T5ForConditionalGeneration,pass,0 + + + +T5Small,pass,0 + + + +TrOCRForCausalLM,pass,0 + + + +XGLMForCausalLM,pass,0 + + + +XLNetLMHeadModel,pass,0 + + + +YituTechConvBert,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv new file mode 100644 index 000000000000..08dad9b4a06a --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_huggingface_training.csv @@ -0,0 +1,185 @@ +name,accuracy,graph_breaks + + + +AlbertForMaskedLM,pass,4 + + + +AlbertForQuestionAnswering,pass,5 + + + +AllenaiLongformerBase,pass,9 + + + +BartForCausalLM,pass,6 + + + +BartForConditionalGeneration,pass,8 + + + +BertForMaskedLM,pass,5 + + + +BertForQuestionAnswering,pass,5 + + + +BlenderbotForCausalLM,eager_fail_to_run,0 + + + +BlenderbotSmallForCausalLM,pass,6 + + + +BlenderbotSmallForConditionalGeneration,pass,8 + + + +CamemBert,pass,5 + + + +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +DebertaV2ForMaskedLM,pass_due_to_skip,0 + + + +DebertaV2ForQuestionAnswering,eager_1st_run_OOM,0 + + + +DistilBertForMaskedLM,pass,5 + + + +DistilBertForQuestionAnswering,pass,5 + + + +DistillGPT2,pass,5 + + + +ElectraForCausalLM,pass,4 + + + +ElectraForQuestionAnswering,pass,5 + + + +GPT2ForSequenceClassification,pass,7 + + + +GoogleFnet,pass,5 + + + +LayoutLMForMaskedLM,pass,5 + + + +LayoutLMForSequenceClassification,pass,7 + + + +M2M100ForConditionalGeneration,pass,4 + + + +MBartForCausalLM,pass,6 + + + +MBartForConditionalGeneration,pass,8 + + + +MT5ForConditionalGeneration,pass,5 + + + +MegatronBertForCausalLM,pass,5 + + + +MegatronBertForQuestionAnswering,pass,5 + + + +MobileBertForMaskedLM,pass,3 + + + +MobileBertForQuestionAnswering,pass,3 + + + +OPTForCausalLM,pass,6 + + + +PLBartForCausalLM,pass,6 + + + +PLBartForConditionalGeneration,pass,8 + + + +PegasusForCausalLM,pass,6 + + + +PegasusForConditionalGeneration,pass,7 + + + +RobertaForCausalLM,pass,5 + + + +RobertaForQuestionAnswering,pass,5 + + + +Speech2Text2ForCausalLM,pass,6 + + + +T5ForConditionalGeneration,pass,5 + + + +T5Small,pass,5 + + + +TrOCRForCausalLM,pass,6 + + + +XGLMForCausalLM,pass,6 + + + +XLNetLMHeadModel,pass,5 + + + +YituTechConvBert,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv new file mode 100644 index 000000000000..c889ba0e8d2f --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_inference.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,0 + + + +beit_base_patch16_224,pass,0 + + + +botnet26t_256,pass,0 + + + +cait_m36_384,pass,0 + + + +coat_lite_mini,pass,0 + + + +convit_base,pass,0 + + + +convmixer_768_32,pass,0 + + + +convnext_base,pass,0 + + + +crossvit_9_240,pass,0 + + + +cspdarknet53,pass,0 + + + +deit_base_distilled_patch16_224,pass,0 + + + +dla102,pass,0 + + + +dm_nfnet_f0,pass,0 + + + +dpn107,pass,0 + + + +eca_botnext26ts_256,pass,0 + + + +eca_halonext26ts,pass,0 + + + +ese_vovnet19b_dw,pass,0 + + + +fbnetc_100,pass,0 + + + +fbnetv3_b,pass,0 + + + +gernet_l,pass,0 + + + +ghostnet_100,pass,0 + + + +gluon_inception_v3,pass,0 + + + +gmixer_24_224,pass,0 + + + +gmlp_s16_224,pass,0 + + + +hrnet_w18,pass,0 + + + +inception_v3,pass,0 + + + +jx_nest_base,pass,0 + + + +lcnet_050,pass,0 + + + +levit_128,pass,0 + + + +mixer_b16_224,pass,0 + + + +mixnet_l,pass,0 + + + +mnasnet_100,pass,0 + + + +mobilenetv2_100,pass,0 + + + +mobilenetv3_large_100,pass,0 + + + +mobilevit_s,pass,0 + + + +nfnet_l0,pass,0 + + + +pit_b_224,pass,0 + + + +pnasnet5large,pass,0 + + + +poolformer_m36,pass,0 + + + +regnety_002,pass,0 + + + +repvgg_a2,pass,0 + + + +res2net101_26w_4s,pass,0 + + + +res2net50_14w_8s,pass,0 + + + +res2next50,pass,0 + + + +resmlp_12_224,pass,0 + + + +resnest101e,pass,0 + + + +rexnet_100,pass,0 + + + +sebotnet33ts_256,pass,0 + + + +selecsls42b,pass,0 + + + +spnasnet_100,pass,0 + + + +swin_base_patch4_window7_224,pass,0 + + + +swsl_resnext101_32x16d,pass,0 + + + +tf_efficientnet_b0,pass,0 + + + +tf_mixnet_l,pass,0 + + + +tinynet_a,pass,0 + + + +tnt_s_patch16_224,pass,0 + + + +twins_pcpvt_base,pass,0 + + + +visformer_small,pass,0 + + + +vit_base_patch16_224,pass,0 + + + +volo_d1_224,pass,0 + + + +xcit_large_24_p8_224,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv new file mode 100644 index 000000000000..ae860db793c9 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_timm_training.csv @@ -0,0 +1,245 @@ +name,accuracy,graph_breaks + + + +adv_inception_v3,pass,6 + + + +beit_base_patch16_224,fail_accuracy,7 + + + +botnet26t_256,pass,6 + + + +cait_m36_384,eager_fail_to_run,0 + + + +coat_lite_mini,pass,6 + + + +convit_base,pass,7 + + + +convmixer_768_32,pass,5 + + + +convnext_base,pass,7 + + + +crossvit_9_240,pass,7 + + + +cspdarknet53,pass,7 + + + +deit_base_distilled_patch16_224,pass,7 + + + +dla102,pass,7 + + + +dm_nfnet_f0,pass,6 + + + +dpn107,pass,6 + + + +eca_botnext26ts_256,pass,7 + + + +eca_halonext26ts,pass,7 + + + +ese_vovnet19b_dw,pass,7 + + + +fbnetc_100,pass,7 + + + +fbnetv3_b,pass,6 + + + +gernet_l,pass,6 + + + +ghostnet_100,pass,6 + + + +gluon_inception_v3,pass,7 + + + +gmixer_24_224,pass,6 + + + +gmlp_s16_224,pass,7 + + + +hrnet_w18,pass,5 + + + +inception_v3,pass,6 + + + +jx_nest_base,pass,7 + + + +lcnet_050,pass,6 + + + +levit_128,pass,7 + + + +mixer_b16_224,pass,7 + + + +mixnet_l,pass,6 + + + +mnasnet_100,pass,7 + + + +mobilenetv2_100,pass,7 + + + +mobilenetv3_large_100,pass,7 + + + +mobilevit_s,pass,6 + + + +nfnet_l0,pass,7 + + + +pit_b_224,pass,6 + + + +pnasnet5large,pass,5 + + + +poolformer_m36,pass,6 + + + +regnety_002,pass,6 + + + +repvgg_a2,pass,7 + + + +res2net101_26w_4s,pass,6 + + + +res2net50_14w_8s,pass,6 + + + +res2next50,pass,6 + + + +resmlp_12_224,pass,6 + + + +resnest101e,pass,6 + + + +rexnet_100,pass,7 + + + +sebotnet33ts_256,pass,6 + + + +selecsls42b,pass,6 + + + +spnasnet_100,pass,7 + + + +swin_base_patch4_window7_224,pass,7 + + + +swsl_resnext101_32x16d,pass,6 + + + +tf_efficientnet_b0,pass,6 + + + +tf_mixnet_l,pass,6 + + + +tinynet_a,pass,6 + + + +tnt_s_patch16_224,pass,7 + + + +twins_pcpvt_base,pass,7 + + + +visformer_small,pass,7 + + + +vit_base_patch16_224,pass,7 + + + +volo_d1_224,pass,7 + + + +xcit_large_24_p8_224,pass_due_to_skip,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv new file mode 100644 index 000000000000..108bc6543aa9 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_inference.csv @@ -0,0 +1,381 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,eager_fail_to_run,0 + + + +BERT_pytorch,pass,0 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,pass,12 + + + +LearningToPaint,pass,0 + + + +Super_SloMo,pass,0 + + + +alexnet,pass,0 + + + +basic_gnn_edgecnn,pass,0 + + + +basic_gnn_gcn,pass,6 + + + +basic_gnn_gin,pass,0 + + + +basic_gnn_sage,pass,0 + + + +cm3leon_generate,pass,4 + + + +dcgan,pass,0 + + + +demucs,pass,3 + + + +densenet121,pass,0 + + + +detectron2_fasterrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_dc5,eager_fail_to_run,0 + + + +detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 + + + +detectron2_fcos_r_50_fpn,pass,21 + + + +detectron2_maskrcnn_r_101_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_101_fpn,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +detectron2_maskrcnn_r_50_fpn,eager_fail_to_run,0 + + + +dlrm,pass,0 + + + +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 + + + +drq,pass,0 + + + +fastNLP_Bert,pass,4 + + + +functorch_dp_cifar10,pass,0 + + + +functorch_maml_omniglot,pass,0 + + + +hf_Albert,pass,0 + + + +hf_Bart,pass,0 + + + +hf_Bert,pass,0 + + + +hf_Bert_large,pass,0 + + + +hf_BigBird,fail_accuracy,46 + + + +hf_DistilBert,pass,0 + + + +hf_GPT2,pass,0 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,5 + + + +hf_T5,pass,0 + + + +hf_T5_base,eager_fail_to_run,0 + + + +hf_T5_generate,pass,5 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,0 + + + +hf_distil_whisper,pass,0 + + + +lennard_jones,pass,0 + + + +llama,pass,0 + + + +llama_v2_7b_16h,model_fail_to_load,0 + + + +llava,model_fail_to_load,0 + + + +maml,pass_due_to_skip,0 + + + +maml_omniglot,pass,0 + + + +mnasnet1_0,pass,0 + + + +mobilenet_v2,pass,0 + + + +mobilenet_v2_quantized_qat,model_fail_to_load,0 + + + +mobilenet_v3_large,pass,0 + + + +moco,pass,5 + + + +moondream,model_fail_to_load,0 + + + +nanogpt,pass,0 + + + +nvidia_deeprecommender,pass,0 + + + +opacus_cifar10,pass,0 + + + +phlippe_densenet,pass,0 + + + +phlippe_resnet,pass,0 + + + +pyhpc_equation_of_state,pass,0 + + + +pyhpc_isoneutral_mixing,pass,0 + + + +pyhpc_turbulent_kinetic_energy,pass,0 + + + +pytorch_CycleGAN_and_pix2pix,pass,0 + + + +pytorch_stargan,pass,0 + + + +pytorch_unet,pass,0 + + + +resnet152,pass,0 + + + +resnet18,pass,0 + + + +resnet50,pass,0 + + + +resnet50_quantized_qat,model_fail_to_load,0 + + + +resnext50_32x4d,pass,0 + + + +sam,pass,0 + + + +sam_fast,pass,0 + + + +shufflenet_v2_x1_0,pass,0 + + + +soft_actor_critic,pass,0 + + + +speech_transformer,pass,10 + + + +squeezenet1_1,pass,0 + + + +stable_diffusion_text_encoder,pass,0 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,0 + + + +timm_regnet,pass,0 + + + +timm_resnest,pass,0 + + + +timm_vision_transformer,pass,0 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,0 + + + +torch_multimodal_clip,pass,0 + + + +tts_angular,pass,2 + + + +vgg16,pass,0 + + + +vision_maskrcnn,pass,17 + + + +yolov3,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv new file mode 100644 index 000000000000..cfc524426644 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -0,0 +1,289 @@ +name,accuracy,graph_breaks + + + +torchrec_dlrm,pass,6 + + + +BERT_pytorch,pass,6 + + + +Background_Matting,pass_due_to_skip,0 + + + +DALLE2_pytorch,eager_fail_to_run,0 + + + +LearningToPaint,pass,6 + + + +Super_SloMo,pass,7 + + + +alexnet,pass,6 + + + +basic_gnn_edgecnn,pass,22 + + + +basic_gnn_gcn,pass,13 + + + +basic_gnn_gin,pass,7 + + + +basic_gnn_sage,pass,7 + + + +dcgan,pass,6 + + + +demucs,pass,9 + + + +densenet121,pass,6 + + + +detectron2_maskrcnn_r_50_c4,eager_fail_to_run,0 + + + +dlrm,pass,6 + + + +drq,pass,6 + + + +fastNLP_Bert,pass,10 + + + +functorch_dp_cifar10,pass,7 + + + +functorch_maml_omniglot,pass,7 + + + +hf_Albert,pass,6 + + + +hf_Bart,pass,6 + + + +hf_Bert,pass,6 + + + +hf_Bert_large,pass,6 + + + +hf_BigBird,pass,52 + + + +hf_DistilBert,pass,6 + + + +hf_GPT2,pass,6 + + + +hf_GPT2_large,pass_due_to_skip,0 + + + +hf_Reformer,pass,26 + + + +hf_T5_base,eager_2nd_run_OOM,0 + + + +hf_T5_large,pass_due_to_skip,0 + + + +hf_Whisper,pass,6 + + + +hf_distil_whisper,model_fail_to_load,0 + + + +lennard_jones,pass,7 + + + +llava,model_fail_to_load,0 + + + +maml_omniglot,pass,7 + + + +mnasnet1_0,pass,7 + + + +mobilenet_v2,pass,6 + + + +mobilenet_v2_quantized_qat,eager_fail_to_run,0 + + + +mobilenet_v3_large,pass,7 + + + +moco,pass,11 + + + +nanogpt,pass,7 + + + +nvidia_deeprecommender,pass,7 + + + +opacus_cifar10,eager_fail_to_run,0 + + + +phlippe_densenet,pass,6 + + + +phlippe_resnet,pass,6 + + + +pytorch_CycleGAN_and_pix2pix,pass,6 + + + +pytorch_stargan,pass,6 + + + +pytorch_unet,pass_due_to_skip,7 + + + +resnet152,pass,7 + + + +resnet18,pass,6 + + + +resnet50,pass,6 + + + +resnet50_quantized_qat,eager_fail_to_run,0 + + + +resnext50_32x4d,pass,7 + + + +sam,eager_fail_to_run,0 + + + +shufflenet_v2_x1_0,pass,6 + + + +soft_actor_critic,pass,6 + + + +speech_transformer,pass,16 + + + +squeezenet1_1,pass,6 + + + +stable_diffusion_text_encoder,pass,5 + + + +stable_diffusion_unet,pass_due_to_skip,0 + + + +timm_efficientnet,pass,7 + + + +timm_regnet,pass,6 + + + +timm_resnest,pass,7 + + + +timm_vision_transformer,pass,6 + + + +timm_vision_transformer_large,pass_due_to_skip,0 + + + +timm_vovnet,pass,6 + + + +torch_multimodal_clip,pass,7 + + + +tts_angular,pass,9 + + + +vgg16,pass,6 + + + +vision_maskrcnn,pass,34 + + + +yolov3,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py new file mode 100644 index 000000000000..5d73cf658c17 --- /dev/null +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/update_expected.py @@ -0,0 +1,172 @@ +""" +Update commited CSV files used as reference points by dynamo/inductor CI. + +Currently only cares about graph breaks, so only saves those columns. + +Hardcodes a list of job names and artifacts per job, but builds the lookup +by querying github sha and finding associated github actions workflow ID and CI jobs, +downloading artifact zips, extracting CSVs and filtering them. + +Usage: + +python benchmarks/dynamo/ci_expected_accuracy.py + +Known limitations: +- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one +""" + +import argparse +import json +import os +import pathlib +import subprocess +import sys +import urllib +from io import BytesIO +from itertools import product +from urllib.request import urlopen +from zipfile import ZipFile + +import pandas as pd +import requests + +# Note: the public query url targets this rockset lambda: +# https://console.rockset.com/lambdas/details/commons.artifacts +ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35" +CSV_LINTER = str( + pathlib.Path(__file__).absolute().parent.parent.parent.parent + / "tools/linter/adapters/no_merge_conflict_csv_linter.py" +) + + +def query_job_sha(repo, sha): + params = { + "parameters": [ + {"name": "sha", "type": "string", "value": sha}, + {"name": "repo", "type": "string", "value": repo}, + ] + } + + r = requests.post(url=ARTIFACTS_QUERY_URL, json=params) + data = r.json() + return data["results"] + + +def parse_job_name(job_str): + return (part.strip() for part in job_str.split("/")) + + +def parse_test_str(test_str): + return (part.strip() for part in test_str[6:].strip(")").split(",")) + + +S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com" + + +def get_artifacts_urls(results, suites): + urls = {} + for r in results: + if ( + r["workflowName"] in ("inductor", "inductor-periodic") + and "test" in r["jobName"] + ): + config_str, test_str = parse_job_name(r["jobName"]) + suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str) + workflowId = r["workflowId"] + id = r["id"] + runAttempt = r["runAttempt"] + + if suite in suites: + artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip" + s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}" + urls[(suite, int(shard_id))] = s3_url + print(f"{suite} {shard_id}, {num_shards}: {s3_url}") + return urls + + +def normalize_suite_filename(suite_name): + strs = suite_name.split("_") + subsuite = strs[-1] + if "timm" in subsuite: + subsuite = subsuite.replace("timm", "timm_models") + + return subsuite + + +def download_artifacts_and_extract_csvs(urls): + dataframes = {} + for (suite, shard), url in urls.items(): + try: + resp = urlopen(url) + subsuite = normalize_suite_filename(suite) + artifact = ZipFile(BytesIO(resp.read())) + for phase in ("training", "inference"): + name = f"test/test-reports/{phase}_{subsuite}.csv" + try: + df = pd.read_csv(artifact.open(name)) + df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int) + prev_df = dataframes.get((suite, phase), None) + dataframes[(suite, phase)] = ( + pd.concat([prev_df, df]) if prev_df is not None else df + ) + except KeyError: + print( + f"Warning: Unable to find {name} in artifacts file from {url}, continuing" + ) + except urllib.error.HTTPError: + print(f"Unable to download {url}, perhaps the CI job isn't finished?") + + return dataframes + + +def write_filtered_csvs(root_path, dataframes): + for (suite, phase), df in dataframes.items(): + out_fn = os.path.join(root_path, f"{suite}_{phase}.csv") + df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"]) + apply_lints(out_fn) + + +def apply_lints(filename): + patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename])) + if patch.get("replacement"): + with open(filename) as fd: + data = fd.read().replace(patch["original"], patch["replacement"]) + with open(filename, "w") as fd: + fd.write(data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument("sha") + args = parser.parse_args() + + repo = "pytorch/pytorch" + + suites = { + f"{a}_{b}" + for a, b in product( + [ + "aot_eager", + "aot_inductor", + "cpu_inductor", + "dynamic_aot_eager", + "dynamic_cpu_inductor", + "dynamic_inductor", + "dynamo_eager", + "inductor", + ], + ["huggingface", "timm", "torchbench"], + ) + } + + root_path = "benchmarks/dynamo/ci_expected_accuracy/" + assert os.path.exists(root_path), f"cd and ensure {root_path} exists" + + results = query_job_sha(repo, args.sha) + urls = get_artifacts_urls(results, suites) + dataframes = download_artifacts_and_extract_csvs(urls) + write_filtered_csvs(root_path, dataframes) + print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.") diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv index 1def1d99bd53..fe7efa082cea 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_timm_training.csv @@ -218,7 +218,7 @@ tf_mixnet_l,pass,6 -tinynet_a,pass,6 +tinynet_a,fail_accuracy,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 431a91d10669..3aecea06b530 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,46 +hf_BigBird,pass,43 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,fail_to_run,5 +hf_T5_generate,pass,5 @@ -374,4 +374,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 1e1a4be4149e..c87a07a8c294 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 @@ -282,4 +282,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,fail_accuracy,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index ce271939b18c..5ffc870a8dec 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -298,4 +298,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv index e5464160d32f..ae860db793c9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_timm_training.csv @@ -6,7 +6,7 @@ adv_inception_v3,pass,6 -beit_base_patch16_224,pass,7 +beit_base_patch16_224,fail_accuracy,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index f652e5ffa91a..c167ea680d2c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,46 +hf_BigBird,fail_accuracy,43 @@ -178,7 +178,7 @@ hf_T5_base,eager_fail_to_run,0 -hf_T5_generate,fail_to_run,5 +hf_T5_generate,pass,5 @@ -374,4 +374,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index ee58808c0bb0..c25fa9471337 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 @@ -282,4 +282,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 20fb340690ac..9863aa7da6a2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,46 +hf_BigBird,pass,43 @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index cfc524426644..4055eda462c5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv index a5e00513153d..08dad9b4a06a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv @@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,12 +BartForCausalLM,pass,6 -BartForConditionalGeneration,pass,24 +BartForConditionalGeneration,pass,8 @@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 -BlenderbotSmallForCausalLM,pass,12 +BlenderbotSmallForCausalLM,pass,6 -BlenderbotSmallForConditionalGeneration,pass,24 +BlenderbotSmallForConditionalGeneration,pass,8 @@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,12 +MBartForCausalLM,pass,6 -MBartForConditionalGeneration,pass,24 +MBartForConditionalGeneration,pass,8 @@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 -OPTForCausalLM,pass,12 +OPTForCausalLM,pass,6 -PLBartForCausalLM,pass,12 +PLBartForCausalLM,pass,6 -PLBartForConditionalGeneration,pass,29 +PLBartForConditionalGeneration,pass,8 -PegasusForCausalLM,pass,12 +PegasusForCausalLM,pass,6 -PegasusForConditionalGeneration,pass,23 +PegasusForConditionalGeneration,pass,7 @@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 -Speech2Text2ForCausalLM,pass,12 +Speech2Text2ForCausalLM,pass,6 @@ -170,11 +170,11 @@ T5Small,pass,5 -TrOCRForCausalLM,pass,12 +TrOCRForCausalLM,pass,6 -XGLMForCausalLM,pass,12 +XGLMForCausalLM,pass,6 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv index e5464160d32f..ae860db793c9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv @@ -6,7 +6,7 @@ adv_inception_v3,pass,6 -beit_base_patch16_224,pass,7 +beit_base_patch16_224,fail_accuracy,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 108bc6543aa9..74549205d747 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,46 +hf_BigBird,fail_accuracy,43 @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index cfc524426644..4055eda462c5 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 466e6b30d0b1..154651d4fbb7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -143,6 +143,7 @@ class CI(NamedTuple): "pyhpc_equation_of_state", "pyhpc_turbulent_kinetic_energy", "detectron2_fcos_r_50_fpn", + "hf_T5_generate", } # These models currently fail accuracy with eager Adam optimizer @@ -1183,12 +1184,14 @@ def load(cls, model, example_inputs, device): else: _register_dataclass_output_as_pytree(example_outputs) - gm = torch.export._trace._export( + # TODO(angelayi): change this to predispatch + # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing + # to predispatch to avoid performance regressions + gm = torch.export._trace._export_to_torch_ir( model, example_args, example_kwargs, - pre_dispatch=True, - ).module() + ) with torch.no_grad(): so_path = torch._inductor.aot_compile( gm, example_args, example_kwargs @@ -3974,9 +3977,12 @@ def run(runner, args, original_dir=None): assert "cuda" in args.devices, "Quantization requires CUDA device." assert args.bfloat16, "Quantization requires dtype bfloat16." try: - from .torchao_backend import setup_baseline, torchao_optimize_ctx - except ImportError: from torchao_backend import setup_baseline, torchao_optimize_ctx + except ImportError: + from userbenchmark.dynamo.dynamobench.torchao_backend import ( + setup_baseline, + torchao_optimize_ctx, + ) setup_baseline() baseline_ctx = functools.partial( diff --git a/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv b/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv index f2f8c1b26176..e26d3b97864f 100644 --- a/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv +++ b/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv @@ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316 basic_gnn_gcn,float32,dynamic,default,1.074576405 llama_v2_7b_16h,float32,dynamic,default,1.211740245 resnet50,float32,dynamic,default,1.65984261 -timm_efficientnet,float32,static,cpp,2.271561735 +#timm_efficientnet,float32,static,cpp,2.1938112 mobilenet_v3_large,float32,static,cpp,2.63375628 timm_resnest,float32,dynamic,cpp,1.67998548 pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463 -#hf_GPT2,float32,dynamic,cpp, -hf_GPT2,float32,dynamic,cpp,1.379885175 +#hf_GPT2,float32,dynamic,cpp,1.292704418 resnext50_32x4d,amp,static,default,1.461687045 vgg16,amp,static,default,1.267194285 hf_Longformer,amp,dynamic,default,0.997006035 @@ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146 llama,amp,static,default,1.32950568 timm_regnet,amp,static,cpp,1.157188305 lennard_jones,amp,static,cpp,2.240104485 -hf_T5_generate,amp,dynamic,cpp,1.447656135 +#hf_T5_generate,amp,dynamic,cpp,1.29339502 timm_vovnet,amp,dynamic,cpp,1.07856471 mobilenet_v2,amp,dynamic,cpp,2.27774577 diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 5e139783c196..dca2915a07b2 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -7,7 +7,10 @@ import sys import warnings -from common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state +try: + from .common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state +except ImportError: + from common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state import torch diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 54cff5658257..bc42b6566706 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -776,12 +776,18 @@ def extract_df(self, metric, testing): if not perf_row.empty: if acc_row.empty: perf_row[compiler] = 0.0 + elif acc_row[compiler].iloc[0] in ( + "model_fail_to_load", + "eager_fail_to_run", + ): + perf_row = pd.DataFrame() elif acc_row[compiler].iloc[0] not in ( "pass", "pass_due_to_skip", ): perf_row[compiler] = 0.0 - perf_rows.append(perf_row) + if not perf_row.empty: + perf_rows.append(perf_row) df = pd.concat(perf_rows) df = df.sort_values(by=list(reversed(self.compilers)), ascending=False) diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index db29a9bf365a..60a7cc81c06f 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -7,7 +7,10 @@ import sys import warnings -from common import BenchmarkRunner, download_retry_decorator, main +try: + from .common import BenchmarkRunner, download_retry_decorator, main +except ImportError: + from common import BenchmarkRunner, download_retry_decorator, main import torch @@ -71,8 +74,10 @@ def pip_install(package): "hrnet_w18", "inception_v3", "mixer_b16_224", + "mobilenetv3_large_100", "sebotnet33ts_256", "selecsls42b", + "cspdarknet53", } REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { diff --git a/benchmarks/framework_overhead_benchmark/C2Module.py b/benchmarks/framework_overhead_benchmark/C2Module.py deleted file mode 100644 index 0b93836e5940..000000000000 --- a/benchmarks/framework_overhead_benchmark/C2Module.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np - -from utils import NUM_LOOP_ITERS - -from caffe2.python import core, workspace - -workspace.GlobalInit(["caffe2"]) - - -def add_blob(ws, blob_name, tensor_size): - blob_tensor = np.random.randn(*tensor_size).astype(np.float32) - ws.FeedBlob(blob_name, blob_tensor) - - -class C2SimpleNet: - """ - This module constructs a net with 'op_name' operator. The net consist - a series of such operator. - It initializes the workspace with input blob equal to the number of parameters - needed for the op. - Provides forward method to run the net niter times. - """ - - def __init__(self, op_name, num_inputs=1, debug=False): - self.input_names = [] - self.net = core.Net("framework_benchmark_net") - self.input_names = [f"in_{i}" for i in range(num_inputs)] - for i in range(num_inputs): - add_blob(workspace, self.input_names[i], [1]) - self.net.AddExternalInputs(self.input_names) - op_constructor = getattr(self.net, op_name) - op_constructor(self.input_names) - self.output_name = self.net._net.op[-1].output - print(f"Benchmarking op {op_name}:") - for _ in range(NUM_LOOP_ITERS): - output_name = self.net._net.op[-1].output - self.input_names[-1] = output_name[0] - assert len(self.input_names) == num_inputs - op_constructor(self.input_names) - workspace.CreateNet(self.net) - if debug: - print(self.net._net) - - def forward(self, niters): - workspace.RunNet(self.net, niters, False) diff --git a/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py b/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py index 8d1b52738522..826c4d283ee8 100644 --- a/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py +++ b/benchmarks/framework_overhead_benchmark/framework_overhead_benchmark.py @@ -1,6 +1,5 @@ import argparse -from C2Module import C2SimpleNet from pt_wrapper_module import WrapperModule from SimpleAddModule import add_tensors_loop, SimpleAddModule @@ -19,9 +18,6 @@ --add-op --graph-mode --eager-mode (Runs both graph mode and eager mode) buck run @mode/opt :framework_overhead_benchmark -- --add-op --graph-mode (Runs only graph mode) -To run C2 benchmark: -buck run @mode/opt :framework_overhead_benchmark -- - --add-op --benchmark-c2-net """ SUPPORTED_OPS = {"add_op"} @@ -49,39 +45,22 @@ def benchmark_simple_fn(args, config, module_config, module_type, result): module_type: Type of the module to be wrapped. e.g. SimpleAddModule for add op. result: dictionary instance to be populated with the benchmark result (latency per iter). """ - benchmark_c2_net = args.benchmark_c2_net print(f"Benchmarking {module_type.__name__}") - if benchmark_c2_net: - op_name = module_config.c2_op - num_inputs = module_config.num_params - module = C2SimpleNet(op_name, num_inputs=num_inputs, debug=args.debug) - latency_per_iter_ms = benchmark_module(config, module) - result[op_name] = latency_per_iter_ms - else: - f_name = ( - module_config.pt_fn.__name__ - + ":Num Operands=" - + str(module_config.num_params) - ) - graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode) - result_key = ",".join((f_name, graph_mode_str)) - module = WrapperModule(module_type, module_config, args.debug, args.save) - latency_per_iter_ms = benchmark_module( - config, module, args.use_throughput_benchmark - ) - result[result_key] = latency_per_iter_ms + f_name = ( + module_config.pt_fn.__name__ + ":Num Operands=" + str(module_config.num_params) + ) + graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode) + result_key = ",".join((f_name, graph_mode_str)) + module = WrapperModule(module_type, module_config, args.debug, args.save) + latency_per_iter_ms = benchmark_module( + config, module, args.use_throughput_benchmark + ) + result[result_key] = latency_per_iter_ms def main(): parser = argparse.ArgumentParser() parser.add_argument("--op", default="add_op", dest="op", type=str) - parser.add_argument( - "--benchmark-c2-net", - "--benchmark_c2_net", - default=False, - dest="benchmark_c2_net", - action="store_true", - ) parser.add_argument( "--use-throughput-benchmark", "--use_throughput_benchmark", @@ -107,10 +86,6 @@ def main(): if args.op not in SUPPORTED_OPS: print(f"Op {args.op} is not supported: Supported ops are:{SUPPORTED_OPS}") return - assert not ( - args.benchmark_c2_net and args.use_throughput_benchmark - ), "Benchmarking of C2 net via throughput benchmarking is not yet supported" - num_warmup_iters = args.num_warmup_iters num_iters = args.num_iters config = BenchmarkConfig(num_warmup_iters, num_iters) @@ -120,10 +95,7 @@ def main(): result = {} if args.op == "add_op": num_params = 2 - if args.benchmark_c2_net: - module_config = ModuleConfig(None, "Sum", num_params, None) - else: - module_config = ModuleConfig(add_tensors_loop, None, num_params, graph_mode) + module_config = ModuleConfig(add_tensors_loop, None, num_params, graph_mode) benchmark_simple_fn(args, config, module_config, SimpleAddModule, result) print_results(result) diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py index e63db1d7cc02..04dc8969e329 100644 --- a/benchmarks/functional_autograd_benchmark/torchaudio_models.py +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -219,7 +219,7 @@ def forward(self, x, output_lengths): class Lookahead(nn.Module): - # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks + # Wang et al., 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks # input shape - sequence, batch, feature - TxNxH # output shape - same as input def __init__(self, n_features, context): diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 083c98e4a92b..6e335ee31292 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -1,257 +1,74 @@ import argparse import csv import dataclasses -import itertools import os import time -from typing import Optional, Tuple -from mixtral_moe_model import Transformer as MixtralMoE -from mixtral_moe_quantize import ( - WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, -) -from model import Transformer as LLaMA -from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler +from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8 import torch -import torch._inductor.config - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future -torch._inductor.config.assert_indirect_indexing = False +import torch.nn as nn @dataclasses.dataclass class Experiment: name: str - module: type - mode: Optional[str] - quantizer: type - token_per_sec: float - memory_bandwidth: float - - -# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. -all_experiments = { - "llama-7b-fp16": Experiment( - "Llama-2-7b-chat-hf", - LLaMA, - "bfloat16", - LLaMAWeightOnlyInt8QuantHandler, - 94, - 1253, - ), - "llama-7b-int8": Experiment( - "Llama-2-7b-chat-hf", - LLaMA, - "int8", - LLaMAWeightOnlyInt8QuantHandler, - 144, - 957, - ), - "mixtral-int8": Experiment( # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. - "Mixtral-8x7B-v0.1", - MixtralMoE, - "int8", - MixtralMoEWeightOnlyInt8QuantHandler, - 175, - 4129, - ), -} - -DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv" - - -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif "cpu" in device: - pass - else: - print(f"device={device} is not yet suppported") - - -def multinomial_sample_one_no_sync( - probs_sort, -): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - - -@torch.compile(fullgraph=True) -def prefill( - model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> torch.Tensor: - # input_pos: [B, S] - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs)[0] - - -@torch.compile(fullgraph=True, mode="reduce-overhead") -def decode_one_token( - model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs) - - -def decode_n_tokens( - model: torch.nn.Module, - cur_token: torch.Tensor, - input_pos: torch.Tensor, - num_new_tokens: int, - **sampling_kwargs, -): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH - ): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) - - return new_tokens, new_probs + metric: str + target: float + actual: float -@torch.no_grad() -def generate( - model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs -) -> torch.Tensor: - device, dtype = prompt.device, prompt.dtype - T = prompt.size(0) - T_new = T + max_new_tokens - max_seq_length = min(T_new, model.config.block_size) - - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - seq[T] = next_token +def do_inference(mod, x, num_samples: int = 5): + total_time = 0 + start = -1 - input_pos = torch.tensor([T], device=device, dtype=torch.int) + for i in range(start, num_samples): + torch.cuda.synchronize("cuda") - generated_tokens, _ = decode_n_tokens( - model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs - ) - seq[T + 1 :] = torch.cat(generated_tokens) - return seq + t0 = time.perf_counter() + mod(x) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue -def _load_model(x: Experiment, device="cuda", precision=torch.bfloat16): - with torch.device("meta"): - model = x.module.from_name(x.name) - model = model.to(dtype=precision) + torch.cuda.synchronize("cuda") + total_time += time.perf_counter() - t0 - if x.mode == "int8": - print("Using int8 weight-only quantization!") - model = x.quantizer(model).convert_for_runtime() + total_time = total_time / num_samples - state_dict = model.state_dict() - for k, v in state_dict.items(): - state_dict[k] = torch.nn.Parameter( - torch.randn(v.shape, device=device).to(dtype=v.dtype), - requires_grad=v.requires_grad, - ) - model.load_state_dict(state_dict, assign=True) - return model.eval() + return total_time -def _get_model_size(model): - model_size = 0 - for name, child in model.named_children(): - if not isinstance(child, torch.nn.Embedding): - model_size += sum( +def run_multi_layer_norm(): + class MultiLayerNorm(nn.Module): + def __init__(self, num_layers, normalized_shape, eps=1e-5, bias=True): + super().__init__() + self.num_layers = num_layers + self.norm_layers = nn.ModuleList( [ - p.numel() * p.dtype.itemsize - for p in itertools.chain(child.parameters(), child.buffers()) + nn.LayerNorm(normalized_shape, eps=eps, bias=bias) + for _ in range(num_layers) ] ) - return model_size - - -def run_experiment( - x: Experiment, - num_samples: int = 5, - max_new_tokens: int = 200, - top_k: int = 200, - temperature: float = 0.8, -) -> None: - device = "cuda" - print(f"Loading model {x.name}") - t0 = time.time() - model = _load_model(x) - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - prompt = torch.tensor( - [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32 - ) - prompt_length = prompt.size(0) - torch.manual_seed(1234) - model_size = _get_model_size(model) + def forward(self, x): + for layer_norm in self.norm_layers: + x = layer_norm(x) + return x - aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} - start = -1 + mod = MultiLayerNorm(num_layers=8, normalized_shape=4096).to("cuda") + mod = torch.compile(mod) + input = torch.randn([512, 1024, 4096], dtype=torch.bfloat16, device="cuda") + inference_time = do_inference(mod, input) - for i in range(start, num_samples): - device_sync(device=device) # MKG + memory_bandwidth = input.numel() * input.dtype.itemsize / inference_time / 1e9 - t0 = time.perf_counter() - y = generate( - model, prompt, max_new_tokens, temperature=temperature, top_k=top_k + return [ + Experiment( + "multi_layer_norm", "memory_bandwidth(GB/s)", 92, f"{memory_bandwidth:.02f}" ) - - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - - device_sync(device=device) # MKG - t = time.perf_counter() - t0 - tokens_generated = y.size(0) - prompt_length - tokens_sec = tokens_generated / t - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9) - - token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() - memory_bandwidth = torch.mean( - torch.tensor(aggregate_metrics["memory_bandwidth"]) - ).item() - print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") - print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") - print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - return token_per_sec, memory_bandwidth + ] def output_csv(output_file, headers, row): @@ -275,41 +92,27 @@ def output_csv(output_file, headers, row): writer.writerow(list(line) + ["0"] * (len(headers) - len(line))) -def main(experiments=None, output_file=DEFAULT_OUTPUT_FILE): +DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv" + +all_experiments = { + # A list of GPT models: LlaMa, Mixtral, etc. + run_llama2_7b_bf16, + run_llama2_7b_int8, + run_mixtral_8x7b_int8, + # A list of micro-benchmarks. + run_multi_layer_norm, +} + + +def main(output_file=DEFAULT_OUTPUT_FILE): results = [] - if experiments is None: - experiments = all_experiments - else: - experiments = {k: v for k, v in all_experiments.items() if k in experiments} - - for x in experiments.values(): - actual_token_per_sec, actual_memory_bandwidth = run_experiment(x) - token_per_sec_pct = f"{actual_token_per_sec / x.token_per_sec * 100:.2f}%" - bandwidth_pct = f"{actual_memory_bandwidth / x.memory_bandwidth * 100:.2f}%" - results.append( - ( - x.name, - x.mode, - x.token_per_sec, - f"{actual_token_per_sec:.2f}", - token_per_sec_pct, - x.memory_bandwidth, - f"{actual_memory_bandwidth:.2f}", - bandwidth_pct, - ) - ) + for func in all_experiments: + lst = func() + for x in lst: + results.append(dataclasses.astuple(x)) - headers = [ - "name", - "mode", - "token_per_sec[target]", - "token_per_sec[actual]", - "token_per_sec[pct]", - "memory_bandwidth[target]", - "memory_bandwidth[actual]", - "memory_bandwidth[pct]", - ] + headers = [field.name for field in dataclasses.fields(Experiment)] for row in results: output_csv(output_file, headers, row) @@ -317,12 +120,6 @@ def main(experiments=None, output_file=DEFAULT_OUTPUT_FILE): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run experiments.") - parser.add_argument( - "--experiments", - nargs="*", - default=None, - help="Experiment names to run (default: all)", - ) parser.add_argument( "--output", default=DEFAULT_OUTPUT_FILE, @@ -330,4 +127,4 @@ def main(experiments=None, output_file=DEFAULT_OUTPUT_FILE): ) args = parser.parse_args() - main(experiments=args.experiments, output_file=args.output) + main(output_file=args.output) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py new file mode 100644 index 000000000000..a4e4b06c79d7 --- /dev/null +++ b/benchmarks/gpt_fast/generate.py @@ -0,0 +1,308 @@ +import dataclasses +import itertools +import time +from typing import Optional, Tuple + +from mixtral_moe_model import Transformer as MixtralMoE +from mixtral_moe_quantize import ( + WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, +) +from model import Transformer as LLaMA +from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler + +import torch +import torch._inductor.config + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._inductor.config.assert_indirect_indexing = False + + +@dataclasses.dataclass +class GPTModelConfig: + name: str + module: type + mode: Optional[str] + quantizer: type + token_per_sec: float + memory_bandwidth: float + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +@torch.compile(fullgraph=True) +def prefill( + model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +@torch.compile(fullgraph=True, mode="reduce-overhead") +def decode_one_token( + model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: torch.nn.Module, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +@torch.no_grad() +def generate( + model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs +) -> torch.Tensor: + device, dtype = prompt.device, prompt.dtype + T = prompt.size(0) + T_new = T + max_new_tokens + max_seq_length = min(T_new, model.config.block_size) + + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + + generated_tokens, _ = decode_n_tokens( + model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs + ) + seq[T + 1 :] = torch.cat(generated_tokens) + return seq + + +def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16): + with torch.device("meta"): + model = x.module.from_name(x.name) + model = model.to(dtype=precision) + + if x.mode == "int8": + print("Using int8 weight-only quantization!") + model = x.quantizer(model).convert_for_runtime() + + state_dict = model.state_dict() + for k, v in state_dict.items(): + state_dict[k] = torch.nn.Parameter( + torch.randn(v.shape, device=device).to(dtype=v.dtype), + requires_grad=v.requires_grad, + ) + model.load_state_dict(state_dict, assign=True) + return model.eval() + + +def _get_model_size(model): + model_size = 0 + for name, child in model.named_children(): + if not isinstance(child, torch.nn.Embedding): + model_size += sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(child.parameters(), child.buffers()) + ] + ) + return model_size + + +def run_experiment( + x: GPTModelConfig, + num_samples: int = 5, + max_new_tokens: int = 200, + top_k: int = 200, + temperature: float = 0.8, +) -> None: + device = "cuda" + print(f"Loading model {x.name}") + t0 = time.time() + model = _load_model(x) + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + prompt = torch.tensor( + [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32 + ) + prompt_length = prompt.size(0) + + torch.manual_seed(1234) + model_size = _get_model_size(model) + + aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} + start = -1 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + + t0 = time.perf_counter() + y = generate( + model, prompt, max_new_tokens, temperature=temperature, top_k=top_k + ) + + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9) + + token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + memory_bandwidth = torch.mean( + torch.tensor(aggregate_metrics["memory_bandwidth"]) + ).item() + print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") + print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + return token_per_sec, memory_bandwidth + + +# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. +def run_llama2_7b_bf16(): + from benchmark import Experiment + + model = GPTModelConfig( + "Llama-2-7b-chat-hf", + LLaMA, + "bfloat16", + LLaMAWeightOnlyInt8QuantHandler, + 94, + 1253, + ) + token_per_sec, memory_bandwidth = run_experiment(model) + return [ + Experiment( + "llama2_7b_bf16", + "token_per_sec", + model.token_per_sec, + f"{token_per_sec:.02f}", + ), + Experiment( + "llama2_7b_bf16", + "memory_bandwidth(GB/s)", + model.memory_bandwidth, + f"{memory_bandwidth:.02f}", + ), + ] + + +# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. +def run_llama2_7b_int8(): + from benchmark import Experiment + + model = GPTModelConfig( + "Llama-2-7b-chat-hf", + LLaMA, + "int8", + LLaMAWeightOnlyInt8QuantHandler, + 144, + 957, + ) + token_per_sec, memory_bandwidth = run_experiment(model) + return [ + Experiment( + "llama2_7b_int8", + "token_per_sec", + model.token_per_sec, + f"{token_per_sec:.02f}", + ), + Experiment( + "llama2_7b_int8", + "memory_bandwidth(GB/s)", + model.memory_bandwidth, + f"{memory_bandwidth:.02f}", + ), + ] + + +# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. +def run_mixtral_8x7b_int8(): + from benchmark import Experiment + + # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. + model = GPTModelConfig( + "Mixtral-8x7B-v0.1", + MixtralMoE, + "int8", + MixtralMoEWeightOnlyInt8QuantHandler, + 175, + 4129, + ) + token_per_sec, memory_bandwidth = run_experiment(model) + return [ + Experiment( + "mixtral_8x7b_int8", + "token_per_sec", + model.token_per_sec, + f"{token_per_sec:.02f}", + ), + Experiment( + "mixtral_8x7b_int8", + "memory_bandwidth(GB/s)", + model.memory_bandwidth, + f"{memory_bandwidth:.02f}", + ), + ] diff --git a/benchmarks/gpt_fast/micro_benchmark.py b/benchmarks/gpt_fast/micro_benchmark.py deleted file mode 100644 index 3c8f0865a244..000000000000 --- a/benchmarks/gpt_fast/micro_benchmark.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse -import dataclasses -import time - -import torch -import torch.nn as nn - - -@dataclasses.dataclass -class Experiment: - name: str - metric: str - target: float - actual: float - - -DEFAULT_OUTPUT_FILE = "micro_benchmark.csv" - - -def do_inference(mod, x, num_samples: int = 5): - total_time = 0 - start = -1 - - for i in range(start, num_samples): - torch.cuda.synchronize("cuda") - - t0 = time.perf_counter() - mod(x) - - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - - torch.cuda.synchronize("cuda") - total_time += time.perf_counter() - t0 - - total_time = total_time / num_samples - - return total_time - - -def run_multi_layernorm(): - class MultiLayerNorm(nn.Module): - def __init__(self, num_layers, normalized_shape, eps=1e-5, bias=True): - super().__init__() - self.num_layers = num_layers - self.norm_layers = nn.ModuleList( - [ - nn.LayerNorm(normalized_shape, eps=eps, bias=bias) - for _ in range(num_layers) - ] - ) - - def forward(self, x): - for layer_norm in self.norm_layers: - x = layer_norm(x) - return x - - mod = MultiLayerNorm(num_layers=8, normalized_shape=4096).to("cuda") - mod = torch.compile(mod) - input = torch.randn([512, 1024, 4096], dtype=torch.bfloat16, device="cuda") - inference_time = do_inference(mod, input) - - memory_bandwidth = input.numel() * input.dtype.itemsize / inference_time / 1e9 - - return [ - Experiment( - "multi_layer_norm", "memory_bandwidth(GB/s)", 92, f"{memory_bandwidth:.02f}" - ) - ] - - -all_experiments = { - run_multi_layernorm, -} - - -def main(output_file=DEFAULT_OUTPUT_FILE): - results = [] - - for func in all_experiments: - lst = func() - for x in lst: - results.append(dataclasses.astuple(x)) - - headers = [field.name for field in dataclasses.fields(Experiment)] - - from benchmark import output_csv - - for row in results: - output_csv(output_file, headers, row) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run experiments.") - parser.add_argument( - "--output", - default=DEFAULT_OUTPUT_FILE, - help="Set the output CSV file to save the benchmark results", - ) - args = parser.parse_args() - - main(output_file=args.output) diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 549bb137a9d3..9bcfc5d03e19 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -1,6 +1,6 @@ -# PyTorch/Caffe2 Operator Micro-benchmarks +# PyTorch Operator Micro-benchmarks -This benchmark suite provides a systemic way to measure the performance of operators for a wide range of inputs. The generated benchmark data fully characterized the performance of an operator in terms of execution time and the efficiency of the PyTorch/Caffe2 frameworks used. +This benchmark suite provides a systemic way to measure the performance of operators for a wide range of inputs. The generated benchmark data fully characterized the performance of an operator in terms of execution time and the efficiency of the PyTorch frameworks used. ## Features @@ -8,7 +8,7 @@ Key Features: 1\. Language used: Python -2\. Supported Frameworks: PyTorch and Caffe2 +2\. Supported Frameworks: PyTorch 3\. Supported PyTorch mode: eager and JIT @@ -49,7 +49,7 @@ python -m benchmark_all_test ``` ## Code to support `torch.add` in the benchmark -The following example shows the code to support `torch.add` with 27 different tests. In the subpages of this wiki, we'll step through the complete flow of adding PyTorch and Caffe2 operators to the benchmark suite. Existing benchmarks for operators are in `pt` and `c2` directories and we highly recommend putting your new operators in those locations. +The following example shows the code to support `torch.add` with 27 different tests. In the subpages of this wiki, we'll step through the complete flow of adding PyTorch operators to the benchmark suite. Existing benchmarks for operators are in the `pt` directory and we highly recommend putting your new operators in those locations. ```python add_short_configs = op_bench.cross_product_configs( @@ -77,7 +77,7 @@ op_bench.generate_pt_test(add_short_configs, AddBenchmark) The output is intended to be a human readable format. Here is an example output for `torch.add`: ``` # ---------------------------------------- -# PyTorch/Caffe2 Operator Micro-benchmarks +# PyTorch Operator Micro-benchmarks # ---------------------------------------- # Tag : short @@ -146,7 +146,7 @@ python -m pt.add_test --tag-filter long ``` ## Adding New Operators to the Benchmark Suite -In the previous sections, we gave several examples to show how to run the already available operators in the benchmark suite. In the following sections, we'll step through the complete flow of adding PyTorch and Caffe2 operators to the benchmark suite. Existing benchmarks for operators are in `pt` and `c2` directories and we highly recommend putting your new operators in those directories as well. +In the previous sections, we gave several examples to show how to run the already available operators in the benchmark suite. In the following sections, we'll step through the complete flow of adding PyTorch operators to the benchmark suite. Existing benchmarks for operators are in the `pt` directory and we highly recommend putting your new operators in those directories as well. ### Add a New PyTorch Operator Let's say you want to measure the execution time of the following operator: @@ -260,55 +260,6 @@ if __name__ == "__main__": ``` That's it. You just added a new operator to the benchmark suite! - -### Add a New Caffe2 Operator -The steps to add a new Caffe2 operator is the same as that for a PyTorch operator. The code below shows how to add Caffe2 `Add` operator: -```python -import operator_benchmark as op_bench -from caffe2.python import core - -add_long_configs = op_bench.cross_product_configs( - M=[8, 64, 128], - N=range(2, 10, 3), - K=[2 ** x for x in range(0, 3)], - tags=["long"] -) - -add_short_configs = op_bench.config_list( - attrs=[ - [8, 16, 32], - [16, 16, 64], - [64, 64, 128], - ], - attr_names=["M", "N", "K"], - tags=["short"], -) - -class AddBenchmark(op_bench.Caffe2BenchmarkBase): - - def init(self, M, N, K): - self.input_one = self.tensor(M, N, K) - self.input_two = self.tensor(M, N, K) - self.output = self.tensor(M, N, K) - self.set_module_name("add") - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - - return op - -op_bench.generate_c2_test(add_long_configs + add_short_configs, AddBenchmark) - -if __name__ == "__main__": - op_bench.benchmark_runner.main() -``` -There are two things worth mentioning in this code: -* `self.tensor` is a helper function which takes shapes and returns a Caffe2 blob. It is designed to make the tensor creation step easier compared to the standard Caffe2 way. -* `generate_c2_test` is used to register Caffe2 tests with the benchmark. - - ### Add a List of Operators In the previous sections, we introduced the steps required to add a single operator to the benchmark suite. There are scenarios where you want to extend the benchmark suite with a list of operators which can share the same inputs. For example, to benchmark `abs` and `acos` operators, you can use the same set of inputs for both. @@ -416,37 +367,3 @@ The example below shows the relevant code for that: self.input_one = torch.rand(M, N, K, requires_grad=True) generate_pt_gradient_test(long_configs + short_configs, TorchAddBenchmark) ``` -#### For Caffe2 Gradient Ops -To add Caffe2 gradient ops, we need to implement a new backward method in the benchmark class: -```python -class AddBenchmark(op_bench.Caffe2BenchmarkBase): - - def init(self, M, N, K): - self.input_one = self.tensor(M, N, K) - self.input_two = self.tensor(M, N, K) - self.input_one_grad = self.tensor(M, N, K) - self.input_two_grad = self.tensor(M, N, K) - self.output = self.tensor(M, N, K) - self.set_module_name("add") - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - - return op - - def backward(self): - grad_op = core.CreateOperator( - "AddGradient", - [self.output, self.input_one, self.input_two], - [self.input_one_grad, self.input_two_grad], **self.args - ) - - return grad_op - -op_bench.generate_c2_gradient_test(long_configs + short_configs,AddBenchmark) -``` -After the class is implemented, we need to register the tests with `generate_c2_gradient_test` function. - -This concludes the overview of the operator benchmark suite. diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py deleted file mode 100644 index 2d238e593fc9..000000000000 --- a/benchmarks/operator_benchmark/benchmark_caffe2.py +++ /dev/null @@ -1,202 +0,0 @@ -from collections import namedtuple - -import benchmark_utils -from benchmark_test_generator import _register_test - -from caffe2.proto import caffe2_pb2 -from caffe2.python import core, workspace - -from .benchmark_core import TestConfig - -"""Caffe2 performance microbenchmarks. - -This module contains Caffe2-specific functionalities for performance -microbenchmarks. -""" - - -class Caffe2BenchmarkBase: - """This is a base class used to create Caffe2 operator benchmark""" - - tensor_index = 0 - test_index = 0 - - def __init__(self): - self.args = {} - self.user_provided_name = None - self._num_inputs_require_grads = 0 - self._pass_count = 0 - - def _set_backward_test(self, is_backward): - pass - - def _device_option(self, device): - """This method is used to set device option.""" - if device not in ["cuda", "cpu"]: - raise ValueError("Missing attrs in configs") - - if "cuda" in device: - self.dev = core.DeviceOption(caffe2_pb2.CUDA, 0) - else: - self.dev = core.DeviceOption(caffe2_pb2.CPU) - return self.dev - - def tensor(self, shapes, dtype="float32", device="cpu"): - """A wapper function to create C2 tensor filled with random data. - The name/label of the tensor is returned and it is available - throughout the benchmark execution phase. - Args: - shapes: int or a sequence of ints to defining the shapes of the tensor - dtype: use the dtypes from numpy - (https://docs.scipy.org/doc/numpy/user/basics.types.html) - Return: - C2 tensor of dtype - """ - return self.feed_tensor(benchmark_utils.numpy_random(dtype, *shapes), device) - - def feed_tensor(self, tensor, device="cpu"): - """Similar to tensor, but can supply any data compatible with FeedBlob""" - blob_name = "blob_" + str(Caffe2BenchmarkBase.tensor_index) - dev = self._device_option(device) - with core.DeviceScope(dev): - workspace.FeedBlob(blob_name, tensor) - Caffe2BenchmarkBase.tensor_index += 1 - return blob_name - - def module_name(self): - """this is used to label the operator being benchmarked""" - if self.user_provided_name: - return self.user_provided_name - return self.__class__.__name__ - - def set_module_name(self, name): - self.user_provided_name = name - - def _value_to_str(self, value): - """if value is bool, we will convert it to 0 and 1""" - ret = value - if type(value) == bool: - ret = int(value) - return str(ret) - - def test_name(self, name_type="long", **kargs): - """this is a globally unique name which can be used to - label a specific test - """ - if name_type == "long": - test_name_str = [] - for key in kargs: - value = kargs[key] - test_name_str.append(key + self._value_to_str(value)) - name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "") - elif name_type == "short": - # this is used to generate test name based on unique index - name = "_".join( - [self.module_name(), "test", str(Caffe2BenchmarkBase.test_index)] - ) - Caffe2BenchmarkBase.test_index += 1 - return name - - def extract_inputs_tuple(self): - # add a dummy function here to match the interface of TorchBenchmarkBase - pass - - -class Caffe2OperatorTestCase: - """This class includes all the information needed to benchmark an operator. - op_bench: it's a user-defined class (child of Caffe2BenchmarkBase) - which includes input and operator, .etc - test_config: a namedtuple includes test_name, input_shape, tag, run_backward. - When run_backward is false, the run_forward method will be executed, otherwise - run_backward method will be executed. - """ - - def __init__(self, op_bench, test_config): - self.op_bench = op_bench - self.test_config = test_config - self.framework = "Caffe2" - - def run_forward(self, num_runs, print_per_iter=False, cuda_sync=False): - """Run the forward path of an operator in a loop""" - with core.DeviceScope(self.op_bench.dev): - op = self.op_bench.forward() - if not workspace.RunOperatorMultiple(op, num_runs): - raise ValueError(f"Unable to run operator test case: {self.test_name}") - - def run_backward(self, num_runs, print_per_iter=False): - """Run the backward path of an operator in a loop""" - with core.DeviceScope(self.op_bench.dev): - op = self.op_bench.backward() - if not workspace.RunOperatorMultiple(op, num_runs): - raise ValueError( - f"Unable to run operator gradient test case: {self.test_name}" - ) - - def _print_per_iter(self): - pass - - -def create_caffe2_op_test_case(op_bench, test_config): - test_case = Caffe2OperatorTestCase(op_bench, test_config) - test_config = test_case.test_config - op = test_case.op_bench - func_name = f"{op.module_name()}{test_case.framework}{str(test_config)}" - return (func_name, test_case) - - -OpMeta = namedtuple( - "OpMeta", - "op_type num_inputs input_dims input_types \ - output_dims num_outputs args device", -) - - -def generate_c2_test_from_ops(ops_metadata, bench_op, tags): - """ - This function is used to generate Caffe2 tests based on the metadata - of operators. The metadata includes seven fields which are 1) op_type: - the name of the operator. 2) num_inputs: the number of input blobs. - 3) input_dims: a dictionary which includes the shapes of the input blobs. - 4) input_types: a list which includes the types of input blobs. 5) - output_dims: a dictionary which includes the shapes of output blobs. - 6) num_oupts: the number of output blobs. 7) args: a dictionary which - includes the args for th operator. - Here is an example to show the metadata for the WeighedSum operator - op_type : WeightedSum - num_inputs: 4 - input_dims: {'0': [256], '1': [1], '2': [256], '3': [1]} - input_types: ['float', 'float', 'float', 'float'] - output_dims: {'0': [256]} - num_outputs: 4 - args: {} - TODO(mingzhe0908): introduce device and add it to the benchmark name - """ - for op_metadata in ops_metadata: - tmp_attrs = OpMeta( - op_metadata.op_type, - op_metadata.num_inputs, - op_metadata.input_dims, - op_metadata.input_types, - op_metadata.output_dims, - op_metadata.num_outputs, - op_metadata.args, - op_metadata.device, - ) - test_attrs = tmp_attrs._asdict() - op = bench_op() - op.init(**test_attrs) - test_name = op.test_name("short") - input_config = f"Shapes: {op_metadata.input_dims}, Type: {op_metadata.input_types}, Args: {str(op_metadata.args)}" - test_config = TestConfig(test_name, input_config, tags, run_backward=False) - if op is not None: - create_caffe2_op_test_case(op, test_config) - - -def generate_c2_test(configs, c2_bench_op): - """This function creates Caffe2 op test based on the given operator""" - return _register_test(configs, c2_bench_op, create_caffe2_op_test_case, False) - - -def generate_c2_gradient_test(configs, c2_bench_op): - """This function creates Caffe2 op test based on the given operator""" - return _register_test(configs, c2_bench_op, create_caffe2_op_test_case, True) diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index c315382d1538..239dddbf7231 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -13,6 +13,7 @@ # needs to be imported after torch import torch.utils.cpp_extension as cpp_extension # noqa: F401 + """Performance microbenchmarks. This module contains core functionalities for performance microbenchmark tests. @@ -50,7 +51,7 @@ def _create_test( """Create tests with the benchmark backend. Args: bench_op_obj: an object which instantiated from a subclass of - Caffe2BenchmarkBase/TorchBenchmarkBase which includes tensor + TorchBenchmarkBase which includes tensor creation and operator execution. orig_test_attrs: a dictionary includes test configs. tags: a attribute in test config to filter inputs @@ -75,7 +76,7 @@ def _build_test( """Generate PyTorch/Caffe2 tests of operators with different inputs. Args: configs: a dictionary that has the input shapes - bench_op: a subclass of Caffe2BenchmarkBase/TorchBenchmarkBase which includes tensor + bench_op: a subclass of TorchBenchmarkBase which includes tensor creation and operator execution OperatorTestCase: a named tuple to save the metadata of an test run_backward: a bool parameter indicating backward path @@ -233,9 +234,7 @@ def _print_perf_result(self, reported_run_time_us, test_case): ) ) else: - if test_case.framework == "PyTorch": - print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}") - + print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}") print( f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}" ) @@ -283,8 +282,7 @@ def _launch_backward(self, test_case, iters, print_per_iter=False): and the execution time is reported """ test_case.run_forward(num_runs=1, print_per_iter=False, cuda_sync=False) - if test_case.framework == "PyTorch": - test_case._output_mean() + test_case._output_mean() backward_time = timeit.timeit( functools.partial(test_case.run_backward, iters, print_per_iter), number=1 ) @@ -357,9 +355,6 @@ def _keep_test(self, test_case): # Currently, this is a sub-string matching. op_test_config = test_case.test_config - if self.args.framework: - frameworks = benchmark_utils.process_arg_list(self.args.framework) - operators = ( benchmark_utils.process_arg_list(self.args.operators) if self.args.operators @@ -370,7 +365,6 @@ def _keep_test(self, test_case): if ( self._check_keep(op_test_config.test_name, self.args.test_name) and self._check_keep_list(test_case.op_bench.module_name(), operators) - and self._check_keep_list(test_case.framework, frameworks) and self._check_operator_first_char( test_case.op_bench.module_name(), self.operator_range ) diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 7bb18f7d7708..6abbc566820b 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -92,7 +92,7 @@ def parse_args(): parser.add_argument( "--omp-num-threads", "--omp_num_threads", - help="Number of OpenMP threads used in PyTorch/Caffe2 runtime", + help="Number of OpenMP threads used in PyTorch runtime", default=None, type=int, ) @@ -100,7 +100,7 @@ def parse_args(): parser.add_argument( "--mkl-num-threads", "--mkl_num_threads", - help="Number of MKL threads used in PyTorch/Caffe2 runtime", + help="Number of MKL threads used in PyTorch runtime", default=None, type=int, ) @@ -135,12 +135,6 @@ def parse_args(): help="Only run the forward path of operators", ) - parser.add_argument( - "--framework", - help="Comma-delimited list of frameworks to test (Caffe2, PyTorch)", - default="Caffe2,PyTorch", - ) - parser.add_argument( "--device", help="Run tests on the provided architecture (cpu, cuda)", @@ -160,8 +154,7 @@ def parse_args(): # "Modifications to the environment variables after the program has started, # even if modified by the program itself, are ignored by the OpenMP implementation" benchmark_utils.set_omp_threads(args.omp_num_threads) - if benchmark_utils.is_pytorch_enabled(args.framework): - torch.set_num_threads(args.omp_num_threads) + torch.set_num_threads(args.omp_num_threads) if args.mkl_num_threads: benchmark_utils.set_mkl_threads(args.mkl_num_threads) diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py index d7e45b7c1685..be9c62cb3c28 100644 --- a/benchmarks/operator_benchmark/benchmark_utils.py +++ b/benchmarks/operator_benchmark/benchmark_utils.py @@ -319,14 +319,6 @@ def op_list(**configs): return generated_configs -def is_caffe2_enabled(framework_arg): - return "Caffe2" in framework_arg - - -def is_pytorch_enabled(framework_arg): - return "PyTorch" in framework_arg - - def get_operator_range(chars_range): """Generates the characters from chars_range inclusive.""" if chars_range == "None" or chars_range is None: diff --git a/benchmarks/operator_benchmark/c2/add_test.py b/benchmarks/operator_benchmark/c2/add_test.py deleted file mode 100644 index c3b71f3e8514..000000000000 --- a/benchmarks/operator_benchmark/c2/add_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for element-wise Add operator. Supports both Caffe2/PyTorch.""" - -# Configs for C2 add operator -add_long_configs = op_bench.cross_product_configs( - M=[8, 64, 128], - N=range(2, 10, 3), - K=[2**x for x in range(0, 3)], - dtype=["int", "float"], - tags=["long"], -) - - -add_short_configs = op_bench.config_list( - attrs=[ - [8, 16, 32, "int"], - [16, 16, 64, "float"], - [64, 64, 128, "int"], - ], - attr_names=["M", "N", "K", "dtype"], - tags=["short"], -) - - -class AddBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, K, dtype): - self.input_one = self.tensor([M, N, K], dtype) - self.input_two = self.tensor([M, N, K], dtype) - self.output = self.tensor([M, N, K], dtype) - self.set_module_name("add") - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - return op - - -op_bench_c2.generate_c2_test(add_long_configs + add_short_configs, AddBenchmark) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/batch_box_cox_test.py b/benchmarks/operator_benchmark/c2/batch_box_cox_test.py deleted file mode 100644 index 7c40f513cd6e..000000000000 --- a/benchmarks/operator_benchmark/c2/batch_box_cox_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for BatchBoxCox operator.""" - -# Configs for C2 BatchBoxCox operator -batch_box_cox_long_configs = op_bench.cross_product_configs( - M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] -) - - -batch_box_cox_short_configs = op_bench.config_list( - attrs=[ - [16, 16, "float"], - [16, 16, "double"], - [64, 64, "float"], - [64, 64, "double"], - ], - attr_names=["M", "N", "dtype"], - tags=["short"], -) - - -class BatchBoxCoxBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, dtype): - self.data = self.tensor([M, N], dtype) - self.lambda1 = self.tensor([N], dtype) - self.lambda2 = self.tensor([N], dtype) - self.output = self.tensor([1, 1], dtype) - self.set_module_name("batch_box_cox") - - def forward(self): - op = core.CreateOperator( - "BatchBoxCox", [self.data, self.lambda1, self.lambda2], self.output - ) - return op - - -op_bench_c2.generate_c2_test( - batch_box_cox_long_configs + batch_box_cox_short_configs, BatchBoxCoxBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/batch_gather_test.py b/benchmarks/operator_benchmark/c2/batch_gather_test.py deleted file mode 100644 index c0ff2c06f061..000000000000 --- a/benchmarks/operator_benchmark/c2/batch_gather_test.py +++ /dev/null @@ -1,58 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -import numpy -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for element-wise BatchGather operator.""" - -# Configs for C2 BatherGather operator -batch_gather_configs_short = op_bench.config_list( - attr_names=["M", "N", "K"], - attrs=[ - [8, 8, 1], - [256, 512, 1], - [512, 512, 1], - [8, 8, 2], - [256, 512, 2], - [512, 512, 2], - ], - cross_product_configs={ - "device": ["cpu", "cuda"], - }, - tags=["short"], -) - -batch_gather_configs_long = op_bench.cross_product_configs( - M=[128, 1024], N=[128, 1024], K=[1, 2], device=["cpu", "cuda"], tags=["long"] -) - - -class BatchGatherBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, K, device): - self.input_one = self.tensor([M, N, K], device=device) - max_val = N - numpy.random.seed((1 << 32) - 1) - index_dim = numpy.random.randint(0, N) - self.index = self.feed_tensor( - numpy.random.randint(0, max_val, index_dim), device=device - ) - self.output = self.tensor([M, index_dim, K], device=device) - self.set_module_name("batch_gather") - - def forward(self): - op = core.CreateOperator( - "BatchGather", [self.input_one, self.index], self.output - ) - return op - - -op_bench_c2.generate_c2_test( - batch_gather_configs_long + batch_gather_configs_short, BatchGatherBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/clip_ranges_test.py b/benchmarks/operator_benchmark/c2/clip_ranges_test.py deleted file mode 100644 index 57bcd9858a8f..000000000000 --- a/benchmarks/operator_benchmark/c2/clip_ranges_test.py +++ /dev/null @@ -1,54 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core, dyndep - -dyndep.InitOpsLibrary("@/caffe2/caffe2/fb/operators:clip_ranges_op") - -"""Microbenchmarks for ClipRanges operator.""" - -# Configs for C2 ClipRanges operator -clip_ranges_long_configs = op_bench.cross_product_configs( - LENGTH=range(1, 100), - M=[1], - N=[2], - MAX_LENGTH=range(1, 100), - dtype=["int32"], - tags=["long"], -) - - -clip_ranges_short_configs = op_bench.config_list( - attrs=[ - [6, 1, 2, 1, "int32"], - [7, 1, 2, 2, "int32"], - [8, 1, 2, 3, "int32"], - [9, 1, 2, 4, "int32"], - [10, 1, 2, 5, "int32"], - ], - attr_names=["LENGTH", "M", "N", "MAX_LENGTH", "dtype"], - tags=["short"], -) - - -class ClipRangesBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, LENGTH, M, N, MAX_LENGTH, dtype): - self.input = self.tensor([LENGTH, M, N], dtype) - self.max_length = MAX_LENGTH - self.set_module_name("clip_ranges") - - def forward(self): - op = core.CreateOperator( - "ClipRanges", self.input, self.input, max_length=self.max_length - ) - return op - - -op_bench_c2.generate_c2_test( - clip_ranges_long_configs + clip_ranges_short_configs, ClipRangesBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/concat_test.py b/benchmarks/operator_benchmark/c2/concat_test.py deleted file mode 100644 index 4e91c30f2a75..000000000000 --- a/benchmarks/operator_benchmark/c2/concat_test.py +++ /dev/null @@ -1,171 +0,0 @@ -import random - -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for Concat operator. Supports both Caffe2/PyTorch.""" - -cross_product_configs = { - "device": ["cpu", "cuda"], - "dtype": ["float"], - "add_axis": [0], -} - -# Configs for C2 concat operator -cat_configs_short = op_bench.config_list( - attr_names=["sizes", "N", "axis"], - attrs=[ - [(1, 1, 1), 2, 0], # noqa: E241 - [(512, 512, 2), 2, 1], # noqa: E241 - [(128, 1024, 2), 2, 1], # noqa: E241 - ], - cross_product_configs=cross_product_configs, - tags=["short"], -) - -# Configs specific to static runtime feature - a fast runtime for pared down models -cat_configs_static_runtime = op_bench.config_list( - attr_names=["sizes", "N", "axis", "add_axis"], - attrs=[ - [(1, 40), 5, 1, 1], - [[(1, 160), (1, 14)], -1, 1, 0], - [[(1, 20, 40), (1, 4, 40), (1, 5, 40)], -1, 1, 0], - [[(1, 580), (1, 174)], -1, 1, 0], - [(20, 40), 5, 1, 1], - [[(20, 160), (20, 14)], -1, 1, 0], - [[(20, 20, 40), (20, 4, 40), (20, 5, 40)], -1, 1, 0], - [[(20, 580), (20, 174)], -1, 1, 0], - ], - cross_product_configs=cross_product_configs, - tags=["static_runtime"], -) - -cat_configs_long = op_bench.config_list( - attr_names=["sizes", "N", "axis"], - attrs=[ - [(2**10, 2**10, 2), 2, 0], # noqa: E241 - [(2**10 + 1, 2**10 - 1, 2), 2, 1], # noqa: E226,E241 - [(2**10, 2**10, 2), 2, 2], # noqa: E241 - [ - [ - lambda: random.randint(2**6, 2**7), - 2**7 - 17, - 2**6 + 1, - ], # noqa: E201,E226,E241 - 5, - 0, - ], - [ - [ - 2**6 + 2**5, - lambda: random.randint(2**6, 2**7), - 2**6, - ], # noqa: E201,E226,E241,E272 - 5, - 1, - ], - [ - [ - 2**7, - 2**6, - lambda: random.randint(2**6, 2**7), - ], # noqa: E201,E241,E272 - 5, - 2, - ], - [[lambda: random.randint(2**5, 2**6), 2**5, 2**6], 50, 0], # noqa: E241 - [ - [2**5, lambda: random.randint(2**5, 2**6), 2**6], # noqa: E241,E272 - 50, - 1, - ], - [ - [ - 2**5 + 1, - 2**6 + 1, - lambda: random.randint(2**5, 2**6), - ], # noqa: E226,E241,E272 - 50, - 2, - ], - ], - cross_product_configs=cross_product_configs, - tags=["long"], -) - -# There is a different codepath on CUDA for >4 dimensions -cat_configs_multidim = op_bench.config_list( - attr_names=["sizes", "N", "axis", "dtype"], - attrs=[ - [(2**6, 2**5, 2**2, 2**4, 2**5), 2, 2], # noqa: E241 - [(2**4, 2**5, 2**2, 2**4, 2**5), 8, 2], # noqa: E241 - [ - (2**3 + 1, 2**5 - 1, 2**2 + 1, 2**4 - 1, 2**5 + 1), - 17, - 4, - ], # noqa: E226,E241 - ], - cross_product_configs=cross_product_configs, - tags=["multidim"], -) - -cat_configs_manyinputs = op_bench.config_list( - attr_names=["sizes", "N", "axis"], - attrs=[ - [[lambda: random.randint(1, 10000)], 100, 0], - [[lambda: random.randint(1, 1000)], 1000, 0], - [[lambda: random.randint(1, 500)], 2000, 0], - [[lambda: random.randint(1, 300)], 3000, 0], - ], - cross_product_configs=cross_product_configs, - tags=["manyinputs"], -) - - -class ConcatBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, sizes, N, axis, add_axis, dtype, device): - random.seed(42) - self.inputs = [] - self.args = {"axis": axis, "add_axis": add_axis} - gen_sizes = [] - if type(sizes) == list and N == -1: - gen_sizes = sizes - else: - for i in range(N): - gen_sizes.append( - [ - old_size() if callable(old_size) else old_size - for old_size in sizes - ] - ) - - for s in gen_sizes: - self.inputs.append(self.tensor(s, dtype, device=device)) - - self.output = self.tensor(gen_sizes[0], dtype, device=device) - self.split_info = self.tensor(gen_sizes[0], "int") - self.set_module_name("concat") - - def forward(self): - op = core.CreateOperator( - "Concat", self.inputs, [self.output, self.split_info], **self.args - ) - return op - - -op_bench_c2.generate_c2_test( - cat_configs_short - + cat_configs_long - + cat_configs_multidim - + cat_configs_manyinputs - + cat_configs_static_runtime, - ConcatBenchmark, -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/matmul_test.py b/benchmarks/operator_benchmark/c2/matmul_test.py deleted file mode 100644 index 72bc4c78d710..000000000000 --- a/benchmarks/operator_benchmark/c2/matmul_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - -"""Microbenchmarks for MatMul operator""" - -# Configs for C2 Matmul operator -mm_long_configs = op_bench.cross_product_configs( - M=[8, 64, 128], - N=range(2, 10, 3), - K=[2**x for x in range(0, 3)], - trans_a=[True, False], - trans_b=[True, False], - tags=["long"], -) - - -mm_short_configs = op_bench.config_list( - attrs=[ - [128, 128, 128, False, True], - [1024, 1024, 256, True, False], - [8192, 8192, 1024, True, False], - ], - attr_names=["M", "N", "K", "trans_a", "trans_b"], - tags=["short"], -) - - -class MatMulBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, K, trans_a, trans_b): - self.input_one = self.tensor([N, M]) if trans_a else self.tensor([M, N]) - self.input_two = self.tensor([K, N]) if trans_b else self.tensor([N, K]) - self.args = {"trans_a": trans_a, "trans_b": trans_b} - self.output = self.tensor([M, K]) - self.set_module_name("matmul") - - def forward(self): - op = core.CreateOperator( - "MatMul", [self.input_one, self.input_two], self.output, **self.args - ) - return op - - -op_bench_c2.generate_c2_test(mm_long_configs + mm_short_configs, MatMulBenchmark) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/quantile_op_test.py b/benchmarks/operator_benchmark/c2/quantile_op_test.py deleted file mode 100644 index 296b6bf189e3..000000000000 --- a/benchmarks/operator_benchmark/c2/quantile_op_test.py +++ /dev/null @@ -1,48 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for QuantileOp operator.""" - -# Configs for C2 QuantileOp operator -quantile_op_long_configs = op_bench.cross_product_configs( - M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] -) - - -quantile_op_short_configs = op_bench.config_list( - attrs=[ - [16, 16, "float"], - [16, 16, "double"], - [64, 64, "float"], - [64, 64, "double"], - ], - attr_names=["M", "N", "dtype"], - tags=["short"], -) - - -class QuantileOpBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, dtype): - self.data = [self.tensor([N], dtype) for _ in range(M)] - self.quantile = 0.3 - self.output = self.tensor([1], dtype) - self.set_module_name("quantile_op") - - def forward(self): - op = core.CreateOperator( - "Quantile", inputs=self.data, outputs=self.output, quantile=self.quantile - ) - return op - - -op_bench_c2.generate_c2_test( - quantile_op_long_configs + quantile_op_short_configs, QuantileOpBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/c2/replace_nan_test.py b/benchmarks/operator_benchmark/c2/replace_nan_test.py deleted file mode 100644 index c735a69b4ab4..000000000000 --- a/benchmarks/operator_benchmark/c2/replace_nan_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import benchmark_caffe2 as op_bench_c2 -from benchmark_caffe2 import Caffe2BenchmarkBase # noqa: F401 - -import operator_benchmark as op_bench -from caffe2.python import core - - -"""Microbenchmarks for element-wise ReplaceNaN operator.""" - -# Configs for C2 ReplaceNaN operator -replace_nan_long_configs = op_bench.cross_product_configs( - M=[32, 64, 128], N=range(32, 128, 32), dtype=["float", "double"], tags=["long"] -) - - -replace_nan_short_configs = op_bench.config_list( - attrs=[ - [16, 16, "float"], - [16, 16, "double"], - [64, 64, "float"], - [64, 64, "double"], - ], - attr_names=["M", "N", "dtype"], - tags=["short"], -) - - -class ReplaceNaNBenchmark(op_bench_c2.Caffe2BenchmarkBase): - def init(self, M, N, dtype): - self.input = self.tensor([M, N], dtype) - self.set_module_name("replace_nan") - - def forward(self): - op = core.CreateOperator("ReplaceNaN", self.input, self.input, value=1.0) - return op - - -op_bench_c2.generate_c2_test( - replace_nan_long_configs + replace_nan_short_configs, ReplaceNaNBenchmark -) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py b/benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py deleted file mode 100644 index ff34a58533f9..000000000000 --- a/benchmarks/operator_benchmark/common/tests/c2_cpu_gpu_forward_backward_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import operator_benchmark as op_bench - -from caffe2.python import core - - -add_configs = op_bench.cross_product_configs( - M=[8], N=[8], K=[8], tags=["short"], device=["cuda", "cpu"] -) - - -class AddBenchmark(op_bench.Caffe2BenchmarkBase): - def init(self, M, N, K, device): - self.set_module_name("add") - self.input_one = self.tensor([M, N, K], device=device) - self.input_two = self.tensor([M, N, K], device=device) - self.input_one_grad = self.tensor([M, N, K], device=device) - self.input_two_grad = self.tensor([M, N, K], device=device) - self.output = self.tensor([M, N, K], device=device) - - def forward(self): - op = core.CreateOperator( - "Add", [self.input_one, self.input_two], self.output, **self.args - ) - return op - - def backward(self): - grad_op = core.CreateOperator( - "AddGradient", - [self.output, self.input_one, self.input_two], - [self.input_one_grad, self.input_two_grad], - **self.args, - ) - return grad_op - - -op_bench.generate_c2_test(add_configs, AddBenchmark) -op_bench.generate_c2_gradient_test(add_configs, AddBenchmark) - - -if __name__ == "__main__": - op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/configs.py b/benchmarks/operator_benchmark/pt/configs.py index 3add77db1a87..ccfa62cc364e 100644 --- a/benchmarks/operator_benchmark/pt/configs.py +++ b/benchmarks/operator_benchmark/pt/configs.py @@ -34,6 +34,17 @@ def remove_cuda(config_list): tags=["long"], ) +convtranspose_1d_configs_short = op_bench.config_list( + attr_names=["IC", "OC", "kernel", "stride", "N", "L"], + attrs=[ + [2016, 1026, 1024, 256, 1, 224], + ], + cross_product_configs={ + "device": ["cpu", "cuda"], + }, + tags=["short"], +) + # Configs for Conv2d and ConvTranspose1d conv_2d_configs_short = op_bench.config_list( attr_names=[ diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index e01473a04f5b..ad315d8a0bb8 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -37,7 +37,9 @@ def forward(self, input): configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark ) op_bench.generate_pt_test( - configs.conv_1d_configs_short + configs.conv_1d_configs_long, + configs.convtranspose_1d_configs_short + + configs.conv_1d_configs_short + + configs.conv_1d_configs_long, ConvTranspose1dBenchmark, ) diff --git a/benchmarks/record_function_benchmark/record_function_bench.py b/benchmarks/record_function_benchmark/record_function_bench.py index 348c1cae7650..f42f9b0d647f 100644 --- a/benchmarks/record_function_benchmark/record_function_bench.py +++ b/benchmarks/record_function_benchmark/record_function_bench.py @@ -1,18 +1,13 @@ import argparse import sys -import torch -import torch.utils.benchmark as benchmark_utils - - -try: - from benchmarks.fastrnns.factory import lstm_creator -except ImportError: - from caffe2.benchmarks.fastrnns.factory import lstm_creator - +from benchmarks.fastrnns.factory import lstm_creator from torchvision.models import resnet50 +import torch +import torch.utils.benchmark as benchmark_utils + def prepare_lstm_jit(bench_args): model_def = lstm_creator( diff --git a/benchmarks/static_runtime/test_generated_ops.cc b/benchmarks/static_runtime/test_generated_ops.cc index 415bf464fbd1..bdf0585404ed 100644 --- a/benchmarks/static_runtime/test_generated_ops.cc +++ b/benchmarks/static_runtime/test_generated_ops.cc @@ -272,6 +272,38 @@ TEST(StaticRuntime, autogen_addr) { /*check_resize=*/true); } +TEST(StaticRuntime, autogen__test_functorch_fallback) { + const std::string script = R"IR( + graph(%self: Tensor, %other: Tensor): + %bias: None = prim::Constant() + %ret = aten::_test_functorch_fallback(%self, %other) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + auto self0 = at::rand({6, 6, 6}); + auto other0 = at::rand({6, 6, 6}); + std::vector args{self0, other0}; + testStaticRuntime( + script, + args, + {}, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); + + auto self1 = at::rand({22, 22, 22}); + auto other1 = at::rand({22, 22, 22}); + std::vector args2{self1, other1}; + testStaticRuntime( + script, + args, + args2, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); +} + TEST(StaticRuntime, autogen_argmax) { const std::string script = R"IR( graph(%self: Tensor, %dim: int?, %keepdim: bool): @@ -4440,6 +4472,40 @@ TEST(StaticRuntime, autogen_masked_select) { /*check_resize=*/true); } +TEST(StaticRuntime, autogen_nonzero_static) { + const std::string script = R"IR( + graph(%self: Tensor, %size: int, %fill_value: int): + %bias: None = prim::Constant() + %ret = aten::nonzero_static(%self, %size, %fill_value) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + auto self0 = at::rand({6, 6, 6}); + auto size0 = 1; + auto fill_value0 = 1; + std::vector args{self0, size0, fill_value0}; + testStaticRuntime( + script, + args, + {}, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); + + auto self1 = at::rand({22, 22, 22}); + auto size1 = 1; + auto fill_value1 = 1; + std::vector args2{self1, size1, fill_value1}; + testStaticRuntime( + script, + args, + args2, + /*use_allclose=*/false, + /*use_equalnan=*/false, + /*check_resize=*/true); +} + TEST(StaticRuntime, autogen_gather) { const std::string script = R"IR( graph(%self: Tensor, %dim: int, %index: Tensor, %sparse_grad: bool): @@ -7106,222 +7172,6 @@ TEST(StaticRuntime, autogen_special_multigammaln) { /*check_resize=*/true); } -TEST(StaticRuntime, autogen_fft_fft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_fft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_ifft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_ifft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_rfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_rfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_irfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_irfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_hfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_hfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_fft_ihfft) { - const std::string script = R"IR( - graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): - %bias: None = prim::Constant() - %ret = aten::fft_ihfft(%self, %n, %dim, %norm) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6}); - auto n0 = 1; - auto dim0 = 1; - auto norm0 = "forward"; - std::vector args{self0, n0, dim0, norm0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto self1 = at::rand({22, 22, 22}); - auto n1 = 1; - auto dim1 = 1; - auto norm1 = "forward"; - std::vector args2{self1, n1, dim1, norm1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - TEST(StaticRuntime, autogen_linalg_cross) { const std::string script = R"IR( graph(%self: Tensor, %other: Tensor, %dim: int): diff --git a/binaries/bench_gen/bench_gen.py b/binaries/bench_gen/bench_gen.py deleted file mode 100755 index 7523e76f8b14..000000000000 --- a/binaries/bench_gen/bench_gen.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import ast - -from caffe2.python import brew, workspace - -from caffe2.python.model_helper import ModelHelper -from caffe2.python.predictor import mobile_exporter - - -def parse_kwarg(kwarg_str): - key, value = kwarg_str.split("=") - try: - value = ast.literal_eval(value) - except ValueError: - pass - return key, value - - -def main(args): - # User defined keyword arguments - kwargs = {"order": "NCHW", "use_cudnn": False} - kwargs.update(dict(args.kwargs)) - - model = ModelHelper(name=args.benchmark_name) - - op_type = args.operator # assumes a brew type op name - input_name = args.input_name - output_name = args.output_name - - iters = int(args.instances) - for i in range(iters): - input_blob_name = input_name + (str(i) if i > 0 and args.chain else "") - output_blob_name = output_name + str(i + 1) - add_op = getattr(brew, op_type) - add_op(model, input_blob_name, output_blob_name, **kwargs) - if args.chain: - input_name, output_name = output_name, input_name - - workspace.RunNetOnce(model.param_init_net) - - init_net, predict_net = mobile_exporter.Export(workspace, model.net, model.params) - - if args.debug: - print("init_net:") - for op in init_net.op: - print(" ", op.type, op.input, "-->", op.output) - print("predict_net:") - for op in predict_net.op: - print(" ", op.type, op.input, "-->", op.output) - - with open(args.predict_net, "wb") as f: - f.write(predict_net.SerializeToString()) - with open(args.init_net, "wb") as f: - f.write(init_net.SerializeToString()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Utility to generate Caffe2 benchmark models." - ) - parser.add_argument("operator", help="Caffe2 operator to benchmark.") - parser.add_argument( - "-b", - "--blob", - help="Instantiate a blob --blob name=dim1,dim2,dim3", - action="append", - ) - parser.add_argument("--context", help="Context to run on.", default="CPU") - parser.add_argument( - "--kwargs", - help="kwargs to pass to operator.", - nargs="*", - type=parse_kwarg, - default=[], - ) - parser.add_argument( - "--init-net", - "--init_net", - help="Output initialization net.", - default="init_net.pb", - ) - parser.add_argument( - "--predict-net", - "--predict_net", - help="Output prediction net.", - default="predict_net.pb", - ) - parser.add_argument( - "--benchmark-name", - "--benchmark_name", - help="Name of the benchmark network", - default="benchmark", - ) - parser.add_argument( - "--input-name", "--input_name", help="Name of the input blob.", default="data" - ) - parser.add_argument( - "--output-name", - "--output_name", - help="Name of the output blob.", - default="output", - ) - parser.add_argument( - "--instances", help="Number of instances to run the operator.", default="1" - ) - parser.add_argument( - "-d", "--debug", help="Print debug information.", action="store_true" - ) - parser.add_argument( - "-c", - "--chain", - help="Chain ops together (create data dependencies)", - action="store_true", - ) - args = parser.parse_args() - main(args) diff --git a/buckbuild.bzl b/buckbuild.bzl index 649ebe668365..1d668117e910 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -383,6 +383,7 @@ def get_aten_generated_files(enabled_backends): "core/TensorMethods.cpp", "core/aten_interned_strings.h", "core/enum_tag.h", + "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp", ] + get_aten_derived_type_srcs(enabled_backends) # This is tiresome. A better strategy would be to unconditionally @@ -467,6 +468,7 @@ def gen_aten_files( cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([ "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT), "--install_dir $OUT", + "--aoti_install_dir $OUT/torch/csrc/inductor/aoti_torch/generated" ] + extra_params), visibility = visibility, compatible_with = compatible_with, diff --git a/build.bzl b/build.bzl index 5ab9f92acecc..8fd15f4e9c42 100644 --- a/build.bzl +++ b/build.bzl @@ -73,6 +73,7 @@ def define_targets(rules): "$(execpath //torchgen:gen)", "--install_dir=$(RULEDIR)", "--source-path aten/src/ATen", + "--aoti_install_dir=$(RULEDIR)/torch/csrc/inductor/aoti_torch/generated" ] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else [])) gen_aten_outs_cuda = ( @@ -83,6 +84,7 @@ def define_targets(rules): gen_aten_outs = ( GENERATED_H + GENERATED_H_CORE + GENERATED_CPP + GENERATED_CPP_CORE + + GENERATED_AOTI_CPP + aten_ufunc_generated_cpu_sources() + aten_ufunc_generated_cpu_kernel_sources() + [ "Declarations.yaml", @@ -316,3 +318,8 @@ GENERATED_AUTOGRAD_CPP = [ "torch/csrc/lazy/generated/RegisterAutogradLazy.cpp", "torch/csrc/lazy/generated/RegisterLazy.cpp", ] + _GENERATED_AUTOGRAD_CPP_HEADERS + GENERATED_LAZY_H + +GENERATED_AOTI_CPP = [ + "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp", + "torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp", +] diff --git a/build_variables.bzl b/build_variables.bzl index 8b5ac4f46d7c..323588c15b4c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -515,6 +515,8 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/socket.cpp", "torch/csrc/distributed/c10d/Work.cpp", + "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", ] # These files are only supported on Linux (and others) but not on Windows. @@ -825,6 +827,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/guards.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", + "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 1cf1782fa570..8ecaa7be7377 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -65,7 +65,7 @@ enum class Backend { NumOptions }; -static inline Backend dispatchKeyToBackend(DispatchKey t) { +inline Backend dispatchKeyToBackend(DispatchKey t) { if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) { return Backend::CPU; } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) { @@ -142,7 +142,7 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { } } -static inline DispatchKey backendToDispatchKey(Backend b) { +inline DispatchKey backendToDispatchKey(Backend b) { switch (b) { case Backend::CPU: return DispatchKey::CPU; @@ -217,7 +217,7 @@ static inline DispatchKey backendToDispatchKey(Backend b) { } } -static inline DeviceType backendToDeviceType(Backend b) { +inline DeviceType backendToDeviceType(Backend b) { switch (b) { case Backend::CPU: case Backend::MkldnnCPU: @@ -281,8 +281,7 @@ static inline DeviceType backendToDeviceType(Backend b) { } } -// TODO: This probably shouldn't actually be static inline -static inline const char* toString(Backend b) { +inline const char* toString(Backend b) { switch (b) { case Backend::CPU: return "CPU"; @@ -357,7 +356,7 @@ static inline const char* toString(Backend b) { } } -static inline bool isSparse(Backend b) { +inline bool isSparse(Backend b) { switch (b) { case Backend::SparseXPU: case Backend::SparseCPU: @@ -371,7 +370,7 @@ static inline bool isSparse(Backend b) { } } -static inline bool isSparseCsr(Backend b) { +inline bool isSparseCsr(Backend b) { switch (b) { case Backend::SparseCsrXPU: case Backend::SparseCsrCPU: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index f7461ea73a6d..4c391d60f2b0 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -901,7 +901,7 @@ C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); // legacy code that is still using DispatchKey for things like instanceof // checks; if at all possible, refactor the code to stop using DispatchKey in // those cases. -static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { +inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // NB: If you add any extra keys that can be stored in TensorImpl on // top of existing "backend" keys like CPU/CUDA, you need to add it // here. At the moment, autograd keys and ADInplaceOrView key need this diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 590b24a7bc20..f7f059fd513d 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -315,7 +315,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) #undef DEFINE_CONSTANT -static inline const char* toString(ScalarType t) { +inline const char* toString(ScalarType t) { #define DEFINE_CASE(_, name) \ case ScalarType::name: \ return #name; @@ -328,7 +328,7 @@ static inline const char* toString(ScalarType t) { #undef DEFINE_CASE } -static inline size_t elementSize(ScalarType t) { +inline size_t elementSize(ScalarType t) { #define CASE_ELEMENTSIZE_CASE(ctype, name) \ case ScalarType::name: \ return sizeof(ctype); @@ -341,7 +341,7 @@ static inline size_t elementSize(ScalarType t) { #undef CASE_ELEMENTSIZE_CASE } -static inline bool isIntegralType(ScalarType t, bool includeBool) { +inline bool isIntegralType(ScalarType t, bool includeBool) { bool isIntegral = (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || t == ScalarType::Long || t == ScalarType::Short || @@ -353,44 +353,44 @@ static inline bool isIntegralType(ScalarType t, bool includeBool) { C10_DEPRECATED_MESSAGE( "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.") -static inline bool isIntegralType(ScalarType t) { +inline bool isIntegralType(ScalarType t) { return isIntegralType(t, /*includeBool=*/false); } -static inline bool isFloat8Type(ScalarType t) { +inline bool isFloat8Type(ScalarType t) { return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz; } -static inline bool isReducedFloatingType(ScalarType t) { +inline bool isReducedFloatingType(ScalarType t) { return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t); } -static inline bool isFloatingType(ScalarType t) { +inline bool isFloatingType(ScalarType t) { return t == ScalarType::Double || t == ScalarType::Float || isReducedFloatingType(t); } -static inline bool isComplexType(ScalarType t) { +inline bool isComplexType(ScalarType t) { return ( t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble); } -static inline bool isQIntType(ScalarType t) { +inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || t == ScalarType::QUInt2x4; } -static inline bool isBitsType(ScalarType t) { +inline bool isBitsType(ScalarType t) { return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || t == ScalarType::Bits16; } -static inline bool isBarebonesUnsignedType(ScalarType t) { +inline bool isBarebonesUnsignedType(ScalarType t) { return t == ScalarType::UInt1 || t == ScalarType::UInt2 || t == ScalarType::UInt3 || t == ScalarType::UInt4 || t == ScalarType::UInt5 || t == ScalarType::UInt6 || @@ -398,7 +398,7 @@ static inline bool isBarebonesUnsignedType(ScalarType t) { t == ScalarType::UInt32 || t == ScalarType::UInt64; } -static inline ScalarType toQIntType(ScalarType t) { +inline ScalarType toQIntType(ScalarType t) { switch (t) { case ScalarType::Byte: return ScalarType::QUInt8; @@ -411,7 +411,7 @@ static inline ScalarType toQIntType(ScalarType t) { } } -static inline ScalarType toUnderlying(ScalarType t) { +inline ScalarType toUnderlying(ScalarType t) { switch (t) { case ScalarType::QUInt8: case ScalarType::QUInt4x2: @@ -427,7 +427,7 @@ static inline ScalarType toUnderlying(ScalarType t) { } } -static inline bool isSignedType(ScalarType t) { +inline bool isSignedType(ScalarType t) { #define CASE_ISSIGNED(name) \ case ScalarType::name: \ return std::numeric_limits< \ @@ -484,11 +484,11 @@ static inline bool isSignedType(ScalarType t) { #undef CASE_ISSIGNED } -static inline bool isUnderlying(ScalarType type, ScalarType qtype) { +inline bool isUnderlying(ScalarType type, ScalarType qtype) { return type == toUnderlying(qtype); } -static inline ScalarType toRealValueType(ScalarType t) { +inline ScalarType toRealValueType(ScalarType t) { switch (t) { case ScalarType::ComplexHalf: return ScalarType::Half; @@ -501,7 +501,7 @@ static inline ScalarType toRealValueType(ScalarType t) { } } -static inline ScalarType toComplexType(ScalarType t) { +inline ScalarType toComplexType(ScalarType t) { switch (t) { case ScalarType::BFloat16: // BFloat16 has range equivalent to Float, @@ -526,7 +526,7 @@ static inline ScalarType toComplexType(ScalarType t) { // see tensor_attributes.rst for detailed explanation and examples // of casting rules. -static inline bool canCast(const ScalarType from, const ScalarType to) { +inline bool canCast(const ScalarType from, const ScalarType to) { // We disallow complex -> non complex, e.g., float_tensor *= complex is // disallowed. if (isComplexType(from) && !isComplexType(to)) { diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h index 910e0d24b0a3..d2694c96221e 100644 --- a/c10/core/ScalarTypeToTypeMeta.h +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -13,21 +13,21 @@ namespace c10 { /** * convert ScalarType enum values to TypeMeta handles */ -static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { +inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { return caffe2::TypeMeta::fromScalarType(scalar_type); } /** * convert TypeMeta handles to ScalarType enum values */ -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { +inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { return dtype.toScalarType(); } /** * typeMetaToScalarType(), lifted to optional */ -static inline optional optTypeMetaToScalarType( +inline optional optTypeMetaToScalarType( optional type_meta) { if (!type_meta.has_value()) { return c10::nullopt; @@ -38,19 +38,19 @@ static inline optional optTypeMetaToScalarType( /** * convenience: equality across TypeMeta/ScalarType conversion */ -static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { +inline bool operator==(ScalarType t, caffe2::TypeMeta m) { return m.isScalarType(t); } -static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { +inline bool operator==(caffe2::TypeMeta m, ScalarType t) { return t == m; } -static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { +inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { return !(t == m); } -static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { +inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { return !(t == m); } diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 9dd6f5f43131..2b5bbdb86c8a 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -18,7 +18,7 @@ void throwNullDataPtrError() { "If you're using torch.compile/export/fx, it is likely that we are erroneously " "tracing into a custom kernel. To fix this, please wrap the custom kernel into " "an opaque custom op. Please see the following for details: " - "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ"); + "https://pytorch.org/docs/main/notes/custom_operators.html"); } // NOTE: [FakeTensor.data_ptr deprecation] diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 9ffab5065109..bb92b09775b7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_truediv(const SymNode& other) { + return truediv(other); + } + virtual SymNode int_truediv(const SymNode& other) { + return truediv(other); + } + // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_pow(const SymNode& other) { + return pow(other); + } + virtual SymNode pow_by_natural(const SymNode& other) { + return pow(other); + } + // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode int_floordiv(const SymNode& other) { + return floordiv(other); + } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 47f83c78e578..516a61f02004 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -577,6 +577,11 @@ void TensorImpl::copy_generic_tensor_metadata( dest_impl->numel_ = src_impl->numel_; if (src_impl->extra_meta_ != nullptr) { dest_impl->extra_meta_ = src_impl->extra_meta_->clone(); + } else if (dest_impl->extra_meta_ != nullptr) { + // Clean dest_impl extra meta data, cause shallow_copy_from dest impl is a + // real tensor impl, which maybe take extra meta data. This info will + // contaminate the new dest_impl metadata info. + dest_impl->extra_meta_.reset(nullptr); } // NB: symbolic sizes and strides are copied as is custom policy, but python diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index e49a66c916ff..877c1c09543c 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1580,7 +1580,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { "If you're using torch.compile/export/fx, it is likely that we are erroneously " "tracing into a custom kernel. To fix this, please wrap the custom kernel into " "an opaque custom op. Please see the following for details: " - "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ\n" + "https://pytorch.org/docs/main/notes/custom_operators.html\n" "If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call " "mutable_data() or raw_mutable_data() to actually allocate memory."); // Caller does the type check. diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 893a85562976..3327dab4779b 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -14,6 +14,8 @@ configure_file( if(BUILD_LIBTORCHLESS) find_library(C10_CUDA_LIB c10_cuda PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH) +else() + set(C10_CUDA_LIB c10_cuda) endif() # Note: if you want to add ANY dependency to the c10 library, make sure you @@ -75,7 +77,6 @@ if(NOT BUILD_LIBTORCHLESS) $ $ $) - set(C10_CUDA_LIB c10_cuda) # ---[ Installation # Note: for now, we will put all export path into one single Caffe2Targets group diff --git a/c10/test/util/small_vector_test.cpp b/c10/test/util/small_vector_test.cpp index e05d21ce88f1..1efe4d4910e0 100644 --- a/c10/test/util/small_vector_test.cpp +++ b/c10/test/util/small_vector_test.cpp @@ -576,8 +576,8 @@ TYPED_TEST(SmallVectorTest, EraseTest) { SCOPED_TRACE("EraseTest"); this->makeSequence(this->theVector, 1, 3); - const auto& theConstVector = this->theVector; - this->theVector.erase(theConstVector.begin()); + auto& theVector = this->theVector; + this->theVector.erase(theVector.begin()); this->assertValuesInOrder(this->theVector, 2u, 2, 3); } @@ -586,8 +586,8 @@ TYPED_TEST(SmallVectorTest, EraseRangeTest) { SCOPED_TRACE("EraseRangeTest"); this->makeSequence(this->theVector, 1, 3); - const auto& theConstVector = this->theVector; - this->theVector.erase(theConstVector.begin(), theConstVector.begin() + 2); + auto& theVector = this->theVector; + this->theVector.erase(theVector.begin(), theVector.begin() + 2); this->assertValuesInOrder(this->theVector, 1u, 3); } diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 1f62adb9bb00..fe2044f507d4 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -54,11 +54,6 @@ make_unique_base(Args&&... args) { return std::unique_ptr(new Child(std::forward(args)...)); } -template -using conjunction = std::conjunction; -template -using disjunction = std::disjunction; - #if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__) template diff --git a/c10/util/Float8_fnuz_cvt.h b/c10/util/Float8_fnuz_cvt.h index 983063a0230f..327f90d11a71 100644 --- a/c10/util/Float8_fnuz_cvt.h +++ b/c10/util/Float8_fnuz_cvt.h @@ -4,6 +4,10 @@ #include +#if defined(SYCL_LANGUAGE_VERSION) +#include +#endif + namespace c10::detail { /* @@ -33,6 +37,8 @@ inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) uint32_t renorm_shift = __clz(mantissa); +#elif defined(__SYCL_DEVICE_ONLY__) + uint32_t renorm_shift = sycl::clz(mantissa); #elif defined(_MSC_VER) unsigned long nonsign_bsr; _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); diff --git a/c10/util/Half.h b/c10/util/Half.h index af3435941e48..afc90f106a6f 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -330,20 +331,12 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) { } #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) -constexpr inline float16_t fp16_from_bits(uint16_t h) { - union { - uint16_t as_bits; - float16_t as_value; - } fp16 = {h}; - return fp16.as_value; +inline float16_t fp16_from_bits(uint16_t h) { + return c10::bit_cast(h); } -constexpr inline uint16_t fp16_to_bits(float16_t f) { - union { - float16_t as_value; - uint16_t as_bits; - } fp16 = {.as_value = f}; - return fp16.as_bits; +inline uint16_t fp16_to_bits(float16_t f) { + return c10::bit_cast(f); } // According to https://godbolt.org/z/8s14GvEjo it would translate to single diff --git a/c10/util/SmallVector.cpp b/c10/util/SmallVector.cpp index 14b2fa9eb671..e30cdbf8dd3b 100644 --- a/c10/util/SmallVector.cpp +++ b/c10/util/SmallVector.cpp @@ -123,7 +123,7 @@ void* SmallVectorBase::mallocForGrow( // Note: Moving this function into the header may cause performance regression. template void SmallVectorBase::grow_pod( - void* FirstEl, + const void* FirstEl, size_t MinSize, size_t TSize) { size_t NewCapacity = getNewCapacity(MinSize, TSize, this->capacity()); diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index 919553811454..cbcfbc52cb8a 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -38,11 +38,6 @@ #include #include -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") -#endif - namespace c10 { /// This is all the stuff common to all SmallVectors. @@ -75,7 +70,7 @@ class C10_API SmallVectorBase { /// This is an implementation of the grow() method which only works /// on POD-like data types and is out of line to reduce code duplication. /// This function will report a fatal error if it cannot increase capacity. - void grow_pod(void* FirstEl, size_t MinSize, size_t TSize); + void grow_pod(const void* FirstEl, size_t MinSize, size_t TSize); public: SmallVectorBase() = delete; @@ -112,8 +107,10 @@ using SmallVectorSizeType = /// Figure out the offset of the first element. template struct SmallVectorAlignmentAndSize { + // NOLINTNEXTLINE(*c-arrays*) alignas(SmallVectorBase>) char Base[sizeof( SmallVectorBase>)]; + // NOLINTNEXTLINE(*c-arrays*) alignas(T) char FirstEl[sizeof(T)]; }; @@ -246,7 +243,7 @@ class SmallVectorTemplateCommon bool ReferencesStorage = false; int64_t Index = -1; - if (!U::TakesParamByValue) { + if constexpr (!U::TakesParamByValue) { if (C10_UNLIKELY(This->isReferenceToStorage(&Elt))) { ReferencesStorage = true; Index = &Elt - This->begin(); @@ -306,7 +303,7 @@ class SmallVectorTemplateCommon size_type size_in_bytes() const { return size() * sizeof(T); } - size_type max_size() const { + constexpr size_type max_size() const { return std::min(this->SizeTypeMax(), size_type(-1) / sizeof(T)); } @@ -475,6 +472,7 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon { this->set_size(this->size() + 1); } + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) void push_back(T&& Elt) { T* EltPtr = reserveForParamAndGetAddress(Elt); ::new ((void*)this->end()) T(::std::move(*EltPtr)); @@ -788,13 +786,9 @@ class SmallVectorImpl : public SmallVectorTemplateBase { assign(RHS.begin(), RHS.end()); } - iterator erase(const_iterator CI) { - // Just cast away constness because this is a non-const member function. - iterator I = const_cast(CI); - + iterator erase(iterator I) { assert( - this->isReferenceToStorage(CI) && - "Iterator to erase is out of bounds."); + this->isReferenceToStorage(I) && "Iterator to erase is out of bounds."); iterator N = I; // Shift all elts down one. @@ -804,11 +798,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase { return (N); } - iterator erase(const_iterator CS, const_iterator CE) { - // Just cast away constness because this is a non-const member function. - iterator S = const_cast(CS); - iterator E = const_cast(CE); - + iterator erase(iterator S, iterator E) { assert(this->isRangeInStorage(S, E) && "Range to erase is out of bounds."); iterator N = S; @@ -1402,6 +1392,7 @@ class /* LLVM_GSL_OWNER */ SmallVector : public SmallVectorImpl, .end())>::iterator_category, std::input_iterator_tag>, int> = 0> + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) SmallVector& operator=(Container&& C) { this->assign(C.begin(), C.end()); return *this; @@ -1439,6 +1430,7 @@ using ValueTypeFromRangeType = std::remove_const_t< /// SmallVector with elements of the vector. This is useful, for example, /// when you want to iterate a range and then sort the results. template +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) SmallVector, Size> to_vector(R&& Range) { return {std::begin(Range), std::end(Range)}; } @@ -1447,6 +1439,7 @@ SmallVector< ValueTypeFromRangeType, CalculateSmallVectorDefaultInlinedElements< ValueTypeFromRangeType>::value> +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) to_vector(R&& Range) { return {std::begin(Range), std::end(Range)}; } @@ -1472,5 +1465,3 @@ inline void swap( } } // end namespace std - -C10_CLANG_DIAGNOSTIC_POP() diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index 084c59c7d161..1f5254a3deda 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -41,10 +41,15 @@ std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString); #ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +// TODO (huydhn) https://en.cppreference.com/w/cpp/header/codecvt has been +// deprecated in C++17 but there is no alternative yet, so I just ack it std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString) { std::wstring_convert> converter; return _str(ss, converter.to_bytes(wString)); } +#pragma GCC diagnostic pop #else // #ifndef _WIN32 // The WIN32 implementation of wstring_convert leaks memory; see diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index 157a4f4be28d..88a91c84ef0f 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -142,7 +142,7 @@ struct C10_API SourceLocation { std::ostream& operator<<(std::ostream& out, const SourceLocation& loc); // unix isprint but insensitive to locale -inline static bool isPrint(char s) { +inline bool isPrint(char s) { return s > 0x1f && s < 0x7f; } diff --git a/c10/util/TypeSafeSignMath.h b/c10/util/TypeSafeSignMath.h index 7eb6d61c122e..2853ff48d183 100644 --- a/c10/util/TypeSafeSignMath.h +++ b/c10/util/TypeSafeSignMath.h @@ -16,7 +16,7 @@ namespace c10 { /// Returns false since we cannot have x < 0 if x is unsigned. template -static inline constexpr bool is_negative( +inline constexpr bool is_negative( const T& /*x*/, std::true_type /*is_unsigned*/) { return false; @@ -24,9 +24,7 @@ static inline constexpr bool is_negative( /// Returns true if a signed variable x < 0 template -static inline constexpr bool is_negative( - const T& x, - std::false_type /*is_unsigned*/) { +inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { return x < T(0); } @@ -42,15 +40,13 @@ inline constexpr bool is_negative(const T& x) { /// Returns the sign of an unsigned variable x as 0, 1 template -static inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { +inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { return T(0) < x; } /// Returns the sign of a signed variable x as -1, 0, 1 template -static inline constexpr int signum( - const T& x, - std::false_type /*is_unsigned*/) { +inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { return (T(0) < x) - (x < T(0)); } @@ -92,7 +88,7 @@ inline constexpr bool greater_than_max(const T& x) { /// Returns true if x < lowest(Limit). Standard comparison template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& x, std::false_type /*limit_is_unsigned*/, std::false_type /*x_is_unsigned*/) { @@ -102,7 +98,7 @@ static inline constexpr bool less_than_lowest( /// Returns false since all the limit is signed and therefore includes /// negative values but x cannot be negative because it is unsigned template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& /*x*/, std::false_type /*limit_is_unsigned*/, std::true_type /*x_is_unsigned*/) { @@ -112,7 +108,7 @@ static inline constexpr bool less_than_lowest( /// Returns true if x < 0, where 0 is constructed from T. /// Limit is not signed, so its lower value is zero template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& x, std::true_type /*limit_is_unsigned*/, std::false_type /*x_is_unsigned*/) { @@ -121,7 +117,7 @@ static inline constexpr bool less_than_lowest( /// Returns false sign both types are unsigned template -static inline constexpr bool less_than_lowest( +inline constexpr bool less_than_lowest( const T& /*x*/, std::true_type /*limit_is_unsigned*/, std::true_type /*x_is_unsigned*/) { diff --git a/c10/util/int128.h b/c10/util/int128.h index b97a59446da2..7da595b79178 100644 --- a/c10/util/int128.h +++ b/c10/util/int128.h @@ -49,7 +49,7 @@ struct uint128_pod; #endif class uint128; -static inline uint128& operator<<=(uint128& self, int amount); +inline uint128& operator<<=(uint128& self, int amount); // An unsigned 128-bit integer type. Thread-compatible. class C10_API uint128 { @@ -277,7 +277,7 @@ inline uint128 operator>>(const uint128& val, int amount) { } } -static inline uint128& operator<<=(uint128& self, int amount) { +inline uint128& operator<<=(uint128& self, int amount) { // uint64_t shifts of >= 64 are undefined, so we will need some // special-casing. if (amount < 64) { diff --git a/c10/util/strides.h b/c10/util/strides.h index 980540b5b97a..d3d38fd7d011 100644 --- a/c10/util/strides.h +++ b/c10/util/strides.h @@ -6,7 +6,7 @@ namespace c10 { // Computes the contiguous strides of a tensor, given its sizes. -static inline DimVector contiguous_strides(const IntArrayRef sizes) { +inline DimVector contiguous_strides(const IntArrayRef sizes) { using Int = IntArrayRef::value_type; const Int dims = static_cast(sizes.size()); diff --git a/c10/xpu/CMakeLists.txt b/c10/xpu/CMakeLists.txt index d06d0f0aa92a..b5c63d4f7cca 100644 --- a/c10/xpu/CMakeLists.txt +++ b/c10/xpu/CMakeLists.txt @@ -8,6 +8,12 @@ if(NOT BUILD_LIBTORCHLESS) find_library(C10_XPU_LIB c10_xpu PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH) endif() +# ---[ Configure macro file. +set(C10_XPU_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in xpu_cmake_macros.h.in +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/impl/xpu_cmake_macros.h.in + ${CMAKE_BINARY_DIR}/c10/xpu/impl/xpu_cmake_macros.h) + set(C10_XPU_SRCS XPUCachingAllocator.cpp XPUFunctions.cpp @@ -50,3 +56,7 @@ foreach(file ${C10_XPU_HEADERS}) get_filename_component(dir ${file} DIRECTORY) install(FILES ${file} DESTINATION include/c10/xpu/${dir}) endforeach() + +if(MSVC AND C10_XPU_BUILD_SHARED_LIBS) + install(FILES $ DESTINATION lib OPTIONAL) +endif() diff --git a/c10/xpu/XPUFunctions.cpp b/c10/xpu/XPUFunctions.cpp index 15e24d94f5dc..cc885776a916 100644 --- a/c10/xpu/XPUFunctions.cpp +++ b/c10/xpu/XPUFunctions.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include namespace c10::xpu { @@ -53,10 +51,20 @@ inline void initGlobalDevicePoolState() { return; } +#ifdef _WIN32 + // default context feature is disabled by default on Windows. + std::vector deviceList; + for (auto it = gDevicePool.devices.begin(); it != gDevicePool.devices.end(); + ++it) { + deviceList.push_back(*(*it)); + } + gDevicePool.context = std::make_unique(deviceList); +#else // The default context is utilized for each Intel GPU device, allowing the // retrieval of the context from any GPU device. gDevicePool.context = std::make_unique( gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context()); +#endif } inline void initDevicePoolCallOnce() { diff --git a/c10/xpu/XPUMacros.h b/c10/xpu/XPUMacros.h index fc6aad92229c..d51eab989d25 100644 --- a/c10/xpu/XPUMacros.h +++ b/c10/xpu/XPUMacros.h @@ -1,15 +1,29 @@ #pragma once +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif + // See c10/macros/Export.h for a detailed explanation of what the function // of these macros are. We need one set of macros for every separate library // we build. +#ifdef _WIN32 +#if defined(C10_XPU_BUILD_SHARED_LIBS) +#define C10_XPU_EXPORT __declspec(dllexport) +#define C10_XPU_IMPORT __declspec(dllimport) +#else +#define C10_XPU_EXPORT +#define C10_XPU_IMPORT +#endif +#else // _WIN32 #if defined(__GNUC__) #define C10_XPU_EXPORT __attribute__((__visibility__("default"))) #else // defined(__GNUC__) #define C10_XPU_EXPORT #endif // defined(__GNUC__) #define C10_XPU_IMPORT C10_XPU_EXPORT +#endif // _WIN32 // This one is being used by libc10_xpu.so #ifdef C10_XPU_BUILD_MAIN_LIB diff --git a/c10/xpu/impl/xpu_cmake_macros.h.in b/c10/xpu/impl/xpu_cmake_macros.h.in new file mode 100644 index 000000000000..48ed78c07e1d --- /dev/null +++ b/c10/xpu/impl/xpu_cmake_macros.h.in @@ -0,0 +1,6 @@ +#pragma once + +// Automatically generated header file for the C10 XPU library. Do not +// include this file directly. Instead, include c10/xpu/XPUMacros.h + +#cmakedefine C10_XPU_BUILD_SHARED_LIBS diff --git a/c2_defs.bzl b/c2_defs.bzl deleted file mode 100644 index 3cca448b394c..000000000000 --- a/c2_defs.bzl +++ /dev/null @@ -1,509 +0,0 @@ -load("@bazel_skylib//lib:collections.bzl", "collections") -load("@bazel_skylib//lib:paths.bzl", "paths") -load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule") -load("@fbsource//tools/build_defs:default_platform_defs.bzl", "compose_platform_setting_list") -load("@fbsource//tools/build_defs:dict_defs.bzl", "dict_defs") -load("@fbsource//tools/build_defs:expect.bzl", "expect") -load("@fbsource//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") -load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode", "is_fbcode_mode_mac") -load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX", "WINDOWS") -load("@fbsource//tools/build_defs/apple:build_mode_defs.bzl", "is_production_build") -load("@fbsource//xplat/caffe2:buckbuild.bzl", "read_bool") -load("@fbsource//xplat/pfh/Msgr/Mobile/ProductInfra:DEFS.bzl", "Msgr_Mobile_ProductInfra") - -def get_c2_expose_op_to_c10(): - c2_op_to_c10 = native.read_config("caffe2", "expose_op_to_c10", "0") - - expect( - c2_op_to_c10 in ("0", "1"), - c2_op_to_c10, - ) - - return bool(int(c2_op_to_c10)) - -def get_c2_mpscnn(): - c2_mpscnn = native.read_config("caffe2", "enable_mpscnn", "1") - - expect( - c2_mpscnn in ("0", "1"), - c2_mpscnn, - ) - - return bool(int(c2_mpscnn)) - -def get_c2_mpscnn_test(): - c2_mpscnn_test = native.read_config("caffe2", "enable_mpscnn_test", "0") - - expect( - c2_mpscnn_test in ("0", "1"), - c2_mpscnn_test, - ) - - return bool(int(c2_mpscnn_test)) - -def get_c2_qpl(): - c2_qpl = native.read_config("caffe2", "enable_qpl", "1") - - expect( - c2_qpl in ("0", "1"), - c2_qpl, - ) - - return bool(int(c2_qpl)) - -def get_c2_strip_debug_info(): - c2_strip_debug_info = native.read_config("caffe2", "strip_debug_info", "0") - - expect( - c2_strip_debug_info in ("0", "1"), - c2_strip_debug_info, - ) - - return bool(int(c2_strip_debug_info)) - -def get_c2_strip_glog(): - c2_strip_glog = native.read_config("caffe2", "strip_glog", "1") - - expect( - c2_strip_glog in ("0", "1"), - c2_strip_glog, - ) - - return bool(int(c2_strip_glog)) - -def get_c2_tvm(): - c2_tvm = native.read_config("caffe2", "enable_tvm", "1") - - expect( - c2_tvm in ("0", "1"), - c2_tvm, - ) - - return bool(int(c2_tvm)) - -_C2_XPLAT_NO_HPTT_PREPROCESSOR_FLAGS = [ - "-Icaffe2", - "-Imodules", - "-DEIGEN_NO_DEBUG", - "-DCAFFE2_USE_LITE_PROTO", - "-DCAFFE2_USE_GOOGLE_GLOG", - "-DCAFFE2_RNN_NO_TEXT_FORMAT", - "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK=1", - "-DCAFFE2_IS_XPLAT_BUILD", - "-DSTRIP_ERROR_MESSAGES", - "-DUSE_INTERNAL_PTHREADPOOL_IMPL", -] - -def get_c2_xplat_no_hptt_preprocessor_flags(): - flags = [] - flags += _C2_XPLAT_NO_HPTT_PREPROCESSOR_FLAGS - if is_arvr_mode() and get_c2_strip_glog(): - flags += ["-UGOOGLE_STRIP_LOG", "-DGOOGLE_STRIP_LOG=1"] - if get_c2_expose_op_to_c10(): - flags += ["-DEXPOSE_C2_OPS", "-frtti"] - return flags - -C2_XPLAT_SERVER_PREPROCESSOR_FLAGS = [ - "-DCAFFE2_USE_EIGEN_FOR_BLAS", - "-DC10_DISABLE_SIGNAL_HANDLERS", - "-DCAFFE2_DISABLE_NUMA", -] - -C2_XPLAT_HPTT_PREPROCESSOR_FLAGS = [ - "-DCAFFE2_USE_HPTT", -] - -def get_c2_xplat_preprocessor_flags(): - flags = get_c2_xplat_no_hptt_preprocessor_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS - return flags - -def get_c2_xplat_no_hptt_compiler_flags(): - return [ - "-Os", - "-fexceptions", - "-frtti", - "-Wno-shadow", - "-Wno-unknown-pragmas", - "-Wno-unused-variable", - "-Wno-sign-compare", - ] - -def get_c2_xplat_compiler_flags(): - return get_c2_xplat_no_hptt_compiler_flags() + C2_XPLAT_HPTT_PREPROCESSOR_FLAGS - -def get_c2_fbobjc_xplat_compiler_flags(): - flags = [] - - if is_production_build(): - flags.append("-DCAFFE2_NO_OPERATOR_SCHEMA") - - flags.append("-DCAFFE2_NO_GRADIENT_OPS") - - # For iOS production builds (and all Android builds), strip GLOG logging to - # save size. We can disable by setting caffe2.strip_glog=0 in .buckconfig.local. - if is_production_build() or get_c2_strip_glog(): - flags += ["-UGOOGLE_STRIP_LOG", "-DGOOGLE_STRIP_LOG=3"] - else: - flags.append("-UGOOGLE_STRIP_LOG") - - return flags - -def get_c2_fbandroid_xplat_compiler_flags(): - flags = [ - "-Wno-unused-but-set-variable", - "-DHAVE_MMAP", - ] - - if get_c2_strip_glog(): - flags += ["-UGOOGLE_STRIP_LOG", "-DGOOGLE_STRIP_LOG=1"] - - if get_c2_strip_debug_info(): - flags.append("-g0") - - return flags - -_C2_FBOBJC_COMPILER_FLAGS = [ - "-Wno-missing-prototypes", - "-Wno-global-constructors", - "-Wno-unknown-pragmas", - "-Wno-invalid-partial-specialization", - "-Wno-missing-braces", - "-Wno-range-loop-analysis", -] - -def get_c2_fbobjc_compiler_flags(): - flags = list(_C2_FBOBJC_COMPILER_FLAGS) - - # Avoid linking Accelerate on MacOS because we have - # inconsistent LAPACK headers (see problems in D19257077). - flags.append("-DCAFFE2_USE_ACCELERATE" if not is_arvr_mode() else "-DCAFFE2_USE_EIGEN_FOR_BLAS") - if get_c2_mpscnn(): - flags.append( - # TODO(t19120552) - fix this. MPSCNNConvolutionDescriptor.strideInPixelsX - # is marked as iOS 11+, but it's been available since iOS 10. - "-Wno-unguarded-availability", - ) - return flags - -C2_FBOBJC_MACOSX_COMPILER_FLAGS = [ - "-msse4.2", -] - -C2_FBOBJC_IPHONE_COMPILER_FLAGS = [ - "-mfpu=neon-fp16", -] - -def get_c2_fbobjc_frameworks(): - frameworks = [] - if not is_arvr_mode(): - frameworks.append( - # On iOS, presumably Accelerate is a faster BLAS - "$SDKROOT/System/Library/Frameworks/Accelerate.framework", - ) - return frameworks - -def get_c2_fbobjc_ios_frameworks(): - frameworks = [] - - if get_c2_mpscnn(): - frameworks.extend([ - "$SDKROOT/System/Library/Frameworks/Metal.framework", - "$SDKROOT/System/Library/Frameworks/MetalPerformanceShaders.framework", - ]) - - return frameworks - -def get_c2_fbobjc_exported_preprocessor_flags(): - flags = [] - - if get_c2_mpscnn(): - flags.append("-DCAFFE2_USE_MPSCNN") - - if get_c2_mpscnn_test(): - flags.append("-DCAFFE2_USE_MPSCNN_TEST") - - return flags - -def get_c2_fbandroid_exported_preprocessor_flags(): - flags = [] - - BUILD_MODE_DO_NOT_USE_WITHOUT_ASKING_SERIOUSLY = native.read_config( - "fbandroid", - "build_mode", - "dev", - ) - if BUILD_MODE_DO_NOT_USE_WITHOUT_ASKING_SERIOUSLY == "opt": - flags.append("-DCAFFE2_NO_OPERATOR_SCHEMA") - - flags.append("-DCAFFE2_NO_GRADIENT_OPS") - - return flags - -C2_FBANDROID_COMPILER_FLAGS = [ - "-DCAFFE2_USE_EIGEN_FOR_BLAS", - "-Wno-unknown-pragmas", - "-Wno-deprecated-declarations", - "-Wno-invalid-partial-specialization", - "-Wno-missing-braces", -] - -C2_FBANDROID_ARMV7_COMPILER_FLAGS = [ - "-mfpu=neon-fp16", -] - -C2_FBANDROID_X86_COMPILER_FLAGS = [ - "-mssse3", -] - -C2_FBANDROID_LINKER_FLAGS = [] - -C2_FBOBJC_EXTRA_TARGET_CONFIG = { - "MTL_LANGUAGE_REVISION": "Metal12", -} - -def get_c2_torch_vulkan_compiler_flags(): - return ["-Wno-missing-prototypes"] - -def get_c2_default_cxx_args(): - return dict( - header_namespace = "", - apple_sdks = (IOS, MACOSX), - compiler_flags = get_c2_xplat_compiler_flags(), - fbandroid_compiler_flags = C2_FBANDROID_COMPILER_FLAGS + get_c2_fbandroid_xplat_compiler_flags(), - fbandroid_exported_platform_preprocessor_flags = [ - ( - "android-armv7", - get_c2_fbandroid_exported_preprocessor_flags(), - ), - ], - fbandroid_linker_flags = C2_FBANDROID_LINKER_FLAGS, - fbandroid_platform_compiler_flags = [ - ("android-armv7", C2_FBANDROID_ARMV7_COMPILER_FLAGS), - (".*x86.*", C2_FBANDROID_X86_COMPILER_FLAGS), - ], - fbobjc_compiler_flags = get_c2_fbobjc_compiler_flags() + get_c2_fbobjc_xplat_compiler_flags(), - fbobjc_exported_platform_preprocessor_flags = [ - ( - "iphoneos", - get_c2_fbobjc_exported_preprocessor_flags(), - ), - ], - fbobjc_frameworks = get_c2_fbobjc_frameworks() + get_c2_fbobjc_ios_frameworks(), - fbobjc_platform_compiler_flags = [ - ("iphoneos", C2_FBOBJC_IPHONE_COMPILER_FLAGS), - ], - macosx_compiler_flags = C2_FBOBJC_MACOSX_COMPILER_FLAGS, - macosx_frameworks_override = get_c2_fbobjc_frameworks(), - preprocessor_flags = [ - # Use the internal pthreadpool impl for all Caffe2 targets on all - # platforms but do not export the preprocessor flag downstream. - "-DUSE_INTERNAL_PTHREADPOOL_IMPL", - ], - visibility = ["PUBLIC"], - windows_preferred_linkage = "static" if is_arvr_mode() else None, - ) - -def get_c2_aten_cpu_fbobjc_macosx_deps(): - return select({ - "DEFAULT": [], - "ovr_config//os:macos-x86_64": ["fbsource//xplat/deeplearning/fbgemm:fbgemm"], - }) if is_arvr_mode() else [] - -def build_cpukernel_avx2(): - return read_bool("caffe2", "build_cpukernel_avx2", not is_arvr_mode()) - -def get_c2_aten_cpu_fbobjc_macosx_platform_deps(): - return compose_platform_setting_list([ - { - "cpu": "x86_64", - "flags": [ - "fbsource//xplat/deeplearning/fbgemm:fbgemmAppleMac", - ] + ([ - "fbsource//xplat/caffe2:cpukernel_avx2AppleMac", - ] if build_cpukernel_avx2() else []), - "os": "macosx", - }, - { - "cpu": "arm64", - "flags": ["fbsource//xplat/third-party/XNNPACK:XNNPACKAppleMac"], - "os": "macosx", - }, - ]) - -def using_protobuf_v3(): - # Consider migrating this to `read_config("protobuf", "use_v3")` - # The `is_fbcode_mode_mac()` clause was added rather than changing to `read_config` to minimize changes in behavior - return is_arvr_mode() or is_fbcode_mode_mac() - -def get_c2_protobuf_dep(): - return "fbsource//third-party/protobuf:libprotobuf" if using_protobuf_v3() else "fbsource//xplat/third-party/protobuf:fb-protobuf-lite" - -def c2_cxx_library(fbobjc_compiler_flags = [], **kwargs): - args = get_c2_default_cxx_args() - args.update(kwargs) - args.setdefault("platforms", (ANDROID, APPLE, CXX, WINDOWS)) - - # Make sure we don't overwrite custom `fbobjc_compiler_flags` - args["fbobjc_compiler_flags"] = args.pop("fbobjc_compiler_flags", []) + fbobjc_compiler_flags - - fb_xplat_cxx_library( - labels = [ - "supermodule:android/default/caffe2", - "supermodule:ios/default/public.caffe2", - ], - feature = Msgr_Mobile_ProductInfra, - **args - ) - -def c2_protobuf_rule(protos): - cpps = [] - headers = {} - raw_headers = {} - for p in protos: - proto = paths.basename(p) - protocexe = "$(exe fbsource//third-party/protobuf:protoc-host)" if is_arvr_mode() else "$(location fbsource//xplat/third-party/protobuf:protoc.Windows)" - protocmd_exe = "powershell.exe -file $(location fbsource//xplat/caffe2/scripts:proto)\\proto.ps1 -Protoc {} -Unprocessed $SRCDIR/{} -Processed $SRCDIR/{} -out $OUT -srcdir $SRCDIR".format(protocexe, p, proto) - protocmd = ("cp $SRCDIR/{} $SRCDIR/{} && chmod +w $SRCDIR/{} && echo \"option optimize_for = LITE_RUNTIME;\" >> $SRCDIR/{} && ".format(p, proto, proto, proto) + - "cp $SRCDIR/caffe2/proto/caffe2.proto $SRCDIR/caffe2.proto && chmod +w $SRCDIR/caffe2.proto && echo \"option optimize_for = LITE_RUNTIME;\" >> $SRCDIR/caffe2.proto && " + - "sed -i -e 's/caffe2\\/proto\\/caffe2.proto/caffe2.proto/g' $SRCDIR/{} && ".format(proto) + - ("$(exe fbsource//third-party/protobuf:protoc-host) " if using_protobuf_v3() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + - "-I $SRCDIR --cpp_out=$OUT $SRCDIR/{}".format(proto)) - buck_genrule( - name = proto, - srcs = sorted(collections.uniq([p, "caffe2/proto/caffe2.proto"])), - cmd_exe = protocmd_exe, - bash = protocmd, - out = ".", - ) - (name, _) = paths.split_extension(proto) - cpp = name + ".pb.cc" - h = name + ".pb.h" - buck_genrule( - name = h, - cmd_exe = "@powershell -Command \" & { " + "(Get-Content $(location :{})\\{}".format(proto, h) + ") -replace \\\"caffe2.pb.h\\\", \\\"caffe2/proto/caffe2.pb.h\\\" | Set-Content $OUT } \"", - bash = "cp -f $(location :{})/{} $OUT && ".format(proto, h) + - "sed -i -e 's/caffe2.pb.h/caffe2\\/proto\\/caffe2.pb.h/g' $OUT", - out = h, - ) - headers["caffe2/proto/" + h] = ":{}".format(h) - raw_headers[h] = ":{}".format(h) - buck_genrule( - name = cpp, - cmd_exe = "@powershell -Command copy $(location :{})/{} $OUT".format(proto, cpp), - bash = "cp -f $(location :{})/{} $OUT".format(proto, cpp), - out = cpp, - ) - cpps.append(":{}".format(cpp)) - return (cpps, headers, raw_headers) - -# C2 uses lite version of protobuf while torch/jit uses some method only exists -# in full protobuf. This is a temporary workaround to enable experiment build. -# DO NOT USE IT IN PRODUCTION BUILD! -def c2_full_protobuf_rule(protos): - prefix = "full_" - cpps = [] - headers = {} - raw_headers = {} - for p in protos: - proto = paths.basename(p) - protocexe = "$(exe fbsource//third-party/protobuf:protoc-host)" if is_arvr_mode() else "$(location fbsource//xplat/third-party/protobuf:protoc.Windows)" - protocmd_exe = "powershell.exe -file $(location fbsource//xplat/caffe2/scripts:proto)\\proto.ps1 -Protoc {} -Unprocessed $SRCDIR/{} -Processed $SRCDIR/{} -out $OUT -srcdir $SRCDIR".format(protocexe, p, proto) - protocmd = ("cp $SRCDIR/{} $SRCDIR/{} && ".format(p, proto) + - "cp $SRCDIR/caffe2/proto/caffe2.proto $SRCDIR/caffe2.proto && " + - "sed -i -e 's/caffe2\\/proto\\/caffe2.proto/caffe2.proto/g' $SRCDIR/{} && ".format(proto) + - ("$(exe fbsource//third-party/protobuf:protoc-host) " if using_protobuf_v3() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + - "-I $SRCDIR --cpp_out=$OUT $SRCDIR/{}".format(proto)) - buck_genrule( - name = prefix + proto, - srcs = sorted(collections.uniq([p, "caffe2/proto/caffe2.proto"])), - cmd = protocmd, - cmd_exe = protocmd_exe, - out = ".", - ) - (name, _) = paths.split_extension(proto) - cpp = name + ".pb.cc" - h = name + ".pb.h" - buck_genrule( - name = prefix + h, - cmd_exe = "@powershell -Command \" & { " + "(Get-Content $(location :{})\\{}".format(prefix + proto, h) + ") -replace \\\"caffe2.pb.h\\\", \\\"caffe2/proto/caffe2.pb.h\\\" | Set-Content $OUT } \"", - bash = "cp -f $(location :{})/{} $OUT && ".format(prefix + proto, h) + - "sed -i -e 's/caffe2.pb.h/caffe2\\/proto\\/caffe2.pb.h/g' $OUT", - out = h, - ) - headers["caffe2/proto/" + h] = ":{}".format(prefix + h) - raw_headers[h] = ":{}".format(prefix + h) - buck_genrule( - name = prefix + cpp, - cmd_exe = "@powershell -Command copy $(location :{})/{} $OUT".format(prefix + proto, cpp), - bash = "cp -f $(location :{})/{} $OUT".format(prefix + proto, cpp), - out = cpp, - ) - cpps.append(":{}".format(prefix + cpp)) - return (cpps, headers, raw_headers) - -def libcaffe2_cxx_library(name, use_hptt, **kwargs): - c2_cxx_library( - name = name, - exported_deps = [ - "fbsource//xplat/caffe2/c10:c10", - get_c2_protobuf_dep(), - ":caffe2_protobuf_headers", - ":pthreadpool", - ":common_core", - ":caffe2_proto_types", - ], - compiler_flags = get_c2_xplat_compiler_flags() if use_hptt else get_c2_xplat_no_hptt_compiler_flags(), - exported_preprocessor_flags = get_c2_xplat_preprocessor_flags() if use_hptt else get_c2_xplat_no_hptt_preprocessor_flags(), - cxx_preprocessor_flags = C2_XPLAT_SERVER_PREPROCESSOR_FLAGS, - fbandroid_exported_preprocessor_flags = get_c2_fbandroid_xplat_compiler_flags(), - fbobjc_exported_preprocessor_flags = get_c2_fbobjc_xplat_compiler_flags(), - # Hack to work around lack of platform_srcs support in Xcode project generation. - macosx_extra_xcode_sources_override = [], - link_whole = True, - **kwargs - ) - -def c2_operator_library(name, **kwargs): - dict_defs.key_extend( - kwargs, - "deps", - [ - "fbsource//xplat/folly:molly", - "fbsource//third-party/glog:glog", - ":caffe2", - ] + ([":aten_cpu"] if get_c2_expose_op_to_c10() else []), - ) - - # NOTE: Currently operators can "depend" on other operators, which is used - # so that loading one will implicitly load the dependencies. So, make sure - # that no `--as-needed` flags pulled in from dependencies cause these - # operator deps to get dropped. - linker_flags = [] if (read_config("caffe2", "link_as_needed", "0") == "1") else ["-Wl,--no-as-needed"] - c2_cxx_library( - name = name, - soname = "lib" + name + ".$(ext)", - fbandroid_compiler_flags = get_c2_default_cxx_args()["fbandroid_compiler_flags"] + ["-Os"], - fbobjc_compiler_flags = get_c2_default_cxx_args()["fbobjc_compiler_flags"] + ["-Oz", "-DCOMPILING_FOR_MIN_SIZE=1"], - link_whole = True, - cxx_exported_linker_flags = linker_flags, - fbandroid_exported_linker_flags = linker_flags, - exported_deps = [ - ":caffe2", - ], - **kwargs - ) - -def c2_genrule(genrule, genfiles, prefix = "", src_path = "", header_namespace = ""): - headers = {} - srcs = [] - for generated_filename in genfiles: - buck_genrule( - name = prefix + generated_filename, - bash = "cp -f $(location :{})/{} $OUT".format(genrule, src_path + generated_filename), - cmd_exe = "@powershell -Command copy $(location :{})/{} $OUT".format(genrule, src_path + generated_filename), - out = generated_filename, - ) - rule = ":{}{}".format(prefix, generated_filename) - headers[header_namespace + generated_filename] = rule - srcs.append(rule) - return {"headers": headers, "srcs": srcs} diff --git a/c2_test_defs.bzl b/c2_test_defs.bzl deleted file mode 100644 index 8ef83073d6fa..000000000000 --- a/c2_test_defs.bzl +++ /dev/null @@ -1,20 +0,0 @@ -load("@fbsource//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test") -load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX") -load("@fbsource//xplat/caffe2:c2_defs.bzl", "get_c2_default_cxx_args") - -def c2_cxx_test(**kwargs): - args = get_c2_default_cxx_args() - args.update(kwargs) - args["fbandroid_use_instrumentation_test"] = True - for flag in [ - "macosx_compiler_flags", - "fbobjc_macosx_configs_override", - "macosx_frameworks_override", - "xcode_public_headers_symlinks", - "macosx_inherited_buck_flags_override", - ]: - args.pop(flag, None) - args["apple_sdks"] = (IOS, MACOSX) - args["platforms"] = (CXX, APPLE, ANDROID) - args["contacts"] = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"] - fb_xplat_cxx_test(**args) diff --git a/caffe2/BUILD_MODE.bzl b/caffe2/BUILD_MODE.bzl deleted file mode 100644 index 1fbd3e6f7a47..000000000000 --- a/caffe2/BUILD_MODE.bzl +++ /dev/null @@ -1,23 +0,0 @@ -""" build mode definitions for caffe2/caffe2 """ - -load("@fbcode//:BUILD_MODE.bzl", get_parent_modes = "all_modes_keep_gpu_sections_all_modes_use_lld") -load("@fbcode_macros//build_defs:create_build_mode.bzl", "extend_build_mode") - -def update_mode_struct(name, mode_struct): - if name == "dev": - return extend_build_mode( - mode_struct, - # TODO(ipbrady): Modules introduce floating point inaccuracies (T43879333) - cxx_modules = False, - ) - else: - return mode_struct - -_modes = { - mode_name: update_mode_struct(mode_name, mode_struct) - for mode_name, mode_struct in get_parent_modes().items() -} - -def get_modes(): - """ Return modes for this file """ - return _modes diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 1e29044e19fd..89c31fab1134 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -104,10 +104,6 @@ if(NOT USE_FBGEMM) add_subdirectory(perfkernels) endif() -if(NOT INTERN_BUILD_MOBILE) - add_subdirectory(proto) -endif() - # Advanced: if we have allow list specified, we will do intersections for all # main lib srcs. if(CAFFE2_ALLOWLISTED_FILES) @@ -218,44 +214,6 @@ if(PRINT_CMAKE_DEBUG_INFO) endif() -if(NOT INTERN_BUILD_MOBILE) - # ---[ List of libraries to link with - add_library(caffe2_protos STATIC $) - add_dependencies(caffe2_protos Caffe2_PROTO) - # If we are going to link protobuf locally inside caffe2 libraries, what we will do is - # to create a helper static library that always contains libprotobuf source files, and - # link the caffe2 related dependent libraries to it. - target_include_directories(caffe2_protos INTERFACE $) - # Reason for this public dependency is as follows: - # (1) Strictly speaking, we should not expose any Protobuf related functions. We should - # only use function interfaces wrapped with our own public API, and link protobuf - # locally. - # (2) However, currently across the Caffe2 codebase, we have extensive use of protobuf - # functionalities. For example, not only libcaffe2.so uses it, but also other - # binaries such as python extensions etc. As a result, we will have to have a - # transitive dependency to libprotobuf. - # - # Good thing is that, if we specify CAFFE2_LINK_LOCAL_PROTOBUF, then we do not need to - # separately deploy protobuf binaries - libcaffe2.so will contain all functionalities - # one needs. One can verify this via ldd. - # - # TODO item in the future includes: - # (1) Enable using lite protobuf - # (2) Properly define public API that do not directly depend on protobuf itself. - # (3) Expose the libprotobuf.a file for dependent libraries to link to. - # - # What it means for users/developers? - # (1) Users: nothing affecting the users, other than the fact that CAFFE2_LINK_LOCAL_PROTOBUF - # avoids the need to deploy protobuf. - # (2) Developers: if one simply uses core caffe2 functionality without using protobuf, - # nothing changes. If one has a dependent library that uses protobuf, then one needs to - # have the right protobuf version as well as linking to libprotobuf.a. - target_link_libraries(caffe2_protos PUBLIC protobuf::libprotobuf) - if(NOT BUILD_SHARED_LIBS) - install(TARGETS caffe2_protos ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") - endif() -endif() - # ========================================================== # formerly-libtorch # ========================================================== @@ -388,7 +346,7 @@ add_custom_command( OUTPUT ${TORCH_GENERATED_CODE} COMMAND - "${Python_EXECUTABLE}" tools/setup_helpers/generate_code.py + Python::Interpreter tools/setup_helpers/generate_code.py --native-functions-path "aten/src/ATen/native/native_functions.yaml" --tags-path "aten/src/ATen/native/tags.yaml" $<$:--disable-autograd> @@ -814,7 +772,7 @@ if(NOT MSVC) set_source_files_properties(${PROJECT_SOURCE_DIR}/torch/csrc/distributed/c10d/socket.cpp PROPERTIES COMPILE_OPTIONS "-Wno-error=deprecated") endif() -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_PYTORCH_METAL AND NOT USE_COREML_DELEGATE) +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) @@ -939,7 +897,6 @@ if(USE_ROCM) hip_add_library(torch_hip ${Caffe2_HIP_SRCS}) if(USE_FLASH_ATTENTION) target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) - add_dependencies(torch_hip aotriton_external) endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake @@ -1105,8 +1062,15 @@ if(USE_XPU) message(WARNING "Failed to include ATen XPU implementation target") else() target_link_libraries(torch_xpu PRIVATE torch_xpu_ops) - target_link_libraries(torch_xpu PRIVATE - "-Wl,--whole-archive,\"$\" -Wl,--no-whole-archive") + if(MSVC) + # Windows + target_link_libraries(torch_xpu PRIVATE + "-WHOLEARCHIVE:\"$\"") + else() + # Linux + target_link_libraries(torch_xpu PRIVATE + "-Wl,--whole-archive,\"$\" -Wl,--no-whole-archive") + endif() endif() endif() @@ -1118,13 +1082,6 @@ endif() # formerly-libtorch flags # ========================================================== -if(NOT INTERN_BUILD_MOBILE) - # Forces caffe2.pb.h to be generated before its dependents are compiled. - # Adding the generated header file to the ${TORCH_SRCS} list is not sufficient - # to establish the dependency, since the generation procedure is declared in a different CMake file. - # See https://samthursfield.wordpress.com/2015/11/21/cmake-dependencies-between-targets-and-files-and-custom-commands/#custom-commands-in-different-directories - add_dependencies(torch_cpu Caffe2_PROTO) -endif() # Build model tracer for tracing-based selective build if(TRACING_BASED AND NOT BUILD_LITE_INTERPRETER AND NOT INTERN_BUILD_MOBILE) @@ -1144,7 +1101,7 @@ if(BUILD_LITE_INTERPRETER AND SELECTED_OP_LIST) add_custom_command( OUTPUT ${CMAKE_BINARY_DIR}/aten/src/ATen/selected_mobile_ops.h COMMAND - "${Python_EXECUTABLE}" + Python::Interpreter -m tools.code_analyzer.gen_oplist --model_file_list_path "${SELECTED_OP_LIST}" --output_dir "${CMAKE_BINARY_DIR}/aten/src/ATen" @@ -1159,7 +1116,7 @@ if(BUILD_LITE_INTERPRETER AND SELECTED_OP_LIST) add_custom_command( OUTPUT ${CMAKE_BINARY_DIR}/aten/src/ATen/selected_mobile_ops.h COMMAND - "${Python_EXECUTABLE}" + Python::Interpreter -m tools.lite_interpreter.gen_selected_mobile_ops_header --yaml_file_path "${SELECTED_OP_LIST}" --output_file_path "${CMAKE_BINARY_DIR}/aten/src/ATen" @@ -1229,6 +1186,9 @@ if(USE_KINETO) ${TORCH_ROOT}/third_party/kineto/libkineto/src) endif() +target_include_directories(torch_cpu PRIVATE + ${TORCH_ROOT}/third_party/cpp-httplib) + install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") @@ -1361,6 +1321,9 @@ if(USE_ROCM) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) endif() + if(USE_MEM_EFF_ATTENTION) + target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) + endif() endif() if(BUILD_LITE_INTERPRETER) @@ -1414,8 +1377,6 @@ if(USE_DISTRIBUTED) endif() if(NOT INTERN_BUILD_MOBILE) - caffe2_interface_library(caffe2_protos caffe2_protos_whole) - target_link_libraries(torch_cpu PRIVATE caffe2_protos_whole) if(${CAFFE2_LINK_LOCAL_PROTOBUF}) target_link_libraries(torch_cpu INTERFACE protobuf::libprotobuf) else() @@ -1811,9 +1772,6 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE $) target_include_directories(${test_name} PRIVATE $) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) - if(NOT MSVC) - target_compile_options(${test_name} PRIVATE -Wno-unused-variable) - endif() add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) install(TARGETS ${test_name} DESTINATION test) @@ -1939,56 +1897,10 @@ endif() # only rerun when needed. if(BUILD_PYTHON) - # Python site-packages - # Get canonical directory for python site packages (relative to install - # location). It varies from system to system. - # We should pin the path separator to the forward slash on Windows. - # More details can be seen at - # https://github.com/pytorch/pytorch/tree/main/tools/build_pytorch_libs.bat#note-backslash-munging-on-windows - pycmd(PYTHON_SITE_PACKAGES " - import os - import sysconfig - relative_site_packages = sysconfig.get_path('purelib').replace(sysconfig.get_path('data'), '').lstrip(os.path.sep) - print(relative_site_packages) - ") - file(TO_CMAKE_PATH ${PYTHON_SITE_PACKAGES} PYTHON_SITE_PACKAGES) - set(PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES} PARENT_SCOPE) # for Summary # ---[ Options. - set(PYTHON_LIB_REL_PATH "${PYTHON_SITE_PACKAGES}" CACHE STRING "Python installation path (relative to CMake installation prefix)") + set(PYTHON_LIB_REL_PATH "${Python_SITELIB}" CACHE STRING "Python installation path (relative to CMake installation prefix)") message(STATUS "Using ${PYTHON_LIB_REL_PATH} as python relative installation path") - # Python extension suffix - # Try to get from python through sysconfig.get_env_var('EXT_SUFFIX') first, - # fallback to ".pyd" if windows and ".so" for all others. - pycmd(PY_EXT_SUFFIX " - def get_ext_suffix(): - import sys - import sysconfig - return sysconfig.get_config_var('EXT_SUFFIX') - - suffix = get_ext_suffix() - if suffix is not None: - print(suffix) - else: - print() - ") - if("${PY_EXT_SUFFIX}" STREQUAL "") - if(MSVC) - set(PY_EXT_SUFFIX ".pyd") - else() - set(PY_EXT_SUFFIX ".so") - endif() - endif() - - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - # Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80947 in EmbeddingBag.cpp - set_source_files_properties(../aten/src/ATen/native/EmbeddingBag.cpp PROPERTIES COMPILE_FLAGS -Wno-attributes) - set_source_files_properties(${TORCH_SRC_DIR}/../caffe2/operators/box_with_nms_limit_op.cc PROPERTIES COMPILE_FLAGS -Wno-attributes) - endif() - # generated pb files are copied from build/caffe2 to caffe2 - # if we copied them back to build this would create a build cycle - # consider removing the need for globs - filter_list_exclude(PYTHON_SRCS PYTHON_SRCS "proto/.*_pb") set(build_files) foreach(python_src ${PYTHON_SRCS}) @@ -2007,10 +1919,4 @@ if(BUILD_PYTHON) # Pick up static python files install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH} FILES_MATCHING PATTERN "*.py") - # Caffe proto files - install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe DESTINATION ${PYTHON_LIB_REL_PATH} - FILES_MATCHING PATTERN "*.py") - # Caffe2 proto files - install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH} - FILES_MATCHING PATTERN "*.py") endif() diff --git a/caffe2/README.md b/caffe2/README.md deleted file mode 100644 index 13171fca23bb..000000000000 --- a/caffe2/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Caffe2 - -Caffe2 is a lightweight, modular, and scalable deep learning framework. Building on the original [Caffe](http://caffe.berkeleyvision.org), Caffe2 is designed with expression, speed, and modularity in mind. - -## Questions and Feedback - -Please use GitHub issues (https://github.com/pytorch/pytorch/issues) to ask questions, report bugs, and request new features. - -### Further Resources on [Caffe2.ai](http://caffe2.ai) - -* [Installation](http://caffe2.ai/docs/getting-started.html) -* [Learn More](http://caffe2.ai/docs/learn-more.html) -* [Upgrading to Caffe2](http://caffe2.ai/docs/caffe-migration.html) -* [Datasets](http://caffe2.ai/docs/datasets.html) -* [Model Zoo](http://caffe2.ai/docs/zoo.html) -* [Tutorials](http://caffe2.ai/docs/tutorials.html) -* [Operators Catalogue](http://caffe2.ai/docs/operators-catalogue.html) -* [C++ API](http://caffe2.ai/doxygen-c/html/classes.html) -* [Python API](http://caffe2.ai/doxygen-python/html/namespaces.html) diff --git a/caffe2/VERSION_NUMBER b/caffe2/VERSION_NUMBER deleted file mode 100644 index 100435be135a..000000000000 --- a/caffe2/VERSION_NUMBER +++ /dev/null @@ -1 +0,0 @@ -0.8.2 diff --git a/caffe2/__init__.py b/caffe2/__init__.py deleted file mode 100644 index f319e8e2dc15..000000000000 --- a/caffe2/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import warnings -from torch.onnx import _CAFFE2_ATEN_FALLBACK - -if not _CAFFE2_ATEN_FALLBACK: - warnings.warn("Caffe2 support is no longer present in PyTorch.") diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h deleted file mode 100644 index 582328092b26..000000000000 --- a/caffe2/core/blob.h +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef CAFFE2_CORE_BLOB_H_ -#define CAFFE2_CORE_BLOB_H_ - -#include -#include -#include -#include -#include -#include "caffe2/core/common.h" - -#include -#include -#include "caffe2/core/logging.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/tensor_int8.h" - -namespace caffe2 { - -inline bool BlobIsInt8TensorCPUType(const Blob& blob) { - return blob.meta().Match(); -} - -inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) { - bool is_match = blob.meta().Match(); - if (!is_match) { - return false; - } - const Tensor* tensor = &blob.Get(); - return tensor && *tensor && tensor->GetDeviceType() == device_type; -} - -inline Tensor* BlobSetTensor(Blob* blob, Tensor&& tensor) { - return blob->Reset(new Tensor(std::move(tensor))); -} - -inline Tensor GetSizedTensorWithOptions( - Tensor&& previous_tensor, - at::IntArrayRef dims, - at::TensorOptions options) { - Tensor tensor = std::move(previous_tensor); - if (!tensor.defined()) { - return caffe2::empty(dims, options); - } - if (tensor.GetDevice() == options.device() || - (!tensor.GetDevice().has_index() && - tensor.GetDeviceType() == options.device().type())) { - if (tensor.sizes() != dims) { - // Resize when the dims doesn't match - tensor.Resize(dims); - } - if (tensor.dtype() == options.dtype()) { - tensor.raw_mutable_data(); - } else { - // create a new Tensor when the data_type doesn't match - return caffe2::empty(dims, options); - } - return tensor; - } - return caffe2::empty(dims, options); -} - -// need to keep both functions that returns Tensor* and the one -// returns Tensor for clangr codemod -inline Tensor* -BlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) { - if (blob->IsType()) { - Tensor* tensor = blob->GetMutable(); - if (*tensor) { - // We only compare device_type if the index is not set since there are Tensors - // TODO: remove the extra check when all the Tensors are properly initialized - const auto tensorDevice = tensor->GetDevice(); - if (tensorDevice == options.device() || (!tensorDevice.has_index() && tensor->GetDeviceType() == options.device().type())) { - if (tensor->sizes() != dims) { - // Resize when the dims doesn't match - tensor->Resize(dims); - } - tensor->raw_mutable_data(options.dtype()); - return tensor; - } - // create a new Tensor when device doesn't match - } - } - - VLOG(1) << "Create new mutable object " << TypeMeta::TypeName() - << " dims: " << dims; - // << " options: " << options; (operator<< for Options is in at:: now) - return BlobSetTensor(blob, caffe2::empty(dims, options)); -} - -inline Tensor -XBlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) { - return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance(); -} - -inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) { - if (blob->IsType()) { - Tensor* tensor = blob->GetMutable(); - if (*tensor && tensor->GetDeviceType() == device_type) { - return tensor; - } - } - - // if we're here, then either Blob didn't hold a Tensor - // or that Tensor had the wrong DeviceType. - VLOG(1) << "Create new mutable object " << TypeMeta::TypeName() - << " DeviceType:" << device_type; - - return BlobSetTensor(blob, Tensor(device_type)); -} - -inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) { - if (blob.IsType()) { - const auto& tensor = blob.Get(); - if (tensor.GetDeviceType() == device_type) { - return tensor; - } - } - CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match"); -} - -inline Tensor BlobGetTensorOrUndefined(const Blob& blob) { - if (blob.IsType()) { - return blob.Get().UnsafeSharedInstance(); - } else { - return Tensor(); - } -} - -} // namespace caffe2 -#endif // CAFFE2_CORE_BLOB_H_ diff --git a/caffe2/core/blob_gpu_test.cc b/caffe2/core/blob_gpu_test.cc deleted file mode 100644 index de6ea99c0395..000000000000 --- a/caffe2/core/blob_gpu_test.cc +++ /dev/null @@ -1,227 +0,0 @@ -#include // NOLINT - -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { -namespace { - -template class TensorGPUTest : public ::testing::Test {}; -template class TensorGPUDeathTest : public ::testing::Test {}; -typedef ::testing::Types TensorTypes; -TYPED_TEST_CASE(TensorGPUTest, TensorTypes); -TYPED_TEST_CASE(TensorGPUDeathTest, TensorTypes); - -TYPED_TEST(TensorGPUTest, TensorInitializedEmpty) { - if (!caffe2::HasCudaGPU()) return; - Tensor tensor(CUDA); - EXPECT_EQ(tensor.numel(), 0); - EXPECT_EQ(tensor.dim(), 1); - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CUDA); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - dims[0] = 7; - dims[1] = 11; - dims[2] = 13; - dims.push_back(17); - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 4); - EXPECT_EQ(tensor.dim32(0), 7); - EXPECT_EQ(tensor.dim32(1), 11); - EXPECT_EQ(tensor.dim32(2), 13); - EXPECT_EQ(tensor.dim32(3), 17); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorGPUTest, TensorAlias) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CUDA); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); -} - -TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - vector alternate_dims(1); - alternate_dims[0] = 2 * 3 * 5; - Tensor tensor(dims, CUDA); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - other_tensor.Resize(alternate_dims); - EXPECT_EQ(other_tensor.dim(), 1); - EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); -} - -TYPED_TEST(TensorGPUTest, NoLongerAliasAfterNumelChanges) { - if (!HasCudaGPU()) return; - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CUDA); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_EQ(tensor.data(), other_tensor.data()); - auto* old_pointer = other_tensor.data(); - - dims[0] = 7; - tensor.Resize(dims); - EXPECT_EQ(old_pointer, other_tensor.data()); - EXPECT_NE(old_pointer, tensor.mutable_data()); -} - -TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) { - if (!HasCudaGPU()) return; - ::testing::FLAGS_gtest_death_test_style = "threadsafe"; - Tensor tensor(CUDA); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - EXPECT_THROW(tensor.data(), EnforceNotMet); -} - -#define TEST_SERIALIZATION_GPU_WITH_TYPE(TypeParam, field_name) \ - TEST(TensorGPUTest, TensorSerialization_##TypeParam) { \ - if (!HasCudaGPU()) { \ - return; \ - } \ - Blob blob; \ - Tensor cpu_tensor(CPU); \ - cpu_tensor.Resize(2, 3); \ - for (int i = 0; i < 6; ++i) { \ - cpu_tensor.mutable_data()[i] = static_cast(i); \ - } \ - BlobGetMutableTensor(&blob, CUDA)->CopyFrom(cpu_tensor); \ - string serialized = SerializeBlob(blob, "test"); \ - BlobProto proto; \ - CAFFE_ENFORCE(proto.ParseFromString(serialized)); \ - EXPECT_EQ(proto.name(), "test"); \ - EXPECT_EQ(proto.type(), "Tensor"); \ - EXPECT_TRUE(proto.has_tensor()); \ - const TensorProto& tensor_proto = proto.tensor(); \ - EXPECT_EQ( \ - tensor_proto.data_type(), \ - TypeMetaToDataType(TypeMeta::Make())); \ - EXPECT_EQ(tensor_proto.field_name##_size(), 6); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ(tensor_proto.field_name(i), static_cast(i)); \ - } \ - Blob new_blob; \ - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \ - EXPECT_TRUE(BlobIsTensorType(new_blob, CUDA)); \ - Tensor new_cpu_tensor(blob.Get(), CPU); \ - EXPECT_EQ(new_cpu_tensor.dim(), 2); \ - EXPECT_EQ(new_cpu_tensor.size(0), 2); \ - EXPECT_EQ(new_cpu_tensor.size(1), 3); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ( \ - cpu_tensor.data()[i], \ - new_cpu_tensor.data()[i]); \ - } \ - } - -TEST_SERIALIZATION_GPU_WITH_TYPE(bool, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(double, double_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(float, float_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int8_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int16_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(uint8_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(uint16_t, int32_data) -TEST_SERIALIZATION_GPU_WITH_TYPE(int64_t, int64_data) - -TEST(TensorConstruction, ReinitializeTensorTest) { - if (!HasCudaGPU()) return; - Tensor x = caffe2::empty({1}, at::dtype().device(CUDA, 0)); - auto* data_before = x.template mutable_data(); - // We'll only compare device_type in ReinitializeTensor, - // so no tensor reallocation will happen here - ReinitializeTensor(&x, {1}, at::dtype().device(CUDA)); - auto* data_after = x.template mutable_data(); - EXPECT_EQ(data_before, data_after); -} - -TEST(TensorTest, TensorSerializationMultiDevices) { - Blob blob; - Tensor tensor(CPU); - tensor.Resize(2, 3); - for (int i = 0; i < 6; ++i) { - tensor.mutable_data()[i] = i; - } - for (int gpu_id = 0; gpu_id < NumCudaDevices(); ++gpu_id) { - CUDAGuard guard(gpu_id); - CUDAContext context(gpu_id); // switch to the current gpu - blob.Reset(new Tensor(tensor, CUDA)); - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CAFFE_ENFORCE(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_TRUE(proto.has_tensor()); - const TensorProto& tensor_proto = proto.tensor(); - EXPECT_EQ(tensor_proto.data_type(), TensorProto::FLOAT); - EXPECT_EQ(tensor_proto.float_data_size(), 6); - for (int i = 0; i < 6; ++i) { - EXPECT_EQ(tensor_proto.float_data(i), i); - } - EXPECT_TRUE(tensor_proto.has_device_detail()); - EXPECT_EQ(tensor_proto.device_detail().device_type(), PROTO_CUDA); - EXPECT_EQ(tensor_proto.device_detail().device_id(), gpu_id); - // Test if the restored blob is still of the same device. - blob.Reset(); - EXPECT_NO_THROW(DeserializeBlob(serialized, &blob)); - EXPECT_TRUE(BlobIsTensorType(blob, CUDA)); - EXPECT_EQ(GetGPUIDForPointer(blob.Get().data()), - gpu_id); - // Test if we force the restored blob on a different device, we - // can still get so. - blob.Reset(); - proto.mutable_tensor()->mutable_device_detail()->set_device_id(0); - EXPECT_NO_THROW(DeserializeBlob(proto.SerializeAsString(), &blob)); - EXPECT_TRUE(BlobIsTensorType(blob, CUDA)); - EXPECT_EQ(GetGPUIDForPointer(blob.Get().data()), 0); - } -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/blob_serialization_gpu.cc b/caffe2/core/blob_serialization_gpu.cc deleted file mode 100644 index 4d675354531c..000000000000 --- a/caffe2/core/blob_serialization_gpu.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" -#include "caffe2/core/context_gpu.h" - -namespace caffe2 { - -namespace { -REGISTER_BLOB_DESERIALIZER(TensorCUDA, TensorDeserializer); -} -} // namespace caffe2 diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc deleted file mode 100644 index a7e3a8d27e23..000000000000 --- a/caffe2/core/blob_test.cc +++ /dev/null @@ -1,1306 +0,0 @@ -#include -#include -#include - -#include -#include "c10/util/Registry.h" -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/db.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/qtensor.h" -#include "caffe2/core/qtensor_serialization.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/test_utils.h" -#include "caffe2/core/types.h" -#include "caffe2/core/workspace.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/proto_utils.h" - -C10_DEFINE_int64(caffe2_test_big_tensor_size, 100000000, ""); -C10_DECLARE_int(caffe2_tensor_chunk_size); -C10_DECLARE_bool(caffe2_serialize_fp16_as_bytes); -C10_DECLARE_bool(caffe2_serialize_using_bytes_as_holder); - -namespace caffe2 { -using namespace ::caffe2::db; -namespace { -class BlobTestFoo { - public: - int32_t val; -}; -class BlobTestBar {}; -class BlobTestNonDefaultConstructible { - public: - BlobTestNonDefaultConstructible() = delete; - BlobTestNonDefaultConstructible(int x) : val(x) {} - int32_t val; -}; -} // namespace - -CAFFE_KNOWN_TYPE_NOEXPORT(BlobTestFoo); -CAFFE_KNOWN_TYPE_NOEXPORT(BlobTestBar); -CAFFE_KNOWN_TYPE_NOEXPORT(BlobTestNonDefaultConstructible); - -class BlobTestFooSerializer : public BlobSerializerBase { - public: - // NOLINTNEXTLINE(modernize-use-equals-default) - BlobTestFooSerializer() {} - // NOLINTNEXTLINE(modernize-use-equals-default) - ~BlobTestFooSerializer() override {} - /** - * Serializes a Blob. Note that this blob has to contain Tensor, - * otherwise this function produces a fatal error. - */ - void Serialize( - const void* pointer, - TypeMeta typeMeta, - const string& name, - SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(typeMeta.Match()); - - BlobProto blob_proto; - blob_proto.set_name(name); - blob_proto.set_type("BlobTestFoo"); - // For simplicity we will just serialize the 4-byte content as a string. - blob_proto.set_content(std::string( - reinterpret_cast( - &static_cast(pointer)->val), - sizeof(int32_t))); - acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); - } -}; - -class BlobTestFooDeserializer : public BlobDeserializerBase { - public: - void Deserialize(const BlobProto& proto, Blob* blob) override { - blob->GetMutable()->val = - reinterpret_cast(proto.content().c_str())[0]; - } -}; - -REGISTER_BLOB_SERIALIZER((TypeMeta::Id()), BlobTestFooSerializer); -REGISTER_BLOB_DESERIALIZER(BlobTestFoo, BlobTestFooDeserializer); - -namespace { - -TEST(BlobTest, Blob) { - Blob blob; - - int* int_unused CAFFE2_UNUSED = blob.GetMutable(); - EXPECT_TRUE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); - EXPECT_FALSE(BlobIsTensorType(blob, CPU)); - - BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable(); - EXPECT_TRUE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); - EXPECT_FALSE(BlobIsTensorType(blob, CPU)); - - Tensor* tensor_unused CAFFE2_UNUSED = BlobGetMutableTensor(&blob, CPU); - EXPECT_TRUE(BlobIsTensorType(blob, CPU)); - EXPECT_FALSE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); -} - -TEST(BlobTest, BlobUninitialized) { - Blob blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob.Get(), EnforceNotMet); -} - -TEST(BlobTest, BlobWrongType) { - Blob blob; - BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable(); - EXPECT_TRUE(blob.IsType()); - EXPECT_FALSE(blob.IsType()); - // When not null, we should only call with the right type. - EXPECT_NE(&blob.Get(), nullptr); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob.Get(), EnforceNotMet); -} - -TEST(BlobTest, BlobReset) { - Blob blob; - std::unique_ptr foo(new BlobTestFoo()); - EXPECT_TRUE(blob.Reset(foo.release()) != nullptr); - // Also test that Reset works. - blob.Reset(); -} - -TEST(BlobTest, BlobMove) { - Blob blob1; - std::unique_ptr foo(new BlobTestFoo()); - auto* fooPtr = foo.get(); - EXPECT_TRUE(blob1.Reset(foo.release()) != nullptr); - Blob blob2; - blob2 = std::move(blob1); - // NOLINTNEXTLINE(bugprone-use-after-move,hicpp-avoid-goto,clang-analyzer-cplusplus.Move,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob1.Get(), EnforceNotMet); - EXPECT_EQ(&blob2.Get(), fooPtr); - Blob blob3{std::move(blob2)}; - EXPECT_EQ(&blob3.Get(), fooPtr); -} - -TEST(BlobTest, BlobNonConstructible) { - Blob blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(blob.Get(), EnforceNotMet); - // won't work because it's not default constructible - // blob.GetMutable(); - EXPECT_FALSE( - blob.GetMutableOrNull() != nullptr); - EXPECT_TRUE(blob.Reset(new BlobTestNonDefaultConstructible(42)) != nullptr); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_NO_THROW(blob.Get()); - ASSERT_TRUE( - blob.GetMutableOrNull() != nullptr); - EXPECT_EQ(blob.Get().val, 42); - blob.GetMutableOrNull()->val = 37; - EXPECT_EQ(blob.Get().val, 37); -} - -TEST(BlobTest, BlobShareExternalPointer) { - Blob blob; - std::unique_ptr foo(new BlobTestFoo()); - EXPECT_EQ(blob.ShareExternal(foo.get()), foo.get()); - EXPECT_TRUE(blob.IsType()); - // Also test that Reset works. - blob.Reset(); -} - -TEST(BlobTest, BlobShareExternalObject) { - Blob blob; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - BlobTestFoo foo; - EXPECT_EQ(blob.ShareExternal(&foo), &foo); - EXPECT_TRUE(blob.IsType()); - // Also test that Reset works. - blob.Reset(); -} - -TEST(BlobTest, StringSerialization) { - const std::string kTestString = "Hello world?"; - Blob blob; - *blob.GetMutable() = kTestString; - - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "std::string"); - EXPECT_FALSE(proto.has_tensor()); - EXPECT_EQ(proto.content(), kTestString); -} - -TEST(TensorNonTypedTest, TensorChangeType) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - - auto* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(tensor.dtype().Match()); - - // int and float are same size, so should retain the pointer - // NB: this is only true when the use_count of the underlying Storage is 1, if - // the underlying Storage is shared between multiple Tensors We'll create a - // new Storage when the data type changes - EXPECT_TRUE(tensor.mutable_data() == (float*)ptr); - EXPECT_TRUE(tensor.data() == (const float*)ptr); - EXPECT_TRUE(tensor.dtype().Match()); - - // at::Half is smaller, so still should share buffer - EXPECT_TRUE(tensor.mutable_data() == (at::Half*)ptr); - EXPECT_TRUE(tensor.data() == (const at::Half*)ptr); - EXPECT_TRUE(tensor.dtype().Match()); - - // share the data with other tensor so that the pointer won't be reused - // when we reallocate - Tensor other_tensor = tensor.Alias(); - // but double is bigger, so it should allocate a new one - auto* doubleptr = tensor.mutable_data(); - EXPECT_TRUE(doubleptr != (double*)ptr); - EXPECT_TRUE(doubleptr != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(tensor.dtype().Match()); -} - -TEST(TensorNonTypedTest, NonDefaultConstructible) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - - // this doesn't compile - good! - // auto* ptr = tensor.mutable_data(); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW( - tensor.raw_mutable_data( - TypeMeta::Make()), - EnforceNotMet); -} - -template -class TensorCPUTest : public ::testing::Test {}; -template -class TensorCPUDeathTest : public ::testing::Test {}; -typedef ::testing::Types TensorTypes; -TYPED_TEST_CASE(TensorCPUTest, TensorTypes); -TYPED_TEST_CASE(TensorCPUDeathTest, TensorTypes); - -TYPED_TEST(TensorCPUTest, TensorInitializedEmpty) { - Tensor tensor(CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_EQ(tensor.numel(), 2 * 3 * 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorInitializedNonEmpty) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - dims[0] = 7; - dims[1] = 11; - dims[2] = 13; - dims.push_back(17); - tensor.Resize(dims); - EXPECT_EQ(tensor.dim(), 4); - EXPECT_EQ(tensor.dim32(0), 7); - EXPECT_EQ(tensor.dim32(1), 11); - EXPECT_EQ(tensor.dim32(2), 13); - EXPECT_EQ(tensor.dim32(3), 17); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorInitializedZeroDim) { - vector dims(3); - dims[0] = 2; - dims[1] = 0; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 0); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() == nullptr); - EXPECT_TRUE(tensor.data() == nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorResizeZeroDim) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 2); - EXPECT_EQ(tensor.dim32(1), 3); - EXPECT_EQ(tensor.dim32(2), 5); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); - - dims[0] = 7; - dims[1] = 0; - dims[2] = 13; - tensor.Resize(dims); - EXPECT_EQ(tensor.numel(), 0); - EXPECT_EQ(tensor.dim(), 3); - EXPECT_EQ(tensor.dim32(0), 7); - EXPECT_EQ(tensor.dim32(1), 0); - EXPECT_EQ(tensor.dim32(2), 13); - // output value can be arbitrary, but the call to data() shouldn't crash - tensor.mutable_data(); - tensor.data(); -} - -TYPED_TEST(TensorCPUTest, TensorInitializedScalar) { - vector dims; - Tensor tensor(dims, CPU); - EXPECT_EQ(tensor.dim(), 0); - EXPECT_EQ(tensor.numel(), 1); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - EXPECT_TRUE(tensor.data() != nullptr); -} - -TYPED_TEST(TensorCPUTest, TensorAlias) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - tensor.mutable_data()[i] = i; - EXPECT_EQ(other_tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, TensorShareDataRawPointer) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) - std::unique_ptr raw_buffer(new TypeParam[2 * 3 * 5]); - Tensor tensor(dims, CPU); - tensor.ShareExternalPointer(raw_buffer.get()); - EXPECT_EQ(tensor.mutable_data(), raw_buffer.get()); - EXPECT_EQ(tensor.data(), raw_buffer.get()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - raw_buffer.get()[i] = i; - EXPECT_EQ(tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, TensorShareDataRawPointerWithMeta) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) - std::unique_ptr raw_buffer(new TypeParam[2 * 3 * 5]); - Tensor tensor(dims, CPU); - TypeMeta meta = TypeMeta::Make(); - tensor.ShareExternalPointer(raw_buffer.get(), meta); - EXPECT_EQ(tensor.mutable_data(), raw_buffer.get()); - EXPECT_EQ(tensor.data(), raw_buffer.get()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - raw_buffer.get()[i] = i; - EXPECT_EQ(tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - vector alternate_dims(1); - alternate_dims[0] = 2 * 3 * 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - other_tensor.Resize(alternate_dims); - EXPECT_EQ(other_tensor.dim(), 1); - EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); - EXPECT_TRUE(tensor.data() != nullptr); - EXPECT_TRUE(other_tensor.data() != nullptr); - EXPECT_EQ(tensor.data(), other_tensor.data()); - // Set one value, check the other - for (int i = 0; i < tensor.numel(); ++i) { - tensor.mutable_data()[i] = i; - EXPECT_EQ(other_tensor.data()[i], i); - } -} - -TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_EQ(tensor.data(), other_tensor.data()); - auto* old_pointer = other_tensor.data(); - - dims[0] = 7; - tensor.Resize(dims); - EXPECT_EQ(old_pointer, other_tensor.data()); - EXPECT_NE(old_pointer, tensor.mutable_data()); -} - -TYPED_TEST(TensorCPUTest, NoLongerAliasAfterFreeMemory) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - Tensor tensor(dims, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - EXPECT_EQ(tensor.data(), other_tensor.data()); - auto* old_pointer = other_tensor.data(); - - tensor.FreeMemory(); - EXPECT_EQ(old_pointer, other_tensor.data()); - EXPECT_NE(old_pointer, tensor.mutable_data()); -} - -TYPED_TEST(TensorCPUTest, KeepOnShrink) { - // Set flags (defaults) - FLAGS_caffe2_keep_on_shrink = true; - FLAGS_caffe2_max_keep_on_shrink_memory = LLONG_MAX; - - vector dims{2, 3, 5}; - Tensor tensor(dims, CPU); - TypeParam* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - // Expanding - will reallocate - tensor.Resize(3, 4, 6); - TypeParam* larger_ptr = tensor.mutable_data(); - EXPECT_TRUE(larger_ptr != nullptr); - - // This check can fail when malloc() returns the same recently freed address - // EXPECT_NE(ptr, larger_ptr); - - // Shrinking - will not reallocate - tensor.Resize(1, 2, 4); - TypeParam* smaller_ptr = tensor.mutable_data(); - EXPECT_TRUE(smaller_ptr != nullptr); - EXPECT_EQ(larger_ptr, smaller_ptr); - // resize to 0 in the meantime; - tensor.Resize(3, 0, 6); - // Expanding but still under capacity - will not reallocate - tensor.Resize(2, 3, 5); - TypeParam* new_ptr = tensor.mutable_data(); - EXPECT_TRUE(new_ptr != nullptr); - EXPECT_EQ(larger_ptr, new_ptr); -} - -TYPED_TEST(TensorCPUTest, MaxKeepOnShrink) { - // Set flags - FLAGS_caffe2_keep_on_shrink = true; - FLAGS_caffe2_max_keep_on_shrink_memory = 8 * 4 * sizeof(TypeParam); - - vector dims{1, 8, 8}; - Tensor tensor(dims, CPU); - TypeParam* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - // Shrinking - will not reallocate - tensor.Resize(1, 7, 8); - TypeParam* smaller_ptr = tensor.mutable_data(); - EXPECT_TRUE(smaller_ptr != nullptr); - EXPECT_EQ(ptr, smaller_ptr); - // Resize to more than maximum shrink, should reallocate - tensor.Resize(1, 1, 8); - TypeParam* new_ptr = tensor.mutable_data(); - EXPECT_TRUE(new_ptr != nullptr); - - // This check can fail when malloc() returns the same recently freed address - // EXPECT_NE(ptr, new_ptr); - - // Restore default flags - FLAGS_caffe2_max_keep_on_shrink_memory = LLONG_MAX; -} - -TYPED_TEST(TensorCPUDeathTest, CannotAccessRawDataWhenEmpty) { - Tensor tensor(CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(tensor.raw_data()); -} - -TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) { - Tensor tensor(CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.numel(), 0); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(tensor.data()); -} - -TEST(TensorTest, TensorNonFundamentalType) { - Tensor tensor(vector{2, 3, 4}, CPU); - EXPECT_TRUE(tensor.mutable_data() != nullptr); - const std::string* ptr = tensor.data(); - for (int i = 0; i < tensor.numel(); ++i) { - EXPECT_TRUE(ptr[i] == ""); - } -} - -TEST(TensorTest, TensorNonFundamentalTypeClone) { - Tensor tensor(vector{2, 3, 4}, CPU); - std::string* ptr = tensor.mutable_data(); - EXPECT_TRUE(ptr != nullptr); - for (int i = 0; i < tensor.numel(); ++i) { - EXPECT_TRUE(ptr[i] == ""); - ptr[i] = "filled"; - } - Tensor dst_tensor = tensor.Clone(); - const std::string* dst_ptr = dst_tensor.data(); - for (int i = 0; i < dst_tensor.numel(); ++i) { - EXPECT_TRUE(dst_ptr[i] == "filled"); - } - // Change the original tensor - for (int i = 0; i < tensor.numel(); ++i) { - EXPECT_TRUE(ptr[i] == "filled"); - ptr[i] = "changed"; - } - // Confirm that the cloned tensor is not affect - for (int i = 0; i < dst_tensor.numel(); ++i) { - EXPECT_TRUE(dst_ptr[i] == "filled"); - } -} - -TEST(TensorTest, Tensor64BitDimension) { - // Initialize a large tensor. - int64_t large_number = - static_cast(std::numeric_limits::max()) + 1; - Tensor tensor(vector{large_number}, CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.size(0), large_number); - EXPECT_EQ(tensor.numel(), large_number); - try { - EXPECT_TRUE(tensor.mutable_data() != nullptr); - } catch (const EnforceNotMet& e) { - string msg = e.what(); - size_t found = msg.find("posix_memalign"); - if (found != string::npos) { - msg = msg.substr(0, msg.find('\n')); - LOG(WARNING) << msg; - LOG(WARNING) << "Out of memory issue with posix_memalign;\n"; - return; - } else { - throw e; - } - } - EXPECT_EQ(tensor.nbytes(), large_number * sizeof(char)); - EXPECT_EQ(tensor.itemsize(), sizeof(char)); - // Try to go even larger, but this time we will not do mutable_data because we - // do not have a large enough memory. - tensor.Resize(large_number, 100); - EXPECT_EQ(tensor.dim(), 2); - EXPECT_EQ(tensor.size(0), large_number); - EXPECT_EQ(tensor.size(1), 100); - EXPECT_EQ(tensor.numel(), large_number * 100); -} - -TEST(TensorTest, UndefinedTensor) { - Tensor x; - EXPECT_FALSE(x.defined()); -} - -TEST(TensorTest, CopyAndAssignment) { - Tensor x(CPU); - x.Resize(16, 17); - testing::randomFill(x.template mutable_data(), 16 * 17); - EXPECT_TRUE(x.defined()); - - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - Tensor y(x); - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - Tensor z = x; - testing::assertTensorEquals(x, y, 0); - testing::assertTensorEquals(x, z, 0); -} - -TEST(TensorDeathTest, CannotCastDownLargeDims) { - int64_t large_number = - static_cast(std::numeric_limits::max()) + 1; - Tensor tensor(vector{large_number}, CPU); - EXPECT_EQ(tensor.dim(), 1); - EXPECT_EQ(tensor.size(0), large_number); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(tensor.dim32(0), EnforceNotMet); -} - -#define TEST_SERIALIZATION_WITH_TYPE(TypeParam, field_name) \ - TEST(TensorTest, TensorSerialization_##TypeParam) { \ - Blob blob; \ - Tensor* tensor = BlobGetMutableTensor(&blob, CPU); \ - tensor->Resize(2, 3); \ - for (int i = 0; i < 6; ++i) { \ - tensor->mutable_data()[i] = static_cast(i); \ - } \ - string serialized = SerializeBlob(blob, "test"); \ - BlobProto proto; \ - CHECK(proto.ParseFromString(serialized)); \ - EXPECT_EQ(proto.name(), "test"); \ - EXPECT_EQ(proto.type(), "Tensor"); \ - EXPECT_TRUE(proto.has_tensor()); \ - const TensorProto& tensor_proto = proto.tensor(); \ - EXPECT_EQ( \ - tensor_proto.data_type(), \ - TypeMetaToDataType(TypeMeta::Make())); \ - EXPECT_EQ(tensor_proto.field_name##_size(), 6); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ(tensor_proto.field_name(i), static_cast(i)); \ - } \ - Blob new_blob; \ - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \ - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \ - const TensorCPU& new_tensor = blob.Get(); \ - EXPECT_EQ(new_tensor.dim(), 2); \ - EXPECT_EQ(new_tensor.size(0), 2); \ - EXPECT_EQ(new_tensor.size(1), 3); \ - for (int i = 0; i < 6; ++i) { \ - EXPECT_EQ( \ - tensor->data()[i], new_tensor.data()[i]); \ - } \ - } \ - \ - TEST(EmptyTensorTest, TensorSerialization_##TypeParam) { \ - Blob blob; \ - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); \ - tensor->Resize(0, 3); \ - tensor->mutable_data(); \ - string serialized = SerializeBlob(blob, "test"); \ - BlobProto proto; \ - CHECK(proto.ParseFromString(serialized)); \ - EXPECT_EQ(proto.name(), "test"); \ - EXPECT_EQ(proto.type(), "Tensor"); \ - EXPECT_TRUE(proto.has_tensor()); \ - const TensorProto& tensor_proto = proto.tensor(); \ - EXPECT_EQ( \ - tensor_proto.data_type(), \ - TypeMetaToDataType(TypeMeta::Make())); \ - EXPECT_EQ(tensor_proto.field_name##_size(), 0); \ - Blob new_blob; \ - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \ - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \ - const TensorCPU& new_tensor = blob.Get(); \ - EXPECT_EQ(new_tensor.dim(), 2); \ - EXPECT_EQ(new_tensor.size(0), 0); \ - EXPECT_EQ(new_tensor.size(1), 3); \ - } - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(bool, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(double, double_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(float, float_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int8_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int16_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(uint8_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(uint16_t, int32_data) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) -TEST_SERIALIZATION_WITH_TYPE(int64_t, int64_data) - -TEST(TensorTest, TensorSerialization_CustomType) { - Blob blob; - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); - tensor->Resize(2, 3); - for (int i = 0; i < 6; ++i) { - tensor->mutable_data()[i].val = i; - } - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "Tensor"); - Blob new_blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); - const TensorCPU& new_tensor = blob.Get(); - EXPECT_EQ(new_tensor.dim(), 2); - EXPECT_EQ(new_tensor.size(0), 2); - EXPECT_EQ(new_tensor.size(1), 3); - for (int i = 0; i < 6; ++i) { - EXPECT_EQ( - new_tensor.data()[i].val, - tensor->data()[i].val); - } -} - -TEST(TensorTest, Half) { - const int64_t kSize = 3000000; - Blob blob; - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); - tensor->Resize(kSize); - for (int i = 0; i < tensor->numel(); ++i) { - tensor->mutable_data()[i].x = i % 10000; - } - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "Tensor"); - EXPECT_TRUE(proto.has_tensor()); - const TensorProto& tensor_proto = proto.tensor(); - EXPECT_EQ( - tensor_proto.data_type(), TypeMetaToDataType(TypeMeta::Make())); - if (FLAGS_caffe2_serialize_fp16_as_bytes) { - EXPECT_EQ(tensor_proto.byte_data().size(), 2 * kSize); - for (int i = 0; i < kSize; ++i) { - auto value = tensor->mutable_data()[i].x; - auto low_bits = static_cast(value & 0xff); - auto high_bits = static_cast(value >> 8); - EXPECT_EQ(tensor_proto.byte_data()[2 * i], low_bits); - EXPECT_EQ(tensor_proto.byte_data()[2 * i + 1], high_bits); - } - } else { - EXPECT_EQ(tensor_proto.int32_data().size(), kSize); - } - Blob new_blob; - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); - EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); - const TensorCPU& new_tensor = blob.Get(); - EXPECT_EQ(new_tensor.dim(), 1); - EXPECT_EQ(new_tensor.size(0), kSize); - for (int i = 0; i < kSize; ++i) { - EXPECT_EQ(new_tensor.data()[i].x, i % 10000); - } -} - -TEST(TensorTest, TensorFactory) { - Tensor a = empty({1, 2, 3}, at::device(CPU).dtype()); - EXPECT_NE(a.data(), nullptr); - a.mutable_data()[0] = 3.0; - Tensor b = empty({1, 2, 3}, at::device(CPU).dtype()); - EXPECT_NE(b.data(), nullptr); - b.mutable_data()[0] = 3; -} - -TEST(QTensorTest, QTensorSerialization) { - Blob blob; - QTensor* qtensor = blob.GetMutable>(); - qtensor->SetPrecision(5); - qtensor->SetSigned(false); - qtensor->SetScale(1.337); - qtensor->SetBias(-1.337); - qtensor->Resize(std::vector{2, 3}); - // "Randomly" set bits. - srand(0); - for (int i = 0; i < 6; ++i) { - for (int j = 0; j < 5; ++j) { - // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) - qtensor->SetBitAtIndex(j, i, rand() % 2); - } - } - - string serialized = SerializeBlob(blob, "test"); - BlobProto proto; - CHECK(proto.ParseFromString(serialized)); - EXPECT_EQ(proto.name(), "test"); - EXPECT_EQ(proto.type(), "QTensor"); - EXPECT_TRUE(proto.has_qtensor()); - const QTensorProto& qtensor_proto = proto.qtensor(); - - EXPECT_EQ(qtensor_proto.precision(), qtensor->precision()); - EXPECT_EQ(qtensor_proto.scale(), qtensor->scale()); - EXPECT_EQ(qtensor_proto.bias(), qtensor->bias()); - EXPECT_EQ(qtensor_proto.is_signed(), qtensor->is_signed()); - - Blob new_blob; - DeserializeBlob(serialized, &new_blob); - EXPECT_TRUE(new_blob.IsType>()); - const QTensor& new_qtensor = blob.Get>(); - EXPECT_EQ(new_qtensor.ndim(), 2); - EXPECT_EQ(new_qtensor.dim32(0), 2); - EXPECT_EQ(new_qtensor.dim32(1), 3); - for (int i = 0; i < 6; ++i) { - for (int j = 0; j < 5; ++j) { - EXPECT_EQ(qtensor->GetBitAtIndex(j, i), new_qtensor.GetBitAtIndex(j, i)); - } - } -} - -using StringMap = std::vector>; - -class VectorCursor : public db::Cursor { - public: - explicit VectorCursor(StringMap* data) : data_(data) { - pos_ = 0; - } - // NOLINTNEXTLINE(modernize-use-equals-default) - ~VectorCursor() override {} - void Seek(const string& /* unused */) override {} - void SeekToFirst() override {} - void Next() override { - ++pos_; - } - string key() override { - return (*data_)[pos_].first; - } - string value() override { - return (*data_)[pos_].second; - } - bool Valid() override { - return pos_ < data_->size(); - } - - private: - StringMap* data_ = nullptr; - size_t pos_ = 0; -}; - -class VectorDB : public db::DB { - public: - VectorDB(const string& source, db::Mode mode) - : DB(source, mode), name_(source) {} - ~VectorDB() override { - data_.erase(name_); - } - void Close() override {} - std::unique_ptr NewCursor() override { - return make_unique(getData()); - } - std::unique_ptr NewTransaction() override { - CAFFE_THROW("Not implemented"); - } - static void registerData(const string& name, StringMap&& data) { - std::lock_guard guard(dataRegistryMutex_); - data_[name] = std::move(data); - } - - private: - StringMap* getData() { - auto it = data_.find(name_); - CAFFE_ENFORCE(it != data_.end(), "Can't find ", name_); - return &(it->second); - } - - private: - string name_; - static std::mutex dataRegistryMutex_; - static std::map data_; -}; - -std::mutex VectorDB::dataRegistryMutex_; -std::map VectorDB::data_; - -REGISTER_CAFFE2_DB(vector_db, VectorDB); - -template -class TypedTensorTest : public ::testing::Test {}; -typedef ::testing:: - Types - TensorDataTypes; -TYPED_TEST_CASE(TypedTensorTest, TensorDataTypes); - -TYPED_TEST(TypedTensorTest, BigTensorSerialization) { - int64_t d1 = 2; - int64_t d2 = FLAGS_caffe2_test_big_tensor_size - ? FLAGS_caffe2_test_big_tensor_size / d1 - : static_cast(std::numeric_limits::max()) + 1; - int64_t size = d1 * d2; - string db_source = (string)std::tmpnam(nullptr); - VLOG(1) << "db_source: " << db_source; - - { - VLOG(1) << "Test begin"; - Blob blob; - Tensor* tensor = BlobGetMutableTensor(&blob, CPU); - VLOG(1) << "Allocating blob"; - tensor->Resize(d1, d2); - auto mutableData = tensor->mutable_data(); - VLOG(1) << "Filling out the blob"; - for (int64_t i = 0; i < size; ++i) { - mutableData[i] = static_cast(i); - } - StringMap data; - std::mutex mutex; - auto acceptor = [&](const std::string& key, const std::string& value) { - std::lock_guard guard(mutex); - data.emplace_back(key, value); - }; - SerializeBlob(blob, "test", acceptor); - VectorDB::registerData(db_source, std::move(data)); - VLOG(1) << "finished writing to DB"; - } - - { - DeviceOption option; - option.set_device_type(PROTO_CPU); - Argument db_type_arg = MakeArgument("db_type", "vector_db"); - Argument absolute_path_arg = MakeArgument("absolute_path", true); - Argument db_source_arg = MakeArgument("db", db_source); - auto op_def = CreateOperatorDef( - "Load", - "", - std::vector{}, - std::vector({"test"}), - std::vector{db_type_arg, db_source_arg, absolute_path_arg}, - option, - "DUMMY_ENGINE"); - Workspace ws; - auto load_op = CreateOperator(op_def, &ws); - EXPECT_TRUE(load_op != nullptr); - VLOG(1) << "Running operator"; - - load_op->Run(); - VLOG(1) << "Reading blob from workspace"; - auto new_blob = ws.GetBlob("test"); - EXPECT_TRUE(BlobIsTensorType(*new_blob, CPU)); - const auto& new_tensor = new_blob->Get(); - - EXPECT_EQ(new_tensor.dim(), d1); - EXPECT_EQ(new_tensor.size(0), d1); - EXPECT_EQ(new_tensor.size(1), d2); - for (int64_t i = 0; i < size; ++i) { - EXPECT_EQ(static_cast(i), new_tensor.data()[i]); - } - } -} - -struct DummyType { - /* This struct is used to test serialization and deserialization of huge - * blobs, that are not tensors. - */ - - /* implicit */ DummyType(int n_chunks_init = 0) : n_chunks(n_chunks_init) {} - std::string serialize(const std::string& name, const int32_t chunk_id) const { - BlobProto blobProto; - blobProto.set_name(name); - blobProto.set_type("DummyType"); - std::string content(""); - blobProto.set_content(content); - blobProto.set_content_num_chunks(n_chunks); - blobProto.set_content_chunk_id(chunk_id); - return blobProto.SerializeAsString(); - } - void deserialize(const BlobProto& /* unused */) { - ++n_chunks; - } - int n_chunks; -}; - -class DummyTypeSerializer : public BlobSerializerBase { - public: - // NOLINTNEXTLINE(modernize-use-equals-default) - DummyTypeSerializer() {} - // NOLINTNEXTLINE(modernize-use-equals-default) - ~DummyTypeSerializer() override {} - void Serialize( - const void* pointer, - TypeMeta typeMeta, - const string& name, - SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(typeMeta.Match()); - const auto& container = *static_cast(pointer); - for (int k = 0; k < container.n_chunks; ++k) { - std::string serialized_chunk = container.serialize(name, k); - acceptor( - c10::str(name, kChunkIdSeparator, k), std::move(serialized_chunk)); - } - } -}; - -class DummyTypeDeserializer : public BlobDeserializerBase { - public: - void Deserialize(const BlobProto& proto, Blob* blob) override { - auto* container = blob->GetMutable(); - container->deserialize(proto); - } -}; -} // namespace - -CAFFE_KNOWN_TYPE_NOEXPORT(DummyType); - -namespace { -REGISTER_BLOB_SERIALIZER((TypeMeta::Id()), DummyTypeSerializer); -C10_REGISTER_TYPED_CLASS( - BlobDeserializerRegistry, - "DummyType", - DummyTypeDeserializer); - -TEST(ContentChunks, Serialization) { - string db_source = (string)std::tmpnam(nullptr); - VLOG(1) << "db_source: " << db_source; - - { - VLOG(1) << "Test begin"; - Blob blob; - DummyType* container = blob.GetMutable(); - VLOG(1) << "Allocating blob"; - container->n_chunks = 10; - VLOG(1) << "Filling out the blob"; - StringMap data; - std::mutex mutex; - auto acceptor = [&](const std::string& key, const std::string& value) { - std::lock_guard guard(mutex); - data.emplace_back(key, value); - }; - SerializeBlob(blob, "test", acceptor); - VectorDB::registerData(db_source, std::move(data)); - VLOG(1) << "finished writing to DB"; - } - - { - DeviceOption option; - option.set_device_type(PROTO_CPU); - Argument db_type_arg = MakeArgument("db_type", "vector_db"); - Argument absolute_path_arg = MakeArgument("absolute_path", true); - Argument db_source_arg = MakeArgument("db", db_source); - auto op_def = CreateOperatorDef( - "Load", - "", - std::vector{}, - std::vector({"test"}), - std::vector{db_type_arg, db_source_arg, absolute_path_arg}, - option, - "DUMMY_ENGINE"); - Workspace ws; - auto load_op = CreateOperator(op_def, &ws); - EXPECT_TRUE(load_op != nullptr); - VLOG(1) << "Running operator"; - - load_op->Run(); - VLOG(1) << "Reading blob from workspace"; - auto new_blob = ws.GetBlob("test"); - EXPECT_TRUE(new_blob->IsType()); - const auto& container = new_blob->Get(); - EXPECT_EQ(container.n_chunks, 10); - } -} - -TEST(CustomChunkSize, BigTensorSerialization) { - int64_t d1 = 2; - int64_t d2 = FLAGS_caffe2_test_big_tensor_size - ? FLAGS_caffe2_test_big_tensor_size / d1 - : static_cast(std::numeric_limits::max()) + 1; - BlobSerializationOptions options; - - Blob blob; - TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); - tensor->Resize(d1, d2); - tensor->mutable_data(); - std::mutex mutex; - int counter = 0; - auto acceptor = [&](const std::string& /*key*/, - const std::string& /*value*/) { - std::lock_guard guard(mutex); - counter++; - }; - options.set_chunk_size(d1 * d2); - SerializeBlob(blob, "test", acceptor, options); - EXPECT_EQ(counter, 1); - - counter = 0; - options.set_chunk_size((d1 * d2) / 2 + 1); - SerializeBlob(blob, "test", acceptor, options); - EXPECT_EQ(counter, 2); - - counter = 0; - options.set_chunk_size(-1); - SerializeBlob(blob, "test", acceptor, options); - EXPECT_EQ(counter, 1); -} - -TEST(QTensor, QTensorSizingTest) { - vector dims(3); - dims[0] = 2; - dims[1] = 3; - dims[2] = 5; - QTensor qtensor(dims, 3); - EXPECT_TRUE(qtensor.mutable_data() != nullptr); - EXPECT_EQ(qtensor.nbytes(), 12); - EXPECT_EQ(qtensor.size(), 30); -} - -TEST(BlobTest, CastingMessage) { - Blob b; - b.GetMutable(); - b.Get(); - try { - b.Get(); - FAIL() << "Should have thrown"; - } catch (const EnforceNotMet& e) { - string msg = e.what_without_backtrace(); - LOG(INFO) << msg; - EXPECT_NE(msg.find("BlobTestFoo"), std::string::npos) << msg; - EXPECT_NE(msg.find("BlobTestBar"), std::string::npos) << msg; - } -} - -TEST(TensorConstruction, UninitializedCopyTest) { - Tensor x(CPU); - Tensor y(x, CPU); - Tensor z = x.Clone(); - EXPECT_FALSE(x.dtype_initialized()); - EXPECT_FALSE(y.dtype_initialized()); - LOG(INFO) << "z.size()" << z.numel(); - EXPECT_FALSE(z.dtype_initialized()); -} - -TEST(TensorConstruction, CopyConstructorTest) { - Tensor x(CPU); - x.Resize(5); - x.mutable_data()[0] = 1; - Tensor y = x.Clone(); - Tensor z(x, CPU); - - EXPECT_EQ(*x.data(), 1); - EXPECT_EQ(*y.data(), 1); - EXPECT_EQ(*z.data(), 1); - x.mutable_data()[0] = 5; - EXPECT_EQ(*x.data(), 5); - EXPECT_EQ(*y.data(), 1); - EXPECT_EQ(*z.data(), 1); -} - -TEST(TensorConstruction, MoveAssignmentOpTest) { - Tensor x(CPU); - x.Resize(5); - x.mutable_data()[0] = 1; - Tensor y(CPU); - y = std::move(x); - - EXPECT_EQ(*y.data(), 1); -} - -TEST(TensorSerialization, MistakenlySerializingDtypeUninitializedTensor) { - // This test preserves a legacy behavior that dtype-unitialized tensors can - // go through serialization. We want to kill this behavior - when it's done, - // remove this test - Blob blob; - Tensor* x = BlobGetMutableTensor(&blob, CPU); - x->Resize(0); - string output; - SerializeBlob( - blob, - "foo", - [&output](const string& /*blobName*/, const std::string& data) { - output = data; - }); - BlobProto b; - CHECK(b.ParseFromString(output)); - LOG(INFO) << "serialized proto: " << b.DebugString(); - - Blob new_blob; - // Deserializing an empty Tensor gives a {0}-dim, float CPU Tensor - DeserializeBlob(output, &new_blob); - const Tensor& new_tensor = new_blob.Get(); - LOG(INFO) << "tensor " << new_tensor.DebugString(); - EXPECT_TRUE(new_tensor.dtype_initialized()); - LOG(INFO) << "dtype:" << new_tensor.dtype(); - EXPECT_EQ(0, new_tensor.numel()); - EXPECT_EQ(1, new_tensor.dim()); -} - -static caffe2::BlobProto CreateProtoWithInt32Data( - const caffe2::TensorProto::DataType& dataType, - size_t numEl, - bool useCached = true) { - static std::map protos; - if (useCached && protos.count(dataType)) { - return protos[dataType]; - } - caffe2::BlobProto proto; - proto.set_type("Tensor"); - auto tensor = proto.mutable_tensor(); - tensor->add_dims(numEl); - tensor->add_dims(1); - tensor->set_data_type(dataType); - tensor->set_name("test_feature"); - tensor->mutable_device_detail()->set_device_type(0); - tensor->mutable_segment()->set_begin(0); - tensor->mutable_segment()->set_end(numEl); - for (size_t i = 0; i < numEl; ++i) { - int32_t data = 0; - switch (dataType) { - case caffe2::TensorProto_DataType_INT32: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0xffffffff); - break; - case caffe2::TensorProto_DataType_BOOL: - // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x00000001); - break; - case caffe2::TensorProto_DataType_UINT8: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x000000ff); - break; - case caffe2::TensorProto_DataType_INT8: - // NOLINTNEXTLINE(bugprone-signed-char-misuse,cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x000000ff); - break; - case caffe2::TensorProto_DataType_UINT16: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x0000ffff); - break; - case caffe2::TensorProto_DataType_INT16: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x0000ffff); - break; - case caffe2::TensorProto_DataType_FLOAT16: - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,clang-analyzer-security.insecureAPI.rand) - data = static_cast(rand() % 0x0000ffff); - break; - default: - continue; - } - tensor->add_int32_data(data); - } - protos[dataType] = proto; - return proto; -} - -void TestDataType( - const caffe2::TensorProto::DataType& dataType, - std::string dataTypeName) { - LOG(INFO) << dataTypeName; - FLAGS_caffe2_serialize_using_bytes_as_holder = true; - int numEl = 1000; - // Proto with int32 - auto protoInt32 = CreateProtoWithInt32Data(dataType, numEl, false); - caffe2::Blob blobInt32; - DeserializeBlob(protoInt32, &blobInt32); - auto serializedStr = SerializeBlob(blobInt32, protoInt32.name()); - caffe2::BlobProto protoBytes; - // Proto with bytes - protoBytes.ParseFromString(serializedStr); - caffe2::Blob blobBytes; - DeserializeBlob(protoBytes, &blobBytes); - FLAGS_caffe2_serialize_using_bytes_as_holder = false; - // Proto with int32 from proto with bytes - protoBytes.ParseFromString(SerializeBlob(blobBytes, protoBytes.name())); - EXPECT_EQ(numEl, protoInt32.tensor().int32_data_size()); - EXPECT_EQ(numEl, protoBytes.tensor().int32_data_size()); - for (int i = 0; i < numEl; ++i) { - EXPECT_EQ( - protoInt32.tensor().int32_data(i), protoBytes.tensor().int32_data(i)); - } -} - -TEST(TensorSerialization, TestCorrectness) { - FLAGS_caffe2_serialize_using_bytes_as_holder = true; - TestDataType( - caffe2::TensorProto_DataType_INT32, "TensorProto_DataType_INT32"); - TestDataType(caffe2::TensorProto_DataType_BOOL, "TensorProto_DataType_BOOL"); - TestDataType( - caffe2::TensorProto_DataType_UINT8, "TensorProto_DataType_UINT8"); - TestDataType(caffe2::TensorProto_DataType_INT8, "TensorProto_DataType_INT8"); - TestDataType( - caffe2::TensorProto_DataType_UINT16, "TensorProto_DataType_UINT16"); - TestDataType( - caffe2::TensorProto_DataType_INT16, "TensorProto_DataType_INT16"); - TestDataType( - caffe2::TensorProto_DataType_FLOAT16, "TensorProto_DataType_FLOAT16"); -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/common_cudnn.cc b/caffe2/core/common_cudnn.cc deleted file mode 100644 index f8186544054a..000000000000 --- a/caffe2/core/common_cudnn.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "caffe2/core/common_cudnn.h" -#include "caffe2/core/cudnn_wrappers.h" - -#include "caffe2/core/init.h" - -namespace caffe2 { - -CuDNNWrapper::PerGPUCuDNNStates& CuDNNWrapper::cudnn_states() { - // New it (never delete) to avoid calling the destructors on process - // exit and racing against the CUDA shutdown sequence. - static auto* p = new CuDNNWrapper::PerGPUCuDNNStates(); - TORCH_CHECK_NOTNULL(p); - return *p; -} - -namespace { -bool PrintCuDNNInfo(int*, char***) { - VLOG(1) << "Caffe2 is built with CuDNN version " << CUDNN_VERSION; - return true; -} - -REGISTER_CAFFE2_INIT_FUNCTION(PrintCuDNNInfo, &PrintCuDNNInfo, - "Print CuDNN Info."); - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h deleted file mode 100644 index b130103fb5cb..000000000000 --- a/caffe2/core/common_cudnn.h +++ /dev/null @@ -1,314 +0,0 @@ -#ifndef CAFFE2_CORE_COMMON_CUDNN_H_ -#define CAFFE2_CORE_COMMON_CUDNN_H_ - -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/types.h" - -#ifndef CAFFE2_USE_CUDNN -#error("This Caffe2 install is not built with cudnn, so you should not include this file."); -#endif - -#include - -static_assert( - CUDNN_VERSION >= 8200, - "Caffe2 requires cudnn version 8.2 or above."); - -#define CUDNN_VERSION_MIN(major, minor, patch) \ - (major >= 9 ? CUDNN_VERSION >= ((major) * 10000 + (minor) * 100 + (patch)) : \ - CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch))) - -namespace caffe2 { - -namespace internal { -/** - * A helper function to obtain cudnn error strings. - */ -inline const char* cudnnGetErrorString(cudnnStatus_t status) { - switch (status) { - case CUDNN_STATUS_SUCCESS: - return "CUDNN_STATUS_SUCCESS"; - case CUDNN_STATUS_NOT_INITIALIZED: - return "CUDNN_STATUS_NOT_INITIALIZED"; - case CUDNN_STATUS_ALLOC_FAILED: - return "CUDNN_STATUS_ALLOC_FAILED"; - case CUDNN_STATUS_BAD_PARAM: - return "CUDNN_STATUS_BAD_PARAM"; - case CUDNN_STATUS_INTERNAL_ERROR: - return "CUDNN_STATUS_INTERNAL_ERROR"; - case CUDNN_STATUS_INVALID_VALUE: - return "CUDNN_STATUS_INVALID_VALUE"; - case CUDNN_STATUS_ARCH_MISMATCH: - return "CUDNN_STATUS_ARCH_MISMATCH"; - case CUDNN_STATUS_MAPPING_ERROR: - return "CUDNN_STATUS_MAPPING_ERROR"; - case CUDNN_STATUS_EXECUTION_FAILED: - return "CUDNN_STATUS_EXECUTION_FAILED"; - case CUDNN_STATUS_NOT_SUPPORTED: - return "CUDNN_STATUS_NOT_SUPPORTED"; - case CUDNN_STATUS_LICENSE_ERROR: - return "CUDNN_STATUS_LICENSE_ERROR"; - default: - return "Unknown cudnn error number"; - } -} -} // namespace internal - -// A macro that wraps around a cudnn statement so we can check if the cudnn -// execution finishes or not. -#define CUDNN_ENFORCE(condition) \ - do { \ - cudnnStatus_t status = condition; \ - CAFFE_ENFORCE_EQ( \ - status, \ - CUDNN_STATUS_SUCCESS, \ - ", Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::internal::cudnnGetErrorString(status)); \ - } while (0) -#define CUDNN_CHECK(condition) \ - do { \ - cudnnStatus_t status = condition; \ - CHECK(status == CUDNN_STATUS_SUCCESS) \ - << ::caffe2::internal::cudnnGetErrorString(status); \ - } while (0) - -// report the version of cuDNN Caffe2 was compiled with -inline size_t cudnnCompiledVersion() { - return CUDNN_VERSION; -} -// report the runtime version of cuDNN -inline size_t cudnnRuntimeVersion() { - return cudnnGetVersion(); -} - -// Check compatibility of compiled and runtime cuDNN versions -inline void CheckCuDNNVersions() { - // Version format is major*1000 + minor*100 + patch - // If compiled with version < 7, major, minor and patch must all match - // If compiled with version >= 7, then either - // runtime_version > compiled_version - // major and minor match - bool version_match = cudnnCompiledVersion() == cudnnRuntimeVersion(); - bool compiled_with_7 = cudnnCompiledVersion() >= 7000; - bool backwards_compatible_7 = compiled_with_7 && cudnnRuntimeVersion() >= cudnnCompiledVersion(); - bool patch_compatible = compiled_with_7 && (cudnnRuntimeVersion() / 100) == (cudnnCompiledVersion() / 100); - CAFFE_ENFORCE(version_match || backwards_compatible_7 || patch_compatible, - "cuDNN compiled (", cudnnCompiledVersion(), ") and " - "runtime (", cudnnRuntimeVersion(), ") versions mismatch"); -} - -/** - * cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type - * in a template function. The class is specialized explicitly for different - * data types below. - */ -template -class cudnnTypeWrapper; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1.0; - return &v; - } - static const ScalingParamType* kZero() { - static ScalingParamType v = 0.0; - return &v; - } -}; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_INT32; - typedef const int ScalingParamType; - typedef int BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1; - return &v; - } - static const ScalingParamType* kZero() { - static ScalingParamType v = 0; - return &v; - } -}; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - typedef const double ScalingParamType; - typedef double BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1.0; - return &v; - } - static ScalingParamType* kZero() { - static ScalingParamType v = 0.0; - return &v; - } -}; - -template <> -class cudnnTypeWrapper { - public: - static const cudnnDataType_t type = CUDNN_DATA_HALF; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() { - static ScalingParamType v = 1.0; - return &v; - } - static ScalingParamType* kZero() { - static ScalingParamType v = 0.0; - return &v; - } -}; - -/** - * A wrapper function to convert the Caffe storage order to cudnn storage order - * enum values. - */ -inline cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder& order) { - switch (order) { - case StorageOrder::NHWC: - return CUDNN_TENSOR_NHWC; - case StorageOrder::NCHW: - return CUDNN_TENSOR_NCHW; - default: - LOG(FATAL) << "Unknown cudnn equivalent for order: " << order; - } - // Just to suppress compiler warnings - return CUDNN_TENSOR_NCHW; -} - -/** - * cudnnTensorDescWrapper is the placeholder that wraps around a - * cudnnTensorDescriptor_t, allowing us to do descriptor change as-needed during - * runtime. - */ -class cudnnTensorDescWrapper { - public: - cudnnTensorDescWrapper() { - CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_)); - } - ~cudnnTensorDescWrapper() noexcept { - CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_)); - } - - inline cudnnTensorDescriptor_t Descriptor( - const cudnnTensorFormat_t format, - const cudnnDataType_t type, - const vector& dims, - bool* changed) { - if (type_ == type && format_ == format && dims_ == dims) { - // if not changed, simply return the current descriptor. - if (changed) - *changed = false; - return desc_; - } - CAFFE_ENFORCE_EQ( - dims.size(), 4U, "Currently only 4-dimensional descriptor supported."); - format_ = format; - type_ = type; - dims_ = dims; - CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( - desc_, - format, - type, - dims_[0], - (format == CUDNN_TENSOR_NCHW ? dims_[1] : dims_[3]), - (format == CUDNN_TENSOR_NCHW ? dims_[2] : dims_[1]), - (format == CUDNN_TENSOR_NCHW ? dims_[3] : dims_[2]))); - if (changed) - *changed = true; - return desc_; - } - - template - inline cudnnTensorDescriptor_t Descriptor( - const StorageOrder& order, - const vector& dims) { - return Descriptor( - GetCudnnTensorFormat(order), cudnnTypeWrapper::type, dims, nullptr); - } - - private: - cudnnTensorDescriptor_t desc_; - cudnnTensorFormat_t format_; - cudnnDataType_t type_; - vector dims_; - C10_DISABLE_COPY_AND_ASSIGN(cudnnTensorDescWrapper); -}; - -class cudnnFilterDescWrapper { - public: - cudnnFilterDescWrapper() { - CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&desc_)); - } - ~cudnnFilterDescWrapper() noexcept { - CUDNN_CHECK(cudnnDestroyFilterDescriptor(desc_)); - } - - inline cudnnFilterDescriptor_t Descriptor( - const StorageOrder& order, - const cudnnDataType_t type, - const vector& dims, - bool* changed) { - if (type_ == type && order_ == order && dims_ == dims) { - // if not changed, simply return the current descriptor. - if (changed) - *changed = false; - return desc_; - } - CAFFE_ENFORCE_EQ( - dims.size(), 4U, "Currently only 4-dimensional descriptor supported."); - order_ = order; - type_ = type; - dims_ = dims; - CUDNN_ENFORCE(cudnnSetFilter4dDescriptor( - desc_, - type, - GetCudnnTensorFormat(order), - dims_[0], - // TODO - confirm that this is correct for NHWC - (order == StorageOrder::NCHW ? dims_[1] : dims_[3]), - (order == StorageOrder::NCHW ? dims_[2] : dims_[1]), - (order == StorageOrder::NCHW ? dims_[3] : dims_[2]))); - if (changed) - *changed = true; - return desc_; - } - - template - inline cudnnFilterDescriptor_t Descriptor( - const StorageOrder& order, - const vector& dims) { - return Descriptor(order, cudnnTypeWrapper::type, dims, nullptr); - } - - private: - cudnnFilterDescriptor_t desc_; - StorageOrder order_; - cudnnDataType_t type_; - vector dims_; - C10_DISABLE_COPY_AND_ASSIGN(cudnnFilterDescWrapper); -}; - - -} // namespace caffe2 - -#endif // CAFFE2_CORE_COMMON_CUDNN_H_ diff --git a/caffe2/core/common_gpu.cc b/caffe2/core/common_gpu.cc deleted file mode 100644 index e5a26359d3f2..000000000000 --- a/caffe2/core/common_gpu.cc +++ /dev/null @@ -1,253 +0,0 @@ -#include "caffe2/core/common_gpu.h" - -#include -#include -#include -#include - -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/init.h" -#include "caffe2/core/logging.h" - -namespace caffe2 { - -int NumCudaDevices() { - if (getenv("CAFFE2_DEBUG_CUDA_INIT_ORDER")) { - static bool first = true; - if (first) { - first = false; - std::cerr << "DEBUG: caffe2::NumCudaDevices() invoked for the first time" - << std::endl; - } - } - // It logs warnings on first run - return c10::cuda::device_count(); -} - -namespace { -int gDefaultGPUID = 0; -} // namespace - -void SetDefaultGPUID(const int deviceid) { - CAFFE_ENFORCE_LT( - deviceid, - NumCudaDevices(), - "The default gpu id should be smaller than the number of gpus " - "on this machine: ", - deviceid, - " vs ", - NumCudaDevices()); - gDefaultGPUID = deviceid; -} - -int GetDefaultGPUID() { return gDefaultGPUID; } - -int CaffeCudaGetDevice() { - int gpu_id = 0; - CUDA_ENFORCE(cudaGetDevice(&gpu_id)); - return gpu_id; -} - -void CaffeCudaSetDevice(const int id) { - CUDA_ENFORCE(cudaSetDevice(id)); -} - -int GetGPUIDForPointer(const void* ptr) { - cudaPointerAttributes attr; - cudaError_t err = cudaPointerGetAttributes(&attr, ptr); - - if (err == cudaErrorInvalidValue) { - // Occurs when the pointer is in the CPU address space that is - // unmanaged by CUDA; make sure the last error state is cleared, - // since it is persistent - err = cudaGetLastError(); - CHECK(err == cudaErrorInvalidValue); - return -1; - } - - // Otherwise, there must be no error - CUDA_ENFORCE(err); - - if (attr.type == cudaMemoryTypeHost) { - return -1; - } - - return attr.device; -} - -struct CudaDevicePropWrapper { - CudaDevicePropWrapper() : props(NumCudaDevices()) { - for (int i = 0; i < NumCudaDevices(); ++i) { - CUDA_ENFORCE(cudaGetDeviceProperties(&props[i], i)); - } - } - - vector props; -}; - -const cudaDeviceProp& GetDeviceProperty(const int deviceid) { - // According to C++11 standard section 6.7, static local variable init is - // thread safe. See - // https://stackoverflow.com/questions/8102125/is-local-static-variable-initialization-thread-safe-in-c11 - // for details. - static CudaDevicePropWrapper props; - CAFFE_ENFORCE_LT( - deviceid, - NumCudaDevices(), - "The gpu id should be smaller than the number of gpus ", - "on this machine: ", - deviceid, - " vs ", - NumCudaDevices()); - return props.props[deviceid]; -} - -void DeviceQuery(const int device) { - const cudaDeviceProp& prop = GetDeviceProperty(device); - std::stringstream ss; - ss << std::endl; - ss << "Device id: " << device << std::endl; - ss << "Major revision number: " << prop.major << std::endl; - ss << "Minor revision number: " << prop.minor << std::endl; - ss << "Name: " << prop.name << std::endl; - ss << "Total global memory: " << prop.totalGlobalMem << std::endl; - ss << "Total shared memory per block: " << prop.sharedMemPerBlock - << std::endl; - ss << "Total registers per block: " << prop.regsPerBlock << std::endl; - ss << "Warp size: " << prop.warpSize << std::endl; -#if !defined(USE_ROCM) - ss << "Maximum memory pitch: " << prop.memPitch << std::endl; -#endif - ss << "Maximum threads per block: " << prop.maxThreadsPerBlock - << std::endl; - ss << "Maximum dimension of block: " - << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", " - << prop.maxThreadsDim[2] << std::endl; - ss << "Maximum dimension of grid: " - << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", " - << prop.maxGridSize[2] << std::endl; - ss << "Clock rate: " << prop.clockRate << std::endl; - ss << "Total constant memory: " << prop.totalConstMem << std::endl; -#if !defined(USE_ROCM) - ss << "Texture alignment: " << prop.textureAlignment << std::endl; - ss << "Concurrent copy and execution: " - << (prop.deviceOverlap ? "Yes" : "No") << std::endl; -#endif - ss << "Number of multiprocessors: " << prop.multiProcessorCount - << std::endl; -#if !defined(USE_ROCM) - ss << "Kernel execution timeout: " - << (prop.kernelExecTimeoutEnabled ? "Yes" : "No") << std::endl; -#endif - LOG(INFO) << ss.str(); - return; -} - -bool GetCudaPeerAccessPattern(vector >* pattern) { - int gpu_count; - if (cudaGetDeviceCount(&gpu_count) != cudaSuccess) return false; - pattern->clear(); - pattern->resize(gpu_count, vector(gpu_count, false)); - for (int i = 0; i < gpu_count; ++i) { - for (int j = 0; j < gpu_count; ++j) { - int can_access = true; - if (i != j) { - if (cudaDeviceCanAccessPeer(&can_access, i, j) - != cudaSuccess) { - return false; - } - } - (*pattern)[i][j] = static_cast(can_access); - } - } - return true; -} - -bool TensorCoreAvailable() { - int device = CaffeCudaGetDevice(); - auto& prop = GetDeviceProperty(device); - - return prop.major >= 7; -} - -const char* cublasGetErrorString(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; -#if !defined(USE_ROCM) - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; -#endif - } - // To suppress compiler warning. - return "Unrecognized cublas error string"; -} - -const char* curandGetErrorString(curandStatus_t error) { - switch (error) { - case CURAND_STATUS_SUCCESS: - return "CURAND_STATUS_SUCCESS"; - case CURAND_STATUS_VERSION_MISMATCH: - return "CURAND_STATUS_VERSION_MISMATCH"; - case CURAND_STATUS_NOT_INITIALIZED: - return "CURAND_STATUS_NOT_INITIALIZED"; - case CURAND_STATUS_ALLOCATION_FAILED: - return "CURAND_STATUS_ALLOCATION_FAILED"; - case CURAND_STATUS_TYPE_ERROR: - return "CURAND_STATUS_TYPE_ERROR"; - case CURAND_STATUS_OUT_OF_RANGE: - return "CURAND_STATUS_OUT_OF_RANGE"; - case CURAND_STATUS_LENGTH_NOT_MULTIPLE: - return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; - case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: - return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; - case CURAND_STATUS_LAUNCH_FAILURE: - return "CURAND_STATUS_LAUNCH_FAILURE"; - case CURAND_STATUS_PREEXISTING_FAILURE: - return "CURAND_STATUS_PREEXISTING_FAILURE"; - case CURAND_STATUS_INITIALIZATION_FAILED: - return "CURAND_STATUS_INITIALIZATION_FAILED"; - case CURAND_STATUS_ARCH_MISMATCH: - return "CURAND_STATUS_ARCH_MISMATCH"; - case CURAND_STATUS_INTERNAL_ERROR: - return "CURAND_STATUS_INTERNAL_ERROR"; -#if defined(USE_ROCM) - case HIPRAND_STATUS_NOT_IMPLEMENTED: - return "HIPRAND_STATUS_NOT_IMPLEMENTED"; -#endif - } - // To suppress compiler warning. - return "Unrecognized curand error string"; -} - -// Turn on the flag g_caffe2_has_cuda_linked to true for HasCudaRuntime() -// function. -namespace { -class CudaRuntimeFlagFlipper { - public: - CudaRuntimeFlagFlipper() { - internal::SetCudaRuntimeFlag(); - } -}; -static CudaRuntimeFlagFlipper g_flipper; -} // namespace - -} // namespace caffe2 diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h deleted file mode 100644 index 011f46264b19..000000000000 --- a/caffe2/core/common_gpu.h +++ /dev/null @@ -1,475 +0,0 @@ -#ifndef CAFFE2_CORE_COMMON_GPU_H_ -#define CAFFE2_CORE_COMMON_GPU_H_ - -#include -#include -#include - -#if !defined(USE_ROCM) -#ifdef __GNUC__ -#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) -#pragma GCC diagnostic push -#endif -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ -#endif // USE_ROCM - -#include -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" - -#include "c10/cuda/CUDAMacros.h" -#include "c10/cuda/CUDAMathCompat.h" -#include - -#define CAFFE2_CUDA_EXPORT C10_EXPORT - -// CAFFE2_CUDA_API gets translated to CAFFE2_HIP_API in hipify script, which -// causes a marco redefinition issue with the later definition of -// CAFFE2_HIP_API, so we exclude this definition when HIP is specified -#if !defined(USE_ROCM) -#define CAFFE2_CUDA_API TORCH_CUDA_CPP_API -#endif // USE_ROCM - -//TODO: [ROCm] Need to remove this after CUDA->HIP mapping is updated. -#define CAFFE2_HIP_EXPORT C10_EXPORT -#define CAFFE2_HIP_API TORCH_HIP_API - -// This is a macro defined for cuda fp16 support. In default, cuda fp16 is -// supported by NVCC 7.5, but it is also included in the Tegra X1 platform with -// a (custom?) NVCC 7.0. As a result, we would normally just check the cuda -// version here, but would also allow a use to pass in the flag -// CAFFE_HAS_CUDA_FP16 manually. - -#ifndef CAFFE_HAS_CUDA_FP16 -#define CAFFE_HAS_CUDA_FP16 -#endif // CAFFE_HAS_CUDA_FP16 - -#ifdef CAFFE_HAS_CUDA_FP16 -#include -#endif - -// cuda major revision number below which fp16 compute is not supoorted -#if !defined(USE_ROCM) -constexpr int kFp16CUDADevicePropMajor = 6; -#else -constexpr int kFp16CUDADevicePropMajor = 3; -#endif - -// Re-enable strict aliasing diagnostic if it was disabled. -#if !defined(USE_ROCM) -#ifdef __GNUC__ -#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6) -#pragma GCC diagnostic pop -#endif -#endif // __GNUC__ -#endif // USE_ROCM - -/** - * The maximum number of peers that each gpu can have when doing p2p setup. - * Currently, according to NVidia documentation, each device can support a - * system-wide maximum of eight peer connections. - * When Caffe2 sets up peer access resources, if we have more than 8 gpus, - * we will enable peer access in groups of 8. - */ -#define CAFFE2_CUDA_MAX_PEER_SIZE 8 - -namespace caffe2 { - -#if !defined(USE_ROCM) -/** - * Empty class to identify TensorCore-based math - */ -class TensorCoreEngine {}; -#endif // USE_ROCM - -/** - * A runtime function to report the cuda version that Caffe2 is built with. - */ -inline int CudaVersion() { -#if defined(USE_ROCM) - return ROCM_VERSION; -#else - return CUDA_VERSION; -#endif -} - -/** - * Returns the number of devices. - */ -CAFFE2_CUDA_API int NumCudaDevices(); - -/** - * Check if the current running session has a cuda gpu present. - * - * Note that this is different from having caffe2 built with cuda. Building - * Caffe2 with cuda only guarantees that this function exists. If there are no - * cuda gpus present in the machine, or there are hardware configuration - * problems like an insufficient driver, this function will still return false, - * meaning that there is no usable GPU present. - * - * In the open source build, it is possible that Caffe2's GPU code is - * dynamically loaded, and as a result a library could be only linked to the - * CPU code, but want to test if cuda is later available or not. In this case, - * one should use HasCudaRuntime() from common.h. - */ -inline bool HasCudaGPU() { - return NumCudaDevices() > 0; -} - -/** - * Gets the current GPU id. This is a simple wrapper around cudaGetDevice(). - */ -CAFFE2_CUDA_API int CaffeCudaGetDevice(); - -/** - * Gets the current GPU id. This is a simple wrapper around cudaGetDevice(). - */ -CAFFE2_CUDA_API void CaffeCudaSetDevice(const int id); - -/** - * Gets the GPU id that the current pointer is located at. - */ -CAFFE2_CUDA_API int GetGPUIDForPointer(const void* ptr); - -/** - * Gets the device property for the given device. This function is thread safe. - * The initial run on this function is ~1ms/device; however, the results are - * cached so subsequent runs should be much faster. - */ -CAFFE2_CUDA_API const cudaDeviceProp& GetDeviceProperty(const int device); - -/** - * Runs a device query function and prints out the results to LOG(INFO). - */ -CAFFE2_CUDA_API void DeviceQuery(const int deviceid); - -/** - * Return a peer access pattern by returning a matrix (in the format of a - * nested vector) of boolean values specifying whether peer access is possible. - * - * This function returns false if anything wrong happens during the query of - * the GPU access pattern. - */ -CAFFE2_CUDA_API bool GetCudaPeerAccessPattern(vector>* pattern); - -/** - * Return the availability of TensorCores for math - */ -CAFFE2_CUDA_API bool TensorCoreAvailable(); - -/** - * Return a human readable cublas error string. - */ -CAFFE2_CUDA_API const char* cublasGetErrorString(cublasStatus_t error); - -/** - * Return a human readable curand error string. - */ -CAFFE2_CUDA_API const char* curandGetErrorString(curandStatus_t error); - -// CUDA: various checks for different function calls. -#define CUDA_ENFORCE(condition, ...) \ - do { \ - cudaError_t error = condition; \ - CAFFE_ENFORCE_EQ( \ - error, \ - cudaSuccess, \ - "Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - cudaGetErrorString(error), \ - ##__VA_ARGS__); \ - } while (0) -#define CUDA_CHECK(condition) \ - do { \ - cudaError_t error = condition; \ - CHECK(error == cudaSuccess) << cudaGetErrorString(error); \ - } while (0) - -#define CUDA_DRIVERAPI_ENFORCE(condition) \ - do { \ - CUresult result = condition; \ - if (result != CUDA_SUCCESS) { \ - const char* msg; \ - cuGetErrorName(result, &msg); \ - CAFFE_THROW("Error at: ", __FILE__, ":", __LINE__, ": ", msg); \ - } \ - } while (0) -#define CUDA_DRIVERAPI_CHECK(condition) \ - do { \ - CUresult result = condition; \ - if (result != CUDA_SUCCESS) { \ - const char* msg; \ - cuGetErrorName(result, &msg); \ - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ - << msg; \ - } \ - } while (0) - -#define CUBLAS_ENFORCE(condition) \ - do { \ - cublasStatus_t status = condition; \ - CAFFE_ENFORCE_EQ( \ - status, \ - CUBLAS_STATUS_SUCCESS, \ - "Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::cublasGetErrorString(status)); \ - } while (0) -#define CUBLAS_CHECK(condition) \ - do { \ - cublasStatus_t status = condition; \ - CHECK(status == CUBLAS_STATUS_SUCCESS) \ - << ::caffe2::cublasGetErrorString(status); \ - } while (0) - -#define CURAND_ENFORCE(condition) \ - do { \ - curandStatus_t status = condition; \ - CAFFE_ENFORCE_EQ( \ - status, \ - CURAND_STATUS_SUCCESS, \ - "Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::curandGetErrorString(status)); \ - } while (0) -#define CURAND_CHECK(condition) \ - do { \ - curandStatus_t status = condition; \ - CHECK(status == CURAND_STATUS_SUCCESS) \ - << ::caffe2::curandGetErrorString(status); \ - } while (0) - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - -#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) \ - for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \ - j += blockDim.y * gridDim.y) - -// The following helper functions are here so that you can write a kernel call -// when you are not particularly interested in maxing out the kernels' -// performance. Usually, this will give you a reasonable speed, but if you -// really want to find the best performance, it is advised that you tune the -// size of the blocks and grids more reasonably. -// A legacy note: this is derived from the old good Caffe days, when I simply -// hard-coded the number of threads and wanted to keep backward compatibility -// for different computation capabilities. -// For more info on CUDA compute capabilities, visit the NVidia website at: -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities - -// The number of cuda threads to use. Since work is assigned to SMs at the -// granularity of a block, 128 is chosen to allow utilizing more SMs for -// smaller input sizes. -// 1D grid -constexpr int CAFFE_CUDA_NUM_THREADS = 128; -// 2D grid -constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMX = 16; -constexpr int CAFFE_CUDA_NUM_THREADS_2D_DIMY = 16; - -// The maximum number of blocks to use in the default kernel call. We set it to -// 4096 which would work for compute capability 2.x (where 65536 is the limit). -// This number is very carelessly chosen. Ideally, one would like to look at -// the hardware at runtime, and pick the number of blocks that makes most -// sense for the specific runtime environment. This is a todo item. -// 1D grid -constexpr int CAFFE_MAXIMUM_NUM_BLOCKS = 4096; -// 2D grid -constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX = 128; -constexpr int CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY = 128; - -constexpr int kCUDAGridDimMaxX = 2147483647; -constexpr int kCUDAGridDimMaxY = 65535; -constexpr int kCUDAGridDimMaxZ = 65535; - -/** - * @brief Compute the number of blocks needed to run N threads. - */ -inline int CAFFE_GET_BLOCKS(const int N) { - return std::max( - std::min( - (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS, - CAFFE_MAXIMUM_NUM_BLOCKS), - // Use at least 1 block, since CUDA does not allow empty block - 1); -} - -/** - * @brief Compute the number of blocks needed to run N threads for a 2D grid - */ -inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) { - dim3 grid; - // Not calling the 1D version for each dim to keep all constants as literals - - grid.x = std::max( - std::min( - (N + CAFFE_CUDA_NUM_THREADS_2D_DIMX - 1) / - CAFFE_CUDA_NUM_THREADS_2D_DIMX, - CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMX), - // Use at least 1 block, since CUDA does not allow empty block - 1); - - grid.y = std::max( - std::min( - (N + CAFFE_CUDA_NUM_THREADS_2D_DIMY - 1) / - CAFFE_CUDA_NUM_THREADS_2D_DIMY, - CAFFE_MAXIMUM_NUM_BLOCKS_2D_DIMY), - // Use at least 1 block, since CUDA does not allow empty block - 1); - - return grid; -} - -using CUDAGuard = c10::cuda::CUDAGuard; - -template -struct SimpleArray { - T data[N]; -}; - -constexpr int kCUDATensorMaxDims = 8; - -#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1(val, Func, T, ...) \ - do { \ - CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \ - switch (val) { \ - case 1: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 4: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 5: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 6: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 7: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 8: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - default: { \ - break; \ - } \ - } \ - } while (false) - -#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2(val, Func, T1, T2, ...) \ - do { \ - CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \ - switch (val) { \ - case 1: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 4: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 5: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 6: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 7: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 8: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - default: { \ - break; \ - } \ - } \ - } while (false) - -#define DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3(val, Func, T1, T2, T3, ...) \ - do { \ - CAFFE_ENFORCE_LE(val, kCUDATensorMaxDims); \ - switch (val) { \ - case 1: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 4: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 5: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 6: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 7: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - case 8: { \ - Func(__VA_ARGS__); \ - break; \ - } \ - default: { \ - break; \ - } \ - } \ - } while (false) - -} // namespace caffe2 - -#endif // CAFFE2_CORE_COMMON_GPU_H_ diff --git a/caffe2/core/context.h b/caffe2/core/context.h deleted file mode 100644 index eb46f78f8b0d..000000000000 --- a/caffe2/core/context.h +++ /dev/null @@ -1,227 +0,0 @@ -#ifndef CAFFE2_CORE_CONTEXT_H_ -#define CAFFE2_CORE_CONTEXT_H_ - -#include -#include -#include -#include - -#include -#include "caffe2/core/allocator.h" -#include "caffe2/core/context_base.h" -#include "caffe2/core/event.h" -#include "caffe2/core/logging.h" -#include "caffe2/proto/caffe2_pb.h" - -#include - -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -#include -#include -#include -#include -#else -#include "caffe2/core/distributions_stubs.h" -#endif - -C10_DECLARE_bool(caffe2_report_cpu_memory_usage); - -namespace caffe2 { - -/** - * A function to generate a random number seed that is unique in a best-effort - * basis, using an ever-incrementing seed and the current time. - */ -TORCH_API uint32_t RandomNumberSeed(); - -/** - * The CPU Context, representing the bare minimum of what a Context class in - * Caffe2 should implement. - * - * // TODO modify docs - * See operator.h, especially Operator, for how Context are used in - * actual operator implementations that are associated with specific devices. - * In general, the Context class is passed in as a template argument, and - * the operator can use the functions defined in the context to execute whatever - * computation it has. - * - */ -class TORCH_API CPUContext final : public BaseContext { - public: -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - class rand_gen_type { - public: - explicit rand_gen_type(uint64_t seed_in = default_rng_seed_val) - : engine_{seed_in} {} - - uint32_t random() { - return engine_(); - } - uint64_t random64() { - uint32_t random1 = engine_(); - uint32_t random2 = engine_(); - return (static_cast(random1) << 32) | random2; - } - - std::optional next_float_normal_sample() { - return next_float_normal_sample_; - } - std::optional next_double_normal_sample() { - return next_double_normal_sample_; - } - void set_next_float_normal_sample(std::optional randn) { - next_float_normal_sample_ = randn; - } - void set_next_double_normal_sample(std::optional randn) { - next_double_normal_sample_ = randn; - } - - private: - at::mt19937 engine_; - std::optional next_float_normal_sample_; - std::optional next_double_normal_sample_; - }; -#else - typedef std::mt19937 rand_gen_type; -#endif - - CPUContext() {} - explicit CPUContext(const DeviceOption& option) - : random_seed_(option.has_random_seed() ? option.random_seed() : 1701), - random_seed_set_(option.has_random_seed() ? true : false) { - CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU); - } - explicit CPUContext(const at::Device& device) - : CPUContext(DeviceToOption(device)) {} - - ~CPUContext() noexcept override {} - - inline void SwitchToDevice(int64_t /*stream_id*/) override {} - - using BaseContext::SwitchToDevice; - - inline void WaitEvent(const Event& ev) override { - ev.Wait(CPU, this); - } - - inline void Record(Event* ev, const char* err_msg = nullptr) const override { - CAFFE_ENFORCE(ev, "Event must not be null."); - ev->Record(CPU, this, err_msg); - } - - inline void FinishDeviceComputation() override {} - - inline rand_gen_type* RandGenerator() { - if (!random_generator_.get()) { - random_generator_.reset(new rand_gen_type(RandSeed())); - } - return random_generator_.get(); - } - - inline uint32_t RandSeed() { - if (!random_seed_set_) { - random_seed_ = RandomNumberSeed(); - random_seed_set_ = true; - } - return static_cast(random_seed_); - } - - inline static at::DataPtr New(size_t nbytes) { - return GetCPUAllocator()->allocate(nbytes); - } - - void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override; - - void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytesSameDevice(nbytes, src, dst); - } - - void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytesSameDevice(nbytes, src, dst); - } - - bool SupportsNonFundamentalTypes() const override { - // CPU non fumdamental type copy OK - return true; - } - - template - inline void CopyBytes(size_t nbytes, const void* src, void* dst); - - template - inline void Copy(size_t n, const T* src, T* dst) { - if (c10::guts::is_fundamental::value) { - CopyBytes( - n * sizeof(T), - static_cast(src), - static_cast(dst)); - } else { - for (const auto i : c10::irange(n)) { - dst[i] = src[i]; - } - } - } - - template - inline void - CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { - if (meta.copy()) { - meta.copy()(src, dst, n); - } else { - CopyBytes(n * meta.itemsize(), src, dst); - } - } - - // By default CPU operators don't have async device parts - static bool HasAsyncPartDefault() { - return false; - } - - static bool SupportsAsyncScheduling() { - return false; - } - - // CPU streams are not implemented and are silently ignored by CPU ops, - // return true to signal executor to schedule a CPU op - static bool IsStreamFree( - const DeviceOption& /* option */, - int /* stream_id */) { - return true; - } - - at::Device device() const override { - // TODO: numa? - return at::Device(CPU); - } - - DeviceType device_type() const override { - return CPU; - } - - static constexpr DeviceType GetDeviceType() { - return CPU; - } - - protected: - // TODO(jiayq): instead of hard-coding a generator, make it more flexible. - int random_seed_{1701}; - bool random_seed_set_{false}; - std::unique_ptr random_generator_; -}; - -template <> -inline void CPUContext::CopyBytes( - size_t nbytes, - const void* src, - void* dst) { - if (nbytes == 0) { - return; - } - CAFFE_ENFORCE(src); - CAFFE_ENFORCE(dst); - memcpy(dst, src, nbytes); -} - -} // namespace caffe2 - -#endif // CAFFE2_CORE_CONTEXT_H_ diff --git a/caffe2/core/context_base.h b/caffe2/core/context_base.h deleted file mode 100644 index cc8cc4c5bb60..000000000000 --- a/caffe2/core/context_base.h +++ /dev/null @@ -1,168 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { -class Event; - -} // namespace caffe2 -namespace at { - -class BaseContext; - -/** - * Virtual interface for the Context class in Caffe2. - * - * A Context defines all the necessities to run an operator on a specific - * device. Specific Context classes needs to implement all the pure virtual - * functions in the BaseContext class. - * TODO: add docs after this is finalized. - */ -class TORCH_API BaseContext { - public: - virtual ~BaseContext() noexcept {} - - virtual Device device() const = 0; - - /* Sorry for the naming, will get rid of this in future diff */ - virtual DeviceType device_type() const = 0; - - virtual void SwitchToDevice(int64_t /*stream_id*/) = 0; - - inline void SwitchToDevice() { - SwitchToDevice(0); - } - - virtual void WaitEvent(const caffe2::Event& ev) = 0; - - virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr) - const = 0; - - virtual void FinishDeviceComputation() = 0; - - // This used to be arbitrary cross-device copy, but it turns out everyone - // did direct CPU-X copy, so we just make three functions for it (to avoid - // double dispatch). This will get obsoleted by C10. where copies - // will be proper operators (and get to rely on multiple dispatch there.) - virtual void CopyBytesSameDevice( - size_t nbytes, - const void* src, - void* dst) = 0; - - virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0; - - virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0; - - template - inline void CopySameDevice(size_t n, const T* src, T* dst) { - static_assert( - c10::guts::is_fundamental::value, - "CopySameDevice requires fundamental types"); - CopyBytesSameDevice( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - template - inline void CopyFromCPU(size_t n, const T* src, T* dst) { - static_assert( - c10::guts::is_fundamental::value, - "CopyFromCPU requires fundamental types"); - CopyBytesFromCPU( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - template - inline void CopyToCPU(size_t n, const T* src, T* dst) { - static_assert( - c10::guts::is_fundamental::value, "CopyToCPU requires fundamental types"); - CopyBytesToCPU( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - virtual bool SupportsNonFundamentalTypes() const { - return false; - } - - inline void EnforceMetaCopyOK() { - AT_ASSERTM( - SupportsNonFundamentalTypes(), "Context requires fundamental types"); - } - - void CopyItemsSameDevice( - const caffe2::TypeMeta meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesSameDevice(n * meta.itemsize(), src, dst); - } - } - - void CopyItemsFromCPU( - const caffe2::TypeMeta meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesFromCPU(n * meta.itemsize(), src, dst); - } - } - - void CopyItemsToCPU( - const caffe2::TypeMeta meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesToCPU(n * meta.itemsize(), src, dst); - } - } -}; - -// Context constructor registry -C10_DECLARE_TYPED_REGISTRY( - ContextRegistry, - at::DeviceType, - at::BaseContext, - std::unique_ptr, - at::Device); - -#define REGISTER_CONTEXT(type, ...) \ - C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__) - -inline std::unique_ptr CreateContext( - const at::Device& device) { - return at::ContextRegistry()->Create(device.type(), device); -} - -} // namespace at - -namespace caffe2 { - -using at::BaseContext; -using at::CreateContext; -} // namespace caffe2 diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu deleted file mode 100644 index ecc933ac7fad..000000000000 --- a/caffe2/core/context_gpu.cu +++ /dev/null @@ -1,669 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include "cub/util_allocator.cuh" - -// Needed to be included first to check the CAFFE2_USE_CUDNN macros. -#include "caffe2/core/macros.h" - -#include "caffe2/core/blob_stats.h" -#ifdef CAFFE2_USE_CUDNN -#include "caffe2/core/common_cudnn.h" -#endif // CAFFE2_USE_CUDNN -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/init.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/tensor.h" -#include "caffe2/utils/string_utils.h" -#include "caffe2/utils/cub_namespace.cuh" - -C10_DEFINE_string( - caffe2_cuda_memory_pool, - "", - "Sets the memory pool used by caffe2. Possible values are " - "none, cnmem, thc and cub."); - -// For description of CUB caching allocator configuration, see -// https://nvlabs.github.io/cub/structcub_1_1_caching_device_allocator.html -C10_DEFINE_int( - caffe2_cub_bin_growth, - 8, - "If using cub as the memory allocator, sets the growth of bins " - "used by the cub pool."); -C10_DEFINE_int( - caffe2_cub_min_bin, - 3, - "If using cub as the memory allocator, sets the min number of " - "bins."); -C10_DEFINE_int( - caffe2_cub_max_bin, - 10, - "If using cub as the memory allocator, sets the max number of " - "bins."); -C10_DEFINE_int( - caffe2_cub_max_managed_mb, - 10 * 1024, - "If using cub as the memory allocators, sets the maximum amount " - "of memory managed in gigabytes"); - -C10_DEFINE_bool( - caffe2_cub_print_allocation_events, - false, - "If true CachingDeviceAllocator will print allocation and deallocation " - "events to stdout."); - -C10_DEFINE_bool( - caffe2_gpu_memory_tracking, - false, - "If set, logs changes in GPU memory allocations"); -C10_DEFINE_int( - caffe2_gpu_memory_report_interval_mb, - 128, - "The threshold in MB on how frequently to report memory changes"); - -namespace at { - -REGISTER_CONTEXT(DeviceType::CUDA, caffe2::CUDAContext); -} // namespace at - -namespace caffe2 { - -// Generic implementation - CUDA will handle the right function to call for us -void CUDAContext::CopyBytesAsync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device) { - // TODO: verify that the CUDA handles copy from device to device correctly - // even without SetDevice() - // TODO: verify whether source or dest device should be a priority in picking - // the stream - // NB: right now the cross-device copy logic is invoked only in the contexts - // when surrounding code explicitly manages data dependencies and sets up - // events, so it's fine. In order to make it a standalone function proper - // synchronization between stream is required - int gpu_id = 0; - if (dst_device.is_cuda()) { - gpu_id = dst_device.index(); - } else if (src_device.is_cuda()) { - gpu_id = src_device.index(); - } else { - LOG(FATAL) << "shouldn't be called with non-cuda device"; - } - CUDA_ENFORCE(cudaMemcpyAsync( - dst, - src, - nbytes, - cudaMemcpyDefault, - CUDAContext::getCudaObjects().GetStream(gpu_id))); -} - -void CUDAContext::CopyBytesSync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device) { - // This emulates Caffe2 original behavior where sync copy doesn't change the - // device. It's probably better for clarity to switch to the target device - // explicitly here, but in the worst case CUDA would sync for us. - // TODO: change it to CUDAGuard - CUDAContext context(-1); // take current device - CUDA_ENFORCE(cudaMemcpyAsync( - dst, src, nbytes, cudaMemcpyDefault, context.cuda_stream())); - // destructor of context synchronizes -} - -// For the CPU context, we also allow a (probably expensive) function -// to copy the data from a cuda context. Inside the function, we create -// a temporary CUDAContext object to carry out the copy. From the caller's -// side, these functions are synchronous with respect to the host, similar -// to a normal CPUContext::CopyBytes call. -template <> -inline void CPUContext::CopyBytes( - size_t nbytes, - const void* src, - void* dst) { - CUDAContext context(GetGPUIDForPointer(src)); - context.CopyBytes(nbytes, src, dst); -} -template <> -inline void CPUContext::CopyBytes( - size_t nbytes, - const void* src, - void* dst) { - CUDAContext context(GetGPUIDForPointer(dst)); - context.CopyBytes(nbytes, src, dst); -} - -} // namespace caffe2 - -namespace caffe2 { - -ThreadLocalCUDAObjects& CUDAContext::getCudaObjects() { - static thread_local ThreadLocalCUDAObjects cuda_objects_; - return cuda_objects_; -} - -// TODO(jiayq): these variables shouldn't be currently accessed during static -// initialization. We should consider moving them to a Mayer's singleton to -// be totally safe against SIOF. - -// Static global variables for setting up the memory pool. -CudaMemoryPoolType g_cuda_memory_pool_type; - -std::unique_ptr g_cub_allocator; - -// an unordered map that holds the map from the cuda memory pointer to the -// device id that it is allocated from. This is used in the cuda memory pool -// cases, where we need the device id to carry out the deletion. -// Note(jiayq): an alternate approach is to use cudaGetPointerAttributes, but -// that is usually quite slow. We might want to benchmark the speed difference -// though. -// Note(jiayq): another alternate approach is to augment the Tensor class that -// would allow one to record the device id. However, this does not address any -// non-tensor allocation and deallocation. -// Ideally, a memory pool should already have the device id information, as -// long as we are using UVA (as of CUDA 5 and later) so the addresses are -// unique. -static std::unordered_map g_cuda_device_affiliation; - -// Data structures for optional memory tracking. Access to these structures -// is guarded by the CUDAContext::mutex. -static std::unordered_map g_size_map; -static std::vector g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0); -static std::vector g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0); - -static long g_total_mem = 0; -static long g_last_rep = 0; - -CudaMemoryPoolType GetCudaMemoryPoolType() { - return g_cuda_memory_pool_type; -} - -/////////////////////////////////////////////////////////////////////////////// -// A wrapper to allow us to lazily initialize all cuda environments that Caffe -// uses. This gets done the first time a caffe2::CUDAContext::New() gets called -// which is probably the decisive indication that this caffe2 run is going to -// use GPUs. We avoid cuda initialization with core/init.h functionalities so -// that we have minimal resource impact in case we will need to run multiple -// caffe2 instances on a GPU machine. -/////////////////////////////////////////////////////////////////////////////// - -static void Caffe2InitializeCuda() { - // If the current run does not have any cuda devices, do nothing. - if (!HasCudaGPU()) { - VLOG(1) << "No cuda gpu present. Skipping."; - return; - } - C10_LOG_API_USAGE_ONCE("caffe2.init.cuda"); - // Check if the number of GPUs matches the expected compile-time max number - // of GPUs. - CAFFE_ENFORCE_LE( - NumCudaDevices(), - C10_COMPILE_TIME_MAX_GPUS, - "Number of CUDA devices on the machine is larger than the compiled " - "max number of gpus expected (", - C10_COMPILE_TIME_MAX_GPUS, - "). Increase that and recompile."); - - for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) { - CUDAGuard g(i); - // Enable peer access. - const int peer_group = i / CAFFE2_CUDA_MAX_PEER_SIZE; - const int peer_start = peer_group * CAFFE2_CUDA_MAX_PEER_SIZE; - const int peer_end = std::min( - NumCudaDevices(), (peer_group + 1) * CAFFE2_CUDA_MAX_PEER_SIZE); - VLOG(1) << "Enabling peer access within group #" << peer_group - << ", from gpuid " << peer_start << " to " << peer_end - 1 - << ", for gpuid " << i << "."; - - for (int j = peer_start; j < peer_end; ++j) { - if (i == j) continue; - int can_access; - CUDA_ENFORCE(cudaDeviceCanAccessPeer(&can_access, i, j)); - if (can_access) { - VLOG(1) << "Enabling peer access from " << i << " to " << j; - // Note: just for future reference, the 0 here is not a gpu id, it is - // a reserved flag for cudaDeviceEnablePeerAccess that should always be - // zero currently. - // It is ok if peer access is already enabled... - cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0)); - if ((err != cudaErrorPeerAccessAlreadyEnabled) && - (err != cudaSuccess)) { - CAFFE_THROW(cudaGetErrorString(err)); - } - cudaGetLastError(); // reset cuda error code - } - } - } - -#ifdef CAFFE2_USE_CUDNN - // Check the versions of cuDNN that were compiled and linked with are compatible - CheckCuDNNVersions(); -#endif // CAFFE2_USE_CUDNN -} - -static void SetUpCub() { - VLOG(1) << "Setting up cub memory pool."; - // Sets up the cub memory pool - try { - g_cub_allocator.reset(new cub::CachingDeviceAllocator( - FLAGS_caffe2_cub_bin_growth, - FLAGS_caffe2_cub_min_bin, - FLAGS_caffe2_cub_max_bin, - size_t(FLAGS_caffe2_cub_max_managed_mb) * 1024L * 1024L, - false, - FLAGS_caffe2_cub_print_allocation_events)); - } catch (...) { - CAFFE_THROW("Some error happened at cub initialization."); - } - VLOG(1) << "Done setting up cub memory pool."; -} - -static void Caffe2SetCUDAMemoryPool() { - if (FLAGS_caffe2_cuda_memory_pool == "" || - FLAGS_caffe2_cuda_memory_pool == "none") { - g_cuda_memory_pool_type = CudaMemoryPoolType::NONE; - } else if (FLAGS_caffe2_cuda_memory_pool == "cnmem") { - CAFFE_THROW("CNMEM is no longer used by Caffe2. Use cub instead. " - "This error message may go away in the future."); - } else if (FLAGS_caffe2_cuda_memory_pool == "cub") { - // Sets up cub. - g_cuda_memory_pool_type = CudaMemoryPoolType::CUB; - SetUpCub(); - } else if (FLAGS_caffe2_cuda_memory_pool == "thc") { - g_cuda_memory_pool_type = CudaMemoryPoolType::THC; - // Initialize caching allocator - at::globalContext().lazyInitCUDA(); - } else { - CAFFE_THROW( - "Unrecognized cuda memory pool type: ", FLAGS_caffe2_cuda_memory_pool); - } -} - -/** - * An allocator that does the CPU memory allocation with pinned memory. - * - * This is needed because if we want to do any asynchronous cuda memcpy, - * the underlying CPU memory also needs to be allocated into pinned memory - * space. As a result, whenever Caffe2 is built with GPU and there is - * GPU present during runtime, at global initialization time we will set - * the CPU memory allocator to allocate pinned memory. - * - * NB: This behavior is probably too aggressive. We should consider asking users - * to do on-demand memory pinning (like exposed in PyTorch APIs) instead. - */ -struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator { - PinnedCPUAllocator() { - baseAllocator_ = GetDefaultCPUAllocator(); - } - ~PinnedCPUAllocator() override {} - at::DataPtr allocate(size_t nbytes) override { - if (nbytes == 0) { - // replicate c10::alloc_cpu behavior - return nullptr - return {nullptr, nullptr, &Delete, at::Device(CPU)}; - } - void* data; - at::DataPtr data_ptr; - std::lock_guard lock(CUDAContext::mutex()); - if (IsNUMAEnabled()) { - at::DeleterFnPtr expected_deleter = baseAllocator_->raw_deleter(); - data_ptr = baseAllocator_->allocate(nbytes); - data = data_ptr.get(); - CAFFE_ENFORCE(data); - CUDA_ENFORCE(cudaHostRegister(data, nbytes, cudaHostRegisterDefault)); - CAFFE_ENFORCE( - data_ptr.compare_exchange_deleter(expected_deleter, &Delete), - "Failed to swap deleter (already swapped?)"); - } else { - CUDA_ENFORCE(cudaMallocHost(&data, nbytes)); - profiledCPUMemoryReporter().New(data, nbytes); - data_ptr = {data, data, &Delete, at::Device(CPU)}; - } - memset(data, 0, nbytes); - return data_ptr; - } - - at::DeleterFnPtr raw_deleter() const override { - return &Delete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for PinnedCPUAllocator"); - } - - private: - static void Delete(void* data) { - if (!data) { - return; - } - // Caffe2 uses a lazy way to figure out if one is actually going to use GPUs - // or not. If a CUDAContext::New() call is made, inside the CUDAContext - // function we will switch the cpu side allocator to a PinnedCPUAllocator. - // But, if one calls CPUContext::New() before any cuda allocations, - // PinnedCPUAllocator can still delete the corresponding memory. - std::lock_guard lock(CUDAContext::mutex()); - if (IsNUMAEnabled()) { - CUDA_ENFORCE(cudaHostUnregister(data)); - GetDefaultCPUAllocator()->raw_deleter()(data); - } else { - cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaFreeHost(data)); - profiledCPUMemoryReporter().Delete(data); - if (err == cudaErrorInvalidValue) { - free(data); - // Calling cudaGetLastError will reset the cuda error. - cudaError_t _err = cudaGetLastError(); - } else { - // For all other errors, still do a cuda check. - CUDA_ENFORCE(err); - } - } - } - - at::Allocator* baseAllocator_; -}; - -static PinnedCPUAllocator g_pinned_cpu_alloc; - -// An initialization function that sets the CPU side to use pinned cpu -// allocator. -void Caffe2UsePinnedCPUAllocator() { -#if C10_ASAN_ENABLED - // Note(jiayq): for more details, see - // https://github.com/google/sanitizers/issues/629 - LOG(WARNING) << "There are known issues between address sanitizer and " - "cudaMallocHost. As a result, caffe2 will not enable pinned " - "memory allocation in asan mode. If you are expecting any " - "behavior that depends on asan, be advised that it is not " - "turned on."; -#else - if (!HasCudaGPU()) { - VLOG(1) << "No GPU present. I won't use pinned allocator then."; - return; - } - VLOG(1) << "Caffe2 gpu: setting CPUAllocator to PinnedCPUAllocator."; - - // If CUDA is enabled, using CPU allocators other than PinnedCPUAllocator - // will cause memory corruptions. Therefore, we need to set the priority - // to highest to avoid being overwritten. - SetCPUAllocator( - &g_pinned_cpu_alloc, - std::numeric_limits::max() /* priority */); -#endif -} - -// Caffe2CudaInitializerHelper is a minimal struct whose sole purpose is to -// detect the first hint that this Caffe2 run is going to use GPU: either -// CUDAContext is initialized or CUDAContext::New is called. It then runs -// all the related cuda initialization functions. -namespace { -struct Caffe2CudaInitializerHelper { - Caffe2CudaInitializerHelper() { - // We cannot use bool because nvcc changes bool to __nv_bool which does - // not have a std::atomic instantiation. - static std::atomic first_call(1); - if (first_call.fetch_and((char)0)) { - Caffe2InitializeCuda(); - Caffe2SetCUDAMemoryPool(); - Caffe2UsePinnedCPUAllocator(); - } - } -}; -} // namespace - -/** - * A utility function to rectify the gpu id. If the context specifies the - * gpu id to be -1, it means that we will just use the current gpu id when - * the function is being called. - */ -static inline DeviceIndex RectifyGPUID(DeviceIndex gpu_id) { - return gpu_id == -1 ? CaffeCudaGetDevice() : gpu_id; -} - -CUDAContext::CUDAContext(DeviceIndex gpu_id) - : gpu_id_(RectifyGPUID(gpu_id)), random_seed_(RandomNumberSeed()) { - static Caffe2CudaInitializerHelper g_cuda_initializer_; -} - -CUDAContext::CUDAContext(const DeviceOption& option) - : gpu_id_( - option.has_device_id() ? RectifyGPUID(option.device_id()) - : CaffeCudaGetDevice()), - random_seed_( - option.has_random_seed() ? option.random_seed() - : RandomNumberSeed()) { - static Caffe2CudaInitializerHelper g_cuda_initializer_; - TORCH_DCHECK_EQ(option.device_type(), PROTO_CUDA); -} - -CUDAContext::~CUDAContext() { - try { - if (curand_generator_) { - CURAND_CHECK(curandDestroyGenerator(curand_generator_)); - } - // CUDAContext is used in 2 cases now: - // - long-lived instance inside OperatorBase in which case what happens in - // destructor doesn't really matter - // - short-lived on-the-fly instances that are utilized as CUDAGuard - in - // this case there's only one stream id (passed to SwitchToDevice) and - // it's preferrable to synchronize in the destructor - FinishDeviceComputation(); - } catch (const std::exception& e) { - LOG(ERROR) << "Encountered following in " << __FUNCTION__ << ": " << e.what(); - } -} - -// shared mutex to lock out alloc / free during NCCL launches -std::mutex& CUDAContext::mutex() { - static std::mutex m; - return m; -} - -std::vector CUDAContext::TotalMemoryByGpu() { - std::lock_guard lock(CUDAContext::mutex()); - CAFFE_ENFORCE( - FLAGS_caffe2_gpu_memory_tracking, - "Pass --caffe2_gpu_memory_tracking to enable memory stats"); - return g_total_by_gpu_map; -} - -std::vector CUDAContext::MaxMemoryByGpu() { - std::lock_guard lock(CUDAContext::mutex()); - CAFFE_ENFORCE( - FLAGS_caffe2_gpu_memory_tracking, - "Pass --caffe2_gpu_memory_tracking to enable memory stats"); - return g_max_by_gpu_map; -} - -namespace { -void TrackMemoryAlloc(size_t nbytes) { - int this_gpu = CaffeCudaGetDevice(); - g_total_by_gpu_map[this_gpu] += nbytes; - g_max_by_gpu_map[this_gpu] = - std::max(g_max_by_gpu_map[this_gpu], g_total_by_gpu_map[this_gpu]); - g_total_mem += nbytes; - if (g_total_mem - g_last_rep > - FLAGS_caffe2_gpu_memory_report_interval_mb * 1024 * 1024) { - for (int gpu = 0; gpu < g_total_by_gpu_map.size(); gpu++) { - long t = g_total_by_gpu_map[gpu]; - long max_t = g_max_by_gpu_map[gpu]; - if (max_t > 0) { - if (max_t != t) { - VLOG(1) << "GPU " << gpu << ": " << t / 1024 / 1024 << " MB" - << " (max: " << max_t / 1024 / 1024 << " MB)"; - } else { - VLOG(1) << "GPU " << gpu << ": " << t / 1024 / 1024 << " MB"; - } - } - } - VLOG(1) << "Total: " << g_total_mem / 1024 / 1024 << " MB"; - g_last_rep = g_total_mem; - } -} -} - -struct DefaultCUDAAllocator final : public at::Allocator { - DefaultCUDAAllocator() {} - ~DefaultCUDAAllocator() override {} - at::DataPtr allocate(size_t nbytes) override { - // Lock the mutex - std::lock_guard lock(CUDAContext::mutex()); - // A one-time caffe2 cuda initializer. - static Caffe2CudaInitializerHelper g_cuda_initializer_; - void* ptr = nullptr; - - if (FLAGS_caffe2_gpu_memory_tracking) { - TrackMemoryAlloc(nbytes); - } - switch (g_cuda_memory_pool_type) { - case CudaMemoryPoolType::NONE: - if (nbytes != 0) { - CUDA_ENFORCE(cudaMalloc(&ptr, nbytes)); - } - if (FLAGS_caffe2_gpu_memory_tracking) { - g_size_map[ptr] = nbytes; - g_cuda_device_affiliation[ptr] = CaffeCudaGetDevice(); - } - return {ptr, ptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - case CudaMemoryPoolType::CUB: - if (nbytes != 0) { - CUDA_ENFORCE(g_cub_allocator->DeviceAllocate(&ptr, nbytes)); - } - g_cuda_device_affiliation[ptr] = CaffeCudaGetDevice(); - VLOG(2) << "CUB allocating pointer " << ptr << " on device " - << CaffeCudaGetDevice(); - if (FLAGS_caffe2_gpu_memory_tracking) { - g_size_map[ptr] = nbytes; - } - return {ptr, ptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - case CudaMemoryPoolType::THC: - { - // The reason we have this stream guard here is to preserve - // the historical behavior of the 'thc' allocator in Caffe2, - // which is to put all allocations on the same (default) - // stream. This behavior is morally wrong (since passing - // allocations between streams allows for the possibility - // of you handing out some memory that an old stream - // is still working on), but it doesn't seem to cause issues - // in Caffe2 today. Our hypothesis for why this is the case - // is that Caffe2 doesn't really do very many allocations - // on the fly; instead they allocate once and then reuse - // the allocations for the whole program. In this case, - // the hazard is avoided. - // - // We intend to remove this stream guard, but the benefit - // to putting all allocations on the same stream is it - // reduces per-stream fragmentation, and this helps - // some models that are currently running with the thc - // allocator fit in memory. We will need to find some - // way of resolving this problem. - c10::cuda::CUDAStreamGuard g( - Stream( - Stream::DEFAULT, - Device(kCUDA, CaffeCudaGetDevice()) - )); - ptr = c10::cuda::CUDACachingAllocator::raw_alloc(nbytes); - } - if (FLAGS_caffe2_gpu_memory_tracking) { - g_size_map[ptr] = nbytes; - g_cuda_device_affiliation[ptr] = CaffeCudaGetDevice(); - } - return {ptr, ptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - } - return {nullptr, nullptr, &Delete, at::Device(CUDA, CaffeCudaGetDevice())}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &Delete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for DefaultCUDAAllocator"); - } - - private: - static void Delete(void* ptr) { - // lock the mutex - std::lock_guard lock(CUDAContext::mutex()); - if (FLAGS_caffe2_gpu_memory_tracking) { - auto sz_it = g_size_map.find(ptr); - DCHECK(sz_it != g_size_map.end()); - auto aff_it = g_cuda_device_affiliation.find(ptr); - DCHECK(aff_it != g_cuda_device_affiliation.end()); - g_total_mem -= sz_it->second; - g_total_by_gpu_map[aff_it->second] -= sz_it->second; - g_size_map.erase(sz_it); - } - - switch (g_cuda_memory_pool_type) { - case CudaMemoryPoolType::NONE: { - // If memory pool is not set up, use simple cudaFree. - cudaError_t error = C10_CUDA_ERROR_HANDLED(cudaFree(ptr)); - // For some reason, in Python runtime we sometimes delete a data pointer - // after the cuda runtime exits - this is odd but is probably caused by - // a static workspace that pycaffe2 uses, and the destruction got - // entangled in some race condition. Anyway, since cuda runtime is - // exiting anyway, we will not need to worry about memory leak, so we - // basically ignore it. This is definitely not ideal but works for now. - if (error != cudaSuccess && error != cudaErrorCudartUnloading) { - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " - << cudaGetErrorString(error); - } - - if (FLAGS_caffe2_gpu_memory_tracking) { - g_cuda_device_affiliation.erase(g_cuda_device_affiliation.find(ptr)); - } - - break; - } - case CudaMemoryPoolType::CUB: { - auto it = g_cuda_device_affiliation.find(ptr); - DCHECK(it != g_cuda_device_affiliation.end()); - VLOG(2) << "CUB freeing pointer " << ptr << " on device " << it->second; - CUDA_ENFORCE(g_cub_allocator->DeviceFree(it->second, ptr)); - g_cuda_device_affiliation.erase(it); - break; - } - case CudaMemoryPoolType::THC: { - c10::cuda::CUDACachingAllocator::raw_delete(ptr); - if (FLAGS_caffe2_gpu_memory_tracking) { - g_cuda_device_affiliation.erase(g_cuda_device_affiliation.find(ptr)); - } - break; - } - } - } -}; - -static DefaultCUDAAllocator g_cuda_alloc; -REGISTER_ALLOCATOR(CUDA, &g_cuda_alloc); - -} // namespace caffe2 - -namespace at { -REGISTER_COPY_BYTES_FUNCTION( - DeviceType::CUDA, - DeviceType::CUDA, - caffe2::CUDAContext::CopyBytesSync, - caffe2::CUDAContext::CopyBytesAsync); - -REGISTER_COPY_BYTES_FUNCTION( - DeviceType::CUDA, - DeviceType::CPU, - caffe2::CUDAContext::CopyBytesSync, - caffe2::CUDAContext::CopyBytesAsync); - -REGISTER_COPY_BYTES_FUNCTION( - DeviceType::CPU, - DeviceType::CUDA, - caffe2::CUDAContext::CopyBytesSync, - caffe2::CUDAContext::CopyBytesAsync); -} // namespace at diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h deleted file mode 100644 index 8490a5002e5f..000000000000 --- a/caffe2/core/context_gpu.h +++ /dev/null @@ -1,354 +0,0 @@ -#ifndef CAFFE2_CORE_CONTEXT_GPU_H_ -#define CAFFE2_CORE_CONTEXT_GPU_H_ - -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context.h" -#include "caffe2/core/context_base.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/numa.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/types.h" -#include "caffe2/proto/caffe2_pb.h" - -// Since we are using the macro CAFFE2_USE_CUDNN, we will need to include this -// file after common.h is included. -#ifdef CAFFE2_USE_CUDNN -#include "caffe2/core/common_cudnn.h" -#endif // CAFFE2_USE_CUDNN - -#include -#include -#include -#include - -namespace caffe2 { - -enum class CudaMemoryPoolType { - NONE = 0, - CUB = 1, - THC = 2, -}; - -/** - * Gets the current memory pool type used by Caffe2. - * - * The memory pool is set up during caffe2's global initialization time. - */ -CAFFE2_CUDA_API CudaMemoryPoolType GetCudaMemoryPoolType(); - -/** - * A struct to host thread-local cuda objects. - * - * In Caffe2, each thread has its own non-default cuda stream as well as - * related objects such as cublas and curand handles. This is achieved by - * having the ThreadLocalCUDAObjects wrapper that takes care of allocating - * and deallocating these objects at the thread scope. This class is solely - * used inside CUDAContext and should not be used externally. - * - * This class manages the mapping from logical stream ID (int stream_id - * passed around in Caffe2) and CUDAStream objects. We intend to eventually - * deprecate the logical stream ID interface, but not for now. - */ -class CAFFE2_CUDA_API ThreadLocalCUDAObjects { - friend class CUDAContext; - - private: - ThreadLocalCUDAObjects() { - for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) { - cuda_streams_[i] = vector(); - } - } - - // Record current stream id for the current thread. - // This is the new API we're trying to migrate use cases to and get rid of - // explicit stream id passing. For now it's invoked in - // CUDAContext::SwitchToDevice - void SetCurrentStreamId(DeviceIndex gpu, StreamId stream_id) { - // TODO: use current device id from thread local instead of passing gpu in - if (stream_id != -1) { - c10::cuda::setCurrentCUDAStream(GetCUDAStream(gpu, stream_id)); - } - } - - // Retrieves the CUDAStream corresponding to a logical stream ID, ensuring - // that it exists in cuda_streams_ if it has not been allocated yet. - c10::cuda::CUDAStream GetCUDAStream(DeviceIndex gpu, StreamId stream_id) { - vector& gpu_streams = cuda_streams_[gpu]; - while (gpu_streams.size() <= static_cast(stream_id)) { - // NB: This streams are not guaranteed to be unique; we'll - // wrap around once we run out of streams in the pool. - gpu_streams.emplace_back(c10::cuda::getStreamFromPool(/* high priority */ false, gpu)); - } - return gpu_streams[stream_id]; - } - - // Uses the logical stream id from the thread local to pick the stream - // We're going to migrate all usages to this case API instead of passing the - // stream id directly - cudaStream_t GetStream(DeviceIndex gpu) { - return c10::cuda::getCurrentCUDAStream(gpu).stream(); - } - - cudaStream_t GetStream(DeviceIndex gpu, StreamId stream_id) { - return GetCUDAStream(gpu, stream_id).stream(); - } - - // Uses the logical stream id from the thread local to pick the stream - // We're going to migrate all usages to this case API instead of passing the - // stream id directly - cublasHandle_t GetHandle(DeviceIndex gpu) { - return GetHandle(c10::cuda::getCurrentCUDAStream(gpu)); - } - - cublasHandle_t GetHandle(c10::cuda::CUDAStream cuda_stream) { - CUDAGuard guard(cuda_stream.device_index()); - // Default construct in the map if it doesn't exist, and return a mutable - // reference to it. - auto& r = cublas_handles_[cuda_stream]; - if (r == nullptr) { - CUBLAS_ENFORCE(cublasCreate(&r)); - // The default is CUBLAS_POINTER_MODE_HOST. You can override - // it after obtaining the cublas handle, but do that with - // caution. - CUBLAS_ENFORCE(cublasSetPointerMode(r, CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSetStream(r, cuda_stream)); - } - return r; - } - -#ifdef CAFFE2_USE_CUDNN - // Uses the logical stream id from the thread local to pick the stream - // We're going to migrate all usages to this case API instead of passing the - // stream id directly - cudnnHandle_t GetCudnnHandle(DeviceIndex gpu) { - return GetCudnnHandle(c10::cuda::getCurrentCUDAStream(gpu)); - } - - cudnnHandle_t GetCudnnHandle(c10::cuda::CUDAStream cuda_stream) { - CUDAGuard guard(cuda_stream.device_index()); - auto& r = cudnn_handles_[cuda_stream]; - if (r == nullptr) { - CUDNN_ENFORCE(cudnnCreate(&r)); - CUDNN_ENFORCE(cudnnSetStream(r, cuda_stream)); - } - return r; - } -#endif // CAFFE2_USE_CUDNN - - ~ThreadLocalCUDAObjects() noexcept { - for (auto element : cublas_handles_) { - if (element.second) { - CUBLAS_CHECK(cublasDestroy(element.second)); - } - } -#ifdef CAFFE2_USE_CUDNN - for (auto element : cudnn_handles_) { - if (element.second) { -#ifdef _WIN32 - // this is because of something dumb in the ordering of - // destruction. Sometimes at exit, the cuda context would already - // be destroyed by the time this gets destroyed. This happens on - // windows with cuda 11 and cuda 12. - cudnnDestroy(element.second); -#else - CUDNN_CHECK(cudnnDestroy(element.second)); -#endif // _WIN32 - } - } -#endif // CAFFE2_USE_CUDNN - } - // WARNING: mapping from logical stream ID to c10::cuda::CUDAStream - // is NOT bijective; multiple logical stream IDs may map to the - // same underlying stream ID. - vector cuda_streams_[C10_COMPILE_TIME_MAX_GPUS]; - std::unordered_map cublas_handles_; -#ifdef CAFFE2_USE_CUDNN - std::unordered_map cudnn_handles_; -#endif // CAFFE2_USE_CUDNN -}; - -class CAFFE2_CUDA_API CUDAContext final : public BaseContext { - public: - // The default cuda context constructor. - explicit CUDAContext(DeviceIndex gpu_id = -1); - explicit CUDAContext(const DeviceOption& option); - explicit CUDAContext(Device device) - : CUDAContext(DeviceToOption(device)) {} - - ~CUDAContext() override; - - inline void SwitchToDevice(StreamId stream_id) override { - getCudaObjects().SetCurrentStreamId(gpu_id_, stream_id); - CaffeCudaSetDevice(gpu_id_); - } - - // void SwitchToDevice() - using BaseContext::SwitchToDevice; - - inline void WaitEvent(const Event& ev) override { - ev.Wait(CUDA, this); - } - - inline void Record(Event* ev, const char* err_msg = nullptr) const override { - CAFFE_ENFORCE(ev, "Event must not be null."); - ev->Record(CUDA, this, err_msg); - } - - // Note on current use cases: - // FinishDeviceComputation must be called on the same cpu thread as - // SwitchToDevice() - void FinishDeviceComputation() override { - CUDA_ENFORCE(cudaStreamSynchronize(getCudaObjects().GetStream(gpu_id_))); - } - - inline int device_id() const { - return gpu_id_; - } - - inline c10::cuda::CUDAStream stream() const { - return at::cuda::getStreamFromExternal(getCudaObjects().GetStream(gpu_id_), gpu_id_); - } - - inline cudaStream_t cuda_stream() const { - return getCudaObjects().GetStream(gpu_id_); - } - - static cudaStream_t cuda_stream(DeviceIndex gpu_id, StreamId stream_id) { - return getCudaObjects().GetStream(gpu_id, stream_id); - } - - cublasHandle_t cublas_handle() { - return getCudaObjects().GetHandle(gpu_id_); - } - -#ifdef CAFFE2_USE_CUDNN - cudnnHandle_t cudnn_handle() { - return getCudaObjects().GetCudnnHandle(gpu_id_); - } -#endif // CAFFE2_USE_CUDNN - - curandGenerator_t& curand_generator() { - if (!curand_generator_) { - CUDAGuard guard(gpu_id_); - CURAND_ENFORCE( - curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); - CURAND_ENFORCE( - curandSetPseudoRandomGeneratorSeed(curand_generator_, random_seed_)); - TORCH_CHECK_NOTNULL(curand_generator_); - } - CURAND_ENFORCE(curandSetStream(curand_generator_, cuda_stream())); - return curand_generator_; - } - - inline static at::DataPtr New(size_t nbytes) { - return GetAllocator(CUDA)->allocate(nbytes); - } - - // Get a mutex to lock out cudaMalloc / cudaFree calls when - // NCCL kernels are being launched. Should remove threat of - // deadlocks - static std::mutex& mutex(); - - // Functions to query memory stats. Only available if flag - // --caffe2_gpu_memory_tracking is enabled. - static std::vector TotalMemoryByGpu(); - static std::vector MaxMemoryByGpu(); - - template - inline void CopyBytes(size_t nbytes, const void* src, void* dst) { - CUDA_ENFORCE(cudaMemcpyAsync( - dst, - src, - nbytes, - cudaMemcpyDefault, - getCudaObjects().GetStream(gpu_id_))); - } - - void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override { - CopyBytes(nbytes, src, dst); - } - - void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytes(nbytes, src, dst); - } - - void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override { - CopyBytes(nbytes, src, dst); - } - - template - inline void Copy(int n, const T* src, T* dst) { - CopyBytes(n * sizeof(T), - static_cast(src), - static_cast(dst)); - } - - template - inline void - CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { - CAFFE_ENFORCE(!meta.copy(), "CUDAContext requires fundamental types."); - CopyBytes(n * meta.itemsize(), src, dst); - } - - static void CopyBytesAsync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device); - static void CopyBytesSync( - size_t nbytes, - const void* src, - Device src_device, - void* dst, - Device dst_device); - - // By default CUDA operators have async device parts - static bool HasAsyncPartDefault() { - return true; - } - - static bool SupportsAsyncScheduling() { - return true; - } - - static bool IsStreamFree(const DeviceOption& option, StreamId stream_id) { - const auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id); - const auto status = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream)); - if (status == cudaErrorNotReady) { - // ignore and clear the error if not ready - C10_CUDA_CLEAR_ERROR(); - } else { - C10_CUDA_CHECK(status); // Reraise error - } - return status == cudaSuccess; - } - - at::Device device() const override { - return at::Device(CUDA, gpu_id_); - } - - DeviceType device_type() const override { - return CUDA; - } - - static constexpr DeviceType GetDeviceType() { - return CUDA; - } - - protected: - int gpu_id_; - int random_seed_; - curandGenerator_t curand_generator_{nullptr}; - static ThreadLocalCUDAObjects& getCudaObjects(); -}; - -using TensorCUDA = Tensor; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_CONTEXT_GPU_H_ diff --git a/caffe2/core/context_gpu_test.cc b/caffe2/core/context_gpu_test.cc deleted file mode 100644 index 9eb92b429ef0..000000000000 --- a/caffe2/core/context_gpu_test.cc +++ /dev/null @@ -1,161 +0,0 @@ -#include -#include -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include - -namespace caffe2 { - -TEST(CUDATest, HasCudaRuntime) { - EXPECT_TRUE(HasCudaRuntime()); -} - -TEST(CUDAContextTest, TestAllocDealloc) { - if (!HasCudaGPU()) return; - CUDAContext context(0); - context.SwitchToDevice(); - auto data = CUDAContext::New(10 * sizeof(float)); - EXPECT_NE(data.get(), nullptr); -} - -TEST(CUDAContextTest, TestSetGetDeviceWithoutCaffeMode) { - // For a while, set full device control to be true. - for (int i = 0; i < NumCudaDevices(); ++i) { - CaffeCudaSetDevice(i); - EXPECT_EQ(CaffeCudaGetDevice(), i); - } - for (int i = NumCudaDevices() - 1; i >= 0; --i) { - CaffeCudaSetDevice(i); - EXPECT_EQ(CaffeCudaGetDevice(), i); - } -} - -TEST(CUDAContextTest, MemoryPoolAllocateDealloc) { - if (!HasCudaGPU()) - return; - if (GetCudaMemoryPoolType() == CudaMemoryPoolType::NONE) { - LOG(ERROR) << "Choose a memory type that is not none to test memory pool."; - return; - } - const int nbytes = 1048576; - for (int i = 0; i < NumCudaDevices(); ++i) { - LOG(INFO) << "Device " << i << " of " << NumCudaDevices(); - CUDAGuard guard(i); - auto allocated = CUDAContext::New(nbytes); - EXPECT_NE(allocated, nullptr); - cudaPointerAttributes attr; - CUDA_ENFORCE(cudaPointerGetAttributes(&attr, allocated.get())); - EXPECT_EQ(attr.type, cudaMemoryTypeDevice); - EXPECT_EQ(attr.device, i); - void* prev_allocated = allocated.get(); - allocated.clear(); - auto new_allocated = CUDAContext::New(nbytes); - // With a pool, the above allocation should yield the same address. - EXPECT_EQ(new_allocated.get(), prev_allocated); - // But, if we are allocating something larger, we will have a different - // chunk of memory. - auto larger_allocated = CUDAContext::New(nbytes * 2); - EXPECT_NE(larger_allocated.get(), prev_allocated); - } -} - -cudaStream_t getStreamForHandle(cublasHandle_t handle) { - cudaStream_t stream = nullptr; - CUBLAS_ENFORCE(cublasGetStream(handle, &stream)); - TORCH_CHECK_NOTNULL(stream); - return stream; -} - -TEST(CUDAContextTest, TestSameThreadSameObject) { - if (!HasCudaGPU()) return; - CUDAContext context_a(0); - CUDAContext context_b(0); - EXPECT_EQ(context_a.cuda_stream(), context_b.cuda_stream()); - EXPECT_EQ(context_a.cublas_handle(), context_b.cublas_handle()); - EXPECT_EQ( - context_a.cuda_stream(), getStreamForHandle(context_b.cublas_handle())); - // CuRAND generators are context-local. - EXPECT_NE(context_a.curand_generator(), context_b.curand_generator()); -} - -TEST(CUDAContextTest, TestSameThreadTempObject) { - if (!HasCudaGPU()) - return; - CUDAContext context_outer(0); // gpu id - context_outer.SwitchToDevice(); - - if (NumCudaDevices() >= 2) { - auto before_stream = context_outer.cuda_stream(); - - // try to mess up current device - CUDAContext context_different_device(1); - context_different_device.SwitchToDevice(10); - - // go back - context_outer.SwitchToDevice(); - EXPECT_EQ(context_outer.cuda_stream(), before_stream); - - // do nothing - infers the current device and stream - CUDAContext context_noop; - EXPECT_EQ(context_outer.cuda_stream(), before_stream); - EXPECT_EQ(context_noop.cuda_stream(), before_stream); - - - // override stream - the previous context is not valid any more until - // SwitchToDevice is called again (needs to be refactored into proper guard) - CUDAContext context_override; - context_override.SwitchToDevice(1); // logical stream id - EXPECT_NE(context_override.cuda_stream(), before_stream); - // note, that accessing streams from context_outer and context_noop is not - // semantically valid any more - } -} - -TEST(CUDAContextTest, TestSameThreadDifferntObjectIfDifferentDevices) { - if (NumCudaDevices() > 1) { - CUDAContext context_a(0); - CUDAContext context_b(1); - EXPECT_NE(context_a.cuda_stream(), context_b.cuda_stream()); - EXPECT_NE(context_a.cublas_handle(), context_b.cublas_handle()); - EXPECT_NE( - context_a.cuda_stream(), getStreamForHandle(context_b.cublas_handle())); - EXPECT_NE(context_a.curand_generator(), context_b.curand_generator()); - } -} - -namespace { -// A test function to return a stream address from a temp CUDA context. You -// should not use that stream though, because the actual stream is destroyed -// after thread exit. -void TEST_GetStreamAddress(cudaStream_t* ptr) { - CUDAContext context(0); - context.SwitchToDevice(); - *ptr = context.cuda_stream(); - // Sleep for a while so we have concurrent thread executions - std::this_thread::sleep_for(std::chrono::seconds(1)); -} -} // namespace - -TEST(CUDAContextTest, TestDifferntThreadDifferentobject) { - if (!HasCudaGPU()) return; - std::array temp = {0}; - // Same thread - TEST_GetStreamAddress(&temp[0]); - TEST_GetStreamAddress(&temp[1]); - EXPECT_TRUE(temp[0] != nullptr); - EXPECT_TRUE(temp[1] != nullptr); - EXPECT_EQ(temp[0], temp[1]); - // Different threads - std::thread thread_a(TEST_GetStreamAddress, &temp[0]); - std::thread thread_b(TEST_GetStreamAddress, &temp[1]); - thread_a.join(); - thread_b.join(); - EXPECT_TRUE(temp[0] != nullptr); - EXPECT_TRUE(temp[1] != nullptr); - EXPECT_NE(temp[0], temp[1]); -} - -} // namespace caffe2 diff --git a/caffe2/core/context_test.cc b/caffe2/core/context_test.cc deleted file mode 100644 index 304f973576c1..000000000000 --- a/caffe2/core/context_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include - -#include -#include -#include "caffe2/core/context.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { - -TEST(CPUContextTest, TestAllocAlignment) { - for (int i = 1; i < 10; ++i) { - auto data = CPUContext::New(i); - EXPECT_EQ((reinterpret_cast(data.get()) % gAlignment), 0); - // data is freed when out of scope - } -} - -TEST(CPUContextTest, TestAllocDealloc) { - auto data_ptr = CPUContext::New(10 * sizeof(float)); - float* data = static_cast(data_ptr.get()); - EXPECT_NE(data, nullptr); - auto dst_data_ptr = CPUContext::New(10 * sizeof(float)); - float* dst_data = static_cast(dst_data_ptr.get()); - EXPECT_NE(dst_data, nullptr); - for (int i = 0; i < 10; ++i) { - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - data[i] = i; - } - DeviceOption option; - CPUContext context(option); - context.CopyToCPU(10, data, dst_data); - for (int i = 0; i < 10; ++i) { - EXPECT_FLOAT_EQ(dst_data[i], i); - } - // data_ptr is freed when out of scope -} - -} // namespace caffe2 diff --git a/caffe2/core/event_gpu.cc b/caffe2/core/event_gpu.cc deleted file mode 100644 index 82000de79011..000000000000 --- a/caffe2/core/event_gpu.cc +++ /dev/null @@ -1,227 +0,0 @@ -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/event_cpu.h" -#include "caffe2/core/operator.h" - -#include -#include - -namespace caffe2 { - -struct CudaEventWrapper { - explicit CudaEventWrapper(const DeviceOption& option) - : cuda_stream_(nullptr), - device_id_(option.device_id()), - status_(EventStatus::EVENT_INITIALIZED) { - CAFFE_ENFORCE(option.device_type(), PROTO_CUDA); - CUDAGuard g(device_id_); - try { - CUDA_ENFORCE(cudaEventCreateWithFlags( - &cuda_event_, cudaEventDefault | cudaEventDisableTiming)); - } catch (const Error&) { - std::cerr << "ERROR: Failed to load CUDA.\n" - << "HINT: Check that this binary contains GPU code." - << std::endl; - throw; - } - } - ~CudaEventWrapper() { - CUDAGuard g(device_id_); - CUDA_CHECK(cudaEventDestroy(cuda_event_)); - } - - cudaEvent_t cuda_event_; - cudaStream_t cuda_stream_; - int device_id_; - - std::atomic status_; - std::mutex mutex_recorded_; - std::condition_variable cv_recorded_; - std::string err_msg_; -}; - -namespace { -const std::string kNoError = "No error"; -} - -void EventCreateCUDA(const DeviceOption& option, Event* event) { - event->event_ = std::make_shared(option); -} - -void EventRecordCUDA(Event* event, const void* context, const char* err_msg) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - - // Possible state changes: - // INITIALIZED -> SCHEDULED/FAILED - // SCHEDULED -> SUCCESS/FAILED - // SUCCESS/FAILED - terminal - // - // No further changes to cuda_event_ and cuda_stream_ after transitioning - // from INITIALIZED - // No further changes to err_msg_ after transitioning into FAILED - - CAFFE_ENFORCE_EQ( - wrapper->status_, - EventStatus::EVENT_INITIALIZED, - "Calling Record multiple times"); - - if (!err_msg) { - // When recording, one needs to make sure that the current gpu id is - // correct. - // TODO(jiayq): move the enforce logic to the caller? - const auto& current_device = CaffeCudaGetDevice(); - CAFFE_ENFORCE_EQ( - current_device, - wrapper->device_id_, - "When you call EventRecordCUDA, your current device should be the same " - "as the device specified by the event."); - CAFFE_ENFORCE_EQ( - current_device, - static_cast(context)->device_id()); - CUDA_ENFORCE(cudaEventRecord( - wrapper->cuda_event_, - static_cast(context)->cuda_stream())); - wrapper->cuda_stream_ = - static_cast(context)->cuda_stream(); - wrapper->status_ = EventStatus::EVENT_SCHEDULED; - } else { - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } - } - wrapper->cv_recorded_.notify_all(); -} - -void EventFinishCUDA(const Event* event) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - while (wrapper->status_ == EventStatus::EVENT_INITIALIZED) { - wrapper->cv_recorded_.wait(lock); - } - } - - if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) { - // ok, even if event is already completed and status was not yet updated - CUDAGuard g(wrapper->device_id_); - auto cudaResult = cudaEventSynchronize(wrapper->cuda_event_); - if (cudaResult == cudaSuccess) { - wrapper->status_ = EventStatus::EVENT_SUCCESS; - } else { - const auto& err_msg = cudaGetErrorString(cudaResult); - - std::unique_lock lock(wrapper->mutex_recorded_); - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } - } -} - -// Both waiter and event are CUDA. Non-blocking -void EventWaitCUDACUDA(const Event* event, void* context) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - while (wrapper->status_ == EventStatus::EVENT_INITIALIZED) { - wrapper->cv_recorded_.wait(lock); - } - } - - if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) { - // ok, even if event is already completed and status was not yet updated - auto context_stream = static_cast(context)->cuda_stream(); - auto event_stream = wrapper->cuda_stream_; - if (context_stream != event_stream) { - // CAFFE_ENFORCE_EQ( - // CaffeCudaGetDevice(), - // static_cast(context)->device_id()); - CUDA_CHECK(cudaStreamWaitEvent(context_stream, wrapper->cuda_event_, 0)); - } - } -} - -// Waiter is CPU, event is CUDA -void EventWaitCPUCUDA(const Event* event, void* context) { - EventFinishCUDA(event); -} - -// Waiter is CUDA, event is CPU -void EventWaitCUDACPU(const Event* event, void* context) { - event->Finish(); // calls EventFinishCPU -} - -EventStatus EventQueryCUDA(const Event* event) { - auto* wrapper = static_cast(event->event_.get()); - if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) { - auto cudaResult = cudaEventQuery(wrapper->cuda_event_); - if (cudaResult == cudaSuccess) { - wrapper->status_ = EventStatus::EVENT_SUCCESS; - } else if (cudaResult != cudaErrorNotReady) { - const auto& err_msg = cudaGetErrorString(cudaResult); - - std::unique_lock lock(wrapper->mutex_recorded_); - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } else { - // ignore and clear the error if not ready - (void)cudaGetLastError(); - } - } - return static_cast(wrapper->status_.load()); -} - -const std::string& EventErrorMessageCUDA(const Event* event) { - auto* wrapper = static_cast(event->event_.get()); - // supposed to be called after EventQueryCUDA to update status first - if (wrapper->status_ == EventStatus::EVENT_FAILED) { - return wrapper->err_msg_; - } else { - return kNoError; - } -} - -void EventSetFinishedCUDA(const Event* event, const char* err_msg) { - auto* wrapper = static_cast(event->event_.get()); - { - std::unique_lock lock(wrapper->mutex_recorded_); - - CAFFE_ENFORCE_EQ( - wrapper->status_, - EventStatus::EVENT_INITIALIZED, - "Calling SetFinished on recorded CUDA event"); - - if (!err_msg) { - wrapper->status_ = EventStatus::EVENT_SUCCESS; - } else { - wrapper->err_msg_ = err_msg; - wrapper->status_ = EventStatus::EVENT_FAILED; - } - } - wrapper->cv_recorded_.notify_all(); -} - -void EventResetCUDA(Event* event) { - auto* wrapper = static_cast(event->event_.get()); - std::unique_lock lock(wrapper->mutex_recorded_); - wrapper->status_ = EventStatus::EVENT_INITIALIZED; - wrapper->err_msg_ = ""; - wrapper->cuda_stream_ = nullptr; -} - -REGISTER_EVENT_CREATE_FUNCTION(CUDA, EventCreateCUDA); -REGISTER_EVENT_RECORD_FUNCTION(CUDA, EventRecordCUDA); -REGISTER_EVENT_WAIT_FUNCTION(CUDA, CUDA, EventWaitCUDACUDA); -REGISTER_EVENT_WAIT_FUNCTION(CPU, CUDA, EventWaitCPUCUDA); -REGISTER_EVENT_WAIT_FUNCTION(CUDA, CPU, EventWaitCUDACPU); -REGISTER_EVENT_FINISH_FUNCTION(CUDA, EventFinishCUDA); - -REGISTER_EVENT_QUERY_FUNCTION(CUDA, EventQueryCUDA); -REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(CUDA, EventErrorMessageCUDA); -REGISTER_EVENT_SET_FINISHED_FUNCTION(CUDA, EventSetFinishedCUDA); -REGISTER_EVENT_RESET_FUNCTION(CUDA, EventResetCUDA); - -REGISTER_EVENT_WAIT_FUNCTION(MKLDNN, CUDA, EventWaitCPUCUDA); -REGISTER_EVENT_WAIT_FUNCTION(CUDA, MKLDNN, EventWaitCUDACPU); - -} // namespace caffe2 diff --git a/caffe2/core/event_gpu_test.cc b/caffe2/core/event_gpu_test.cc deleted file mode 100644 index 18fe152198e2..000000000000 --- a/caffe2/core/event_gpu_test.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include -#include "caffe2/core/context.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/event.h" - -namespace caffe2 { - -TEST(EventCUDATest, EventBasics) { - if (!HasCudaGPU()) - return; - DeviceOption device_cpu; - device_cpu.set_device_type(PROTO_CPU); - DeviceOption device_cuda; - device_cuda.set_device_type(PROTO_CUDA); - - CPUContext context_cpu(device_cpu); - CUDAContext context_cuda(device_cuda); - - Event event_cpu(device_cpu); - Event event_cuda(device_cuda); - - // CPU context and event interactions - context_cpu.Record(&event_cpu); - event_cpu.SetFinished(); - event_cpu.Finish(); - context_cpu.WaitEvent(event_cpu); - - event_cpu.Reset(); - event_cpu.Record(CPU, &context_cpu); - event_cpu.SetFinished(); - event_cpu.Wait(CPU, &context_cpu); - - // CUDA context and event interactions - context_cuda.SwitchToDevice(); - context_cuda.Record(&event_cuda); - context_cuda.WaitEvent(event_cuda); - event_cuda.Finish(); - - event_cuda.Reset(); - event_cuda.Record(CUDA, &context_cuda); - event_cuda.Wait(CUDA, &context_cuda); - - // CPU context waiting for CUDA event - context_cpu.WaitEvent(event_cuda); - - // CUDA context waiting for CPU event - context_cuda.WaitEvent(event_cpu); -} - -} // namespace caffe2 diff --git a/caffe2/core/event_test.cc b/caffe2/core/event_test.cc deleted file mode 100644 index ef25ae891e9a..000000000000 --- a/caffe2/core/event_test.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include -#include "caffe2/core/context.h" -#include "caffe2/core/event.h" - -namespace caffe2 { - -TEST(EventCPUTest, EventBasics) { - DeviceOption device_option; - device_option.set_device_type(PROTO_CPU); - Event event(device_option); - CPUContext context; - - context.Record(&event); - event.SetFinished(); - - context.WaitEvent(event); - event.Finish(); - - event.Reset(); - event.Record(CPU, &context); - event.SetFinished(); - event.Wait(CPU, &context); -} - -TEST(EventCPUTest, EventErrors) { - DeviceOption device_option; - device_option.set_device_type(PROTO_CPU); - Event event(device_option); - - event.SetFinished(); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(event.SetFinished("error"), caffe2::EnforceNotMet); - ASSERT_EQ(event.ErrorMessage(), "No error"); - - event.Reset(); - event.SetFinished("error 1"); - event.SetFinished("error 2"); - ASSERT_EQ(event.ErrorMessage(), "error 1"); -} - -} // namespace caffe2 diff --git a/caffe2/core/flags.h b/caffe2/core/flags.h deleted file mode 100644 index 54f1f41f2fb3..000000000000 --- a/caffe2/core/flags.h +++ /dev/null @@ -1,4 +0,0 @@ -#pragma once - -#include "c10/util/Flags.h" -#include "caffe2/core/common.h" diff --git a/caffe2/core/graph_test.cc b/caffe2/core/graph_test.cc deleted file mode 100644 index 8aa4f1610793..000000000000 --- a/caffe2/core/graph_test.cc +++ /dev/null @@ -1,200 +0,0 @@ -#include -#include "caffe2/core/graph.h" -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -namespace { - -using transform::Graph; - -static std::atomic counter; - -class GraphDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - counter.fetch_add(1); - return true; - } -}; - -REGISTER_CPU_OPERATOR(GraphDummyOp1, GraphDummyOp); - -OPERATOR_SCHEMA(GraphDummyOp1) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(GraphDummyOp2, GraphDummyOp); - -OPERATOR_SCHEMA(GraphDummyOp2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(GraphDummyOp3, GraphDummyOp); - -OPERATOR_SCHEMA(GraphDummyOp3) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -// Checks if two netdefs are in terms of type, input, and output. -void compare_netdefs(const NetDef& net_a, const NetDef& net_b) { - EXPECT_EQ(net_a.op_size(), net_b.op_size()); - for (int i = 0; i < net_a.op_size(); i++) { - EXPECT_EQ(net_a.op(i).type(), net_b.op(i).type()); - EXPECT_EQ(net_a.op(i).input_size(), net_b.op(i).input_size()); - for (int j = 0; j < net_a.op(i).input_size(); j++) { - EXPECT_EQ(net_a.op(i).input(j), net_b.op(i).input(j)); - } - EXPECT_EQ(net_a.op(i).output_size(), net_b.op(i).output_size()); - for (int j = 0; j < net_a.op(i).output_size(); j++) { - EXPECT_EQ(net_a.op(i).output(j), net_b.op(i).output(j)); - } - } -} - -TEST(GraphTest, TestGenerateGraphChain) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "GraphDummyOp2", {"mid3"}, {"out"}); - Graph g(netdef); - EXPECT_EQ(g.size(), 4); - for (int i = 0; i < 4; i++) { - if (i < 3) { - EXPECT_EQ(g.node(i).children.size(), 1); - EXPECT_TRUE(g.node(i).children.count(i + 1)); - } - if (i > 0) { - EXPECT_EQ(g.node(i).parents.size(), 1); - EXPECT_TRUE(g.node(i).parents.count(i - 1)); - } - } - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -TEST(GraphTest, TestGenerateGraphChainInPlace) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"out"}); - AddOp(&netdef, "GraphDummyOp2", {"out"}, {"out"}); - AddOp(&netdef, "GraphDummyOp1", {"out"}, {"out"}); - AddOp(&netdef, "GraphDummyOp2", {"out"}, {"out"}); - Graph g(netdef); - EXPECT_EQ(g.size(), 4); - for (int i = 0; i < 4; i++) { - if (i < 3) { - EXPECT_EQ(g.node(i).children.size(), 1); - EXPECT_TRUE(g.node(i).children.count(i + 1)); - } - if (i > 0) { - EXPECT_EQ(g.node(i).parents.size(), 1); - EXPECT_TRUE(g.node(i).parents.count(i - 1)); - } - } - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -// Diamond Graph -TEST(GraphTest, TestGenerateGraphBranch) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp2", {"mid1"}, {"mid3"}); - AddOp(&netdef, "GraphDummyOp3", {"mid2", "mid3"}, {"out"}); - - Graph g(netdef); - - EXPECT_EQ(g.size(), 4); - EXPECT_EQ(g.node(0).parents.size(), 0); - EXPECT_EQ(g.node(0).children.size(), 2); - EXPECT_EQ(g.node(1).parents.size(), 1); - EXPECT_EQ(g.node(1).children.size(), 1); - EXPECT_EQ(g.node(2).parents.size(), 1); - EXPECT_EQ(g.node(2).children.size(), 1); - EXPECT_EQ(g.node(3).parents.size(), 2); - EXPECT_EQ(g.node(3).children.size(), 0); - - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -// Double Diamond Graph, reused names -TEST(GraphTest, TestReusedInputs) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"}); - - Graph g(netdef); - - EXPECT_EQ(g.size(), 7); - EXPECT_EQ(g.node(0).parents.size(), 0); - EXPECT_EQ(g.node(0).children.size(), 2); - EXPECT_EQ(g.node(1).parents.size(), 1); - EXPECT_EQ(g.node(1).children.size(), 1); - EXPECT_EQ(g.node(2).parents.size(), 1); - EXPECT_EQ(g.node(2).children.size(), 1); - EXPECT_EQ(g.node(3).parents.size(), 2); - EXPECT_EQ(g.node(3).children.size(), 2); - EXPECT_EQ(g.node(4).parents.size(), 1); - EXPECT_EQ(g.node(4).children.size(), 1); - EXPECT_EQ(g.node(5).parents.size(), 1); - EXPECT_EQ(g.node(5).children.size(), 1); - EXPECT_EQ(g.node(6).parents.size(), 2); - EXPECT_EQ(g.node(6).children.size(), 0); - - NetDef retrieved_net = g.GetNetDef(); - compare_netdefs(retrieved_net, netdef); -} - -TEST(GraphTest, TestGetPerimeter) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "GraphDummyOp1", {"in"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp3", {"mid1", "mid2"}, {"in"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "GraphDummyOp2", {"in"}, {"mid2"}); - AddOp(&netdef, "GraphDummyOp1", {"mid1", "mid2"}, {"in"}); - - Graph g(netdef); - std::vector subgraph = {3}; - - auto subgraph_input = g.GetSubgraphInput(subgraph); - EXPECT_EQ(subgraph_input.size(), 2); - EXPECT_EQ(subgraph_input[0], std::make_pair(string("mid1"), 1)); - EXPECT_EQ(subgraph_input[1], std::make_pair(string("mid2"), 2)); - - auto subgraph_output = g.GetSubgraphOutput(subgraph); - EXPECT_EQ(subgraph_output.size(), 2); - EXPECT_EQ(subgraph_output[0], std::make_pair(string("in"), 4)); - EXPECT_EQ(subgraph_output[1], std::make_pair(string("in"), 5)); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/core/hip/common_miopen.h b/caffe2/core/hip/common_miopen.h deleted file mode 100644 index 6901055813cb..000000000000 --- a/caffe2/core/hip/common_miopen.h +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef CAFFE2_CORE_COMMON_MIOPEN_H_ -#define CAFFE2_CORE_COMMON_MIOPEN_H_ - -#include -#include -#include "miopen/miopen.h" -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/types.h" -#include "caffe2/proto/caffe2_pb.h" - -#define MIOPEN_VERSION 1399 - -namespace caffe2 { - -namespace internal { -/** - * A helper function to obtain miopen error strings. - */ -inline const char* miopenGetErrorString(miopenStatus_t status) -{ - switch(status) - { - case miopenStatusSuccess: return "MIOPEN_STATUS_SUCCESS"; - case miopenStatusNotInitialized: return "MIOPEN_STATUS_NOT_INITIALIZED"; - case miopenStatusAllocFailed: return "MIOPEN_STATUS_ALLOC_FAILED"; - case miopenStatusBadParm: return "MIOPEN_STATUS_BAD_PARAM"; - case miopenStatusInternalError: return "MIOPEN_STATUS_INTERNAL_ERROR"; - case miopenStatusInvalidValue: return "MIOPEN_STATUS_INVALID_VALUE"; - case miopenStatusNotImplemented: return "MIOPEN_STATUS_NOT_SUPPORTED"; - case miopenStatusUnknownError: return "MIOPEN_STATUS_UNKNOWN_ERROR"; - default: return "MIOPEN_STATUS_UNKNOWN_ERROR"; - } -} -} // namespace internal - -// A macro that wraps around a miopen statement so we can check if the miopen -// execution finishes or not. -#define MIOPEN_ENFORCE(condition) \ - do \ - { \ - miopenStatus_t status = condition; \ - CAFFE_ENFORCE_EQ(status, \ - miopenStatusSuccess, \ - ", Error at: ", \ - __FILE__, \ - ":", \ - __LINE__, \ - ": ", \ - ::caffe2::internal::miopenGetErrorString(status)); \ - } while(0) -#define MIOPEN_CHECK(condition) \ - do \ - { \ - miopenStatus_t status = condition; \ - CHECK(status == miopenStatusSuccess) << ::caffe2::internal::miopenGetErrorString(status); \ - } while(0) - -// report the version of miopen Caffe2 was compiled with -inline size_t miopenCompiledVersion() { return MIOPEN_VERSION; } - -// report the runtime version of miopen -inline size_t miopenRuntimeVersion() { return MIOPEN_VERSION; } - -// Check compatibility of compiled and runtime miopen versions -inline void CheckMIOPENVersions() {} - -/** - * miopenTypeWrapper is a wrapper class that allows us to refer to the miopen type - * in a template function. The class is specialized explicitly for different - * data types below. - */ -template -class miopenTypeWrapper; - -template <> -class miopenTypeWrapper -{ - public: - static const miopenDataType_t type = miopenFloat; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() - { - static ScalingParamType v = 1.0; - return &v; - } - static const ScalingParamType* kZero() - { - static ScalingParamType v = 0.0; - return &v; - } -}; - -template <> -class miopenTypeWrapper -{ - public: - static const miopenDataType_t type = miopenHalf; - typedef const float ScalingParamType; - typedef float BNParamType; - static ScalingParamType* kOne() - { - static ScalingParamType v = 1.0; - return &v; - } - static ScalingParamType* kZero() - { - static ScalingParamType v = 0.0; - return &v; - } -}; - -/** - * miopenTensorDescWrapper is the placeholder that wraps around a - * miopenTensorDescriptor_t, allowing us to do descriptor change as-needed during - * runtime. - */ -class miopenTensorDescWrapper -{ - public: - miopenTensorDescWrapper() { MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&desc_)); } - ~miopenTensorDescWrapper() noexcept { MIOPEN_CHECK(miopenDestroyTensorDescriptor(desc_)); } - - inline miopenTensorDescriptor_t - Descriptor(const miopenDataType_t type, const vector& dims, bool* changed) - { - if(type_ == type && dims_ == dims) - { - // if not changed, simply return the current descriptor. - if(changed) - *changed = false; - return desc_; - } - CAFFE_ENFORCE_EQ( - dims.size(), 4, "MIOPEN currently only support 4-dimensional tensor descriptor"); - - type_ = type; - dims_ = dims; - MIOPEN_ENFORCE( - miopenSet4dTensorDescriptor(desc_, type, dims_[0], dims_[1], dims_[2], dims_[3])); - if(changed) - *changed = true; - return desc_; - } - - template - inline miopenTensorDescriptor_t Descriptor(const StorageOrder& order, const vector& dims) - { - return Descriptor(miopenTypeWrapper::type, dims, nullptr); - } - - private: - miopenTensorDescriptor_t desc_; - miopenDataType_t type_; - vector dims_; - C10_DISABLE_COPY_AND_ASSIGN(miopenTensorDescWrapper); -}; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_COMMON_MIOPEN_H_ diff --git a/caffe2/core/hip/common_miopen.hip b/caffe2/core/hip/common_miopen.hip deleted file mode 100644 index a617bad29a3d..000000000000 --- a/caffe2/core/hip/common_miopen.hip +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "caffe2/core/hip/common_miopen.h" -#include "caffe2/core/hip/miopen_wrapper.h" - -#include "caffe2/core/init.h" - -namespace caffe2 { - -MIOPENWrapper::PerGPUMIOPENStates& MIOPENWrapper::miopen_states() -{ - // New it (never delete) to avoid calling the destructors on process - // exit and racing against the CUDA shutdown sequence. - static auto* p = new MIOPENWrapper::PerGPUMIOPENStates(); - TORCH_CHECK_NOTNULL(p); - return *p; -} - -namespace { -bool PrintMIOPENInfo(int*, char***) -{ - VLOG(1) << "Caffe2 is built with MIOPEN version " << MIOPEN_VERSION; - return true; -} - -REGISTER_CAFFE2_INIT_FUNCTION(PrintMIOPENInfo, &PrintMIOPENInfo, "Print MIOPEN Info."); - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/hip/miopen_wrapper.h b/caffe2/core/hip/miopen_wrapper.h deleted file mode 100644 index f60bed6c277d..000000000000 --- a/caffe2/core/hip/miopen_wrapper.h +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. -#ifndef CAFFE2_CORE_MIOPEN_WRAPPERS_H_ -#define CAFFE2_CORE_MIOPEN_WRAPPERS_H_ - -#include "caffe2/core/hip/common_miopen.h" -#include "caffe2/core/hip/context_gpu.h" - -#include - -namespace caffe2 { - -class MIOPENWrapper; - -/** - * MIOPENWorkspace is a wrapper around a raw cuda pointer that holds the miopen - * scratch space. This struct is meant to be only used in MIOPENWrapper to - * provide a program-wide scratch space for MIOPEN. The reason behind it is that - * miopen function calls are usually very efficient, hence one probably does not - * want to run multiple miopen calls at the same time. As a result, one should - * not need more than one miopen workspace per device. - */ -struct MIOPENWorkspace -{ - ~MIOPENWorkspace() noexcept {} - - void* get(size_t nbytes) - { - if(nbytes_ < nbytes) - { - reset(); - data_ = HIPContext::New(nbytes); - nbytes_ = nbytes; - } - CAFFE_ENFORCE_GE(nbytes_, nbytes); - return data_.get(); - } - - void reset() - { - data_.clear(); - nbytes_ = 0; - } - - private: - at::DataPtr data_; - size_t nbytes_{0}; -}; - -// MIOPENState is the owner of the MIOPENWorkspace, and serializes all -// executions of operations that use the state onto it's own stream -// (so multiple Net workers can reuse the same workspace from -// different threads and HIP streams). -class MIOPENState -{ - public: - explicit MIOPENState(size_t gpu_id) : gpu_id_(gpu_id) - { - HIPGuard g(gpu_id_); - MIOPEN_ENFORCE(miopenCreate(&miopen_handle_)); - HIP_ENFORCE(hipEventCreate(&before_)); - HIP_ENFORCE(hipEventCreate(&after_)); - HIP_ENFORCE(hipStreamCreate(&stream_)); - MIOPEN_ENFORCE(miopenSetStream(miopen_handle_, stream_)); - } - - ~MIOPENState() noexcept - { - HIPGuard g(gpu_id_); - MIOPEN_CHECK(miopenDestroy(miopen_handle_)); - HIP_CHECK(hipStreamDestroy(stream_)); - HIP_CHECK(hipEventDestroy(after_)); - HIP_CHECK(hipEventDestroy(before_)); - } - - miopenHandle_t& miopen_handle() { return miopen_handle_; } - - MIOPENWorkspace& workspace() { return workspace_; } - - template - void execute(hipStream_t stream, F&& f) - { - HIP_ENFORCE(hipEventRecord(before_, stream)); - HIP_ENFORCE(hipStreamWaitEvent(stream_, before_, 0)); - f(this); - HIP_ENFORCE(hipEventRecord(after_, stream_)); - HIP_ENFORCE(hipStreamWaitEvent(stream, after_, 0)); - } - - private: - miopenHandle_t miopen_handle_{nullptr}; - hipEvent_t before_{nullptr}; - hipEvent_t after_{nullptr}; - hipStream_t stream_{nullptr}; - MIOPENWorkspace workspace_; - size_t gpu_id_{0}; - C10_DISABLE_COPY_AND_ASSIGN(MIOPENState); -}; - -/** - * MIOPENWrapper is a class that wraps the miopen handles and miopen workspaces. - * - * The wrapper ensures that for each thread and each gpu, there is one - * identical miopen handle, which is also associated with the thread-local - * per-device hip stream. The wrapper also hosts the device-specific miopen - * workspace (scratch space for some miopen functions). - * - */ -class MIOPENWrapper -{ - public: - /** - * Creates a miopen wrapper associated with a HIPContext object. Note that - * the HIPContext object should outlive the MIOPENWrapper. - */ - explicit MIOPENWrapper(HIPContext* context) : context_(context) {} - - /** - * Returns the inline miopen handle that executes on the current - * thread's hip_stream. - */ - miopenHandle_t inline_miopen_handle() { return context_->miopen_handle(); } - - // Executes the closure F on the MIOPENState associated with state_idx - template - void with_miopen_state(size_t state_idx, F&& f) - { - CAFFE_ENFORCE(state_idx < CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES, "Invalid state_idx"); - auto& sync_state = miopen_states()[context_->device_id()][state_idx]; - - HIPGuard dg(context_->device_id()); - - // We need to serialize execution on the MIOPENState as we can't - // allow multiple threads to race through the cudaEventRecord - // calls (so a worker thread might wait on another worker thread's - // execution) - std::lock_guard g(sync_state.mutex); - if(!sync_state.state.get()) - { - sync_state.state.reset(new MIOPENState(context_->device_id())); - } - TORCH_CHECK_NOTNULL(sync_state.state.get())->execute(context_->hip_stream(), f); - } - - protected: - // Pointer to an external cuda context that the miopen wrapper will use. - HIPContext* context_; - - static constexpr size_t CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES = 4; - - struct SyncedMIOPENState - { - std::mutex mutex; - std::unique_ptr state; - }; - - using PerGPUMIOPENStates = std::array< - std::array, - C10_COMPILE_TIME_MAX_GPUS>; - static PerGPUMIOPENStates& miopen_states(); - - C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper); -}; - -}; // namespace caffe2 - -#endif diff --git a/caffe2/core/init.h b/caffe2/core/init.h deleted file mode 100644 index 8d0fbd3f1557..000000000000 --- a/caffe2/core/init.h +++ /dev/null @@ -1,179 +0,0 @@ -#ifndef CAFFE2_CORE_INIT_H_ -#define CAFFE2_CORE_INIT_H_ - -#include "caffe2/core/common.h" -#include "caffe2/core/flags.h" -#include "caffe2/core/logging.h" - -namespace caffe2 { - -namespace internal { -class TORCH_API Caffe2InitializeRegistry { - public: - typedef bool (*InitFunction)(int*, char***); - // Registry() is defined in .cpp file to make registration work across - // multiple shared libraries loaded with RTLD_LOCAL - static Caffe2InitializeRegistry* Registry(); - - void Register( - InitFunction function, - bool run_early, - const char* description, - const char* name = nullptr) { - if (name) { - named_functions_[name] = function; - } - if (run_early) { - // Disallow registration after GlobalInit of early init functions - CAFFE_ENFORCE(!early_init_functions_run_yet_); - early_init_functions_.emplace_back(function, description); - } else { - if (init_functions_run_yet_) { - // Run immediately, since GlobalInit already ran. This should be - // rare but we want to allow it in some cases. - LOG(WARNING) << "Running init function after GlobalInit: " - << description; - // TODO(orionr): Consider removing argc and argv for non-early - // registration. Unfortunately that would require a new InitFunction - // typedef, so not making the change right now. - // - // Note that init doesn't receive argc and argv, so the function - // might fail and we want to raise an error in that case. - int argc = 0; - char** argv = nullptr; - bool success = (function)(&argc, &argv); - CAFFE_ENFORCE(success); - } else { - // Wait until GlobalInit to run - init_functions_.emplace_back(function, description); - } - } - } - - bool RunRegisteredEarlyInitFunctions(int* pargc, char*** pargv) { - CAFFE_ENFORCE(!early_init_functions_run_yet_); - early_init_functions_run_yet_ = true; - return RunRegisteredInitFunctionsInternal( - early_init_functions_, pargc, pargv); - } - - bool RunRegisteredInitFunctions(int* pargc, char*** pargv) { - CAFFE_ENFORCE(!init_functions_run_yet_); - init_functions_run_yet_ = true; - return RunRegisteredInitFunctionsInternal(init_functions_, pargc, pargv); - } - - bool RunNamedFunction(const char* name, int* pargc, char*** pargv) { - if (named_functions_.count(name)) { - return named_functions_[name](pargc, pargv); - } - return false; - } - - private: - // Run all registered initialization functions. This has to be called AFTER - // all static initialization are finished and main() has started, since we are - // using logging. - bool RunRegisteredInitFunctionsInternal( - vector>& functions, - int* pargc, char*** pargv) { - for (const auto& init_pair : functions) { - VLOG(1) << "Running init function: " << init_pair.second; - if (!(*init_pair.first)(pargc, pargv)) { - LOG(ERROR) << "Initialization function failed."; - return false; - } - } - return true; - } - - Caffe2InitializeRegistry() {} - vector > early_init_functions_; - vector > init_functions_; - std::unordered_map named_functions_; - bool early_init_functions_run_yet_ = false; - bool init_functions_run_yet_ = false; -}; -} // namespace internal - -TORCH_API bool unsafeRunCaffe2InitFunction( - const char* name, - int* pargc = nullptr, - char*** pargv = nullptr); - -class TORCH_API InitRegisterer { - public: - InitRegisterer( - internal::Caffe2InitializeRegistry::InitFunction function, - bool run_early, - const char* description, - const char* name = nullptr) { - internal::Caffe2InitializeRegistry::Registry()->Register( - function, run_early, description, name); - } -}; - -#define REGISTER_CAFFE2_INIT_FUNCTION(name, function, description) \ - namespace { \ - ::caffe2::InitRegisterer \ - g_caffe2_initregisterer_##name(function, false, description, #name); \ - } // namespace - -#define REGISTER_CAFFE2_EARLY_INIT_FUNCTION(name, function, description) \ - namespace { \ - ::caffe2::InitRegisterer \ - g_caffe2_initregisterer_##name(function, true, description, #name); \ - } // namespace - -/** - * @brief Determine whether GlobalInit has already been run - */ -TORCH_API bool GlobalInitAlreadyRun(); - -class TORCH_API GlobalInitIsCalledGuard { - public: - GlobalInitIsCalledGuard() { - if (!GlobalInitAlreadyRun()) { - LOG(WARNING) - << "Caffe2 GlobalInit should be run before any other API calls."; - } - } -}; - -/** - * @brief Initialize the global environment of caffe2. - * - * Caffe2 uses a registration pattern for initialization functions. Custom - * initialization functions should take the signature - * bool (*func)(int*, char***) - * where the pointers to argc and argv are passed in. Caffe2 then runs the - * initialization in three phases: - * (1) Functions registered with REGISTER_CAFFE2_EARLY_INIT_FUNCTION. Note that - * since it is possible the logger is not initialized yet, any logging in - * such early init functions may not be printed correctly. - * (2) Parses Caffe-specific commandline flags, and initializes caffe logging. - * (3) Functions registered with REGISTER_CAFFE2_INIT_FUNCTION. - * If there is something wrong at each stage, the function returns false. If - * the global initialization has already been run, the function returns false - * as well. - * - * GlobalInit is re-entrant safe; a re-entrant call will no-op and exit. - * - * GlobalInit is safe to call multiple times but not idempotent; - * successive calls will parse flags and re-set caffe2 logging levels from - * flags as needed, but NOT re-run early init and init functions. - * - * GlobalInit is also thread-safe and can be called concurrently. - */ -TORCH_API bool GlobalInit(int* pargc, char*** argv); - -/** - * @brief Initialize the global environment without command line arguments - * - * This is a version of the GlobalInit where no argument is passed in. - * On mobile devices, use this global init, since we cannot pass the - * command line options to caffe2, no arguments are passed. - */ -TORCH_API bool GlobalInit(); -} // namespace caffe2 -#endif // CAFFE2_CORE_INIT_H_ diff --git a/caffe2/core/init_test.cc b/caffe2/core/init_test.cc deleted file mode 100644 index b94d610f5a91..000000000000 --- a/caffe2/core/init_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -#include -#include - -#include -#include "caffe2/core/init.h" -#include "caffe2/core/logging.h" - -namespace caffe2 { -namespace { -bool gTestInitFunctionHasBeenRun = false; -bool gTestFailInitFunctionHasBeenRun = false; - -bool TestInitFunction(int*, char***) { - gTestInitFunctionHasBeenRun = true; - return true; -} - -bool TestFailInitFunction(int*, char***) { - gTestFailInitFunctionHasBeenRun = true; - return false; -} - -REGISTER_CAFFE2_INIT_FUNCTION( - TestInitFunction, - &TestInitFunction, - "Just a test to see if GlobalInit invokes " - "registered functions correctly."); - -int dummy_argc = 1; -const char* dummy_name = "foo"; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-pro-type-const-cast) -char** dummy_argv = const_cast(&dummy_name); -} // namespace - -TEST(InitTest, TestInitFunctionHasRun) { - caffe2::GlobalInit(&dummy_argc, &dummy_argv); - EXPECT_TRUE(gTestInitFunctionHasBeenRun); - EXPECT_FALSE(gTestFailInitFunctionHasBeenRun); -} - -TEST(InitTest, CanRerunGlobalInit) { - caffe2::GlobalInit(&dummy_argc, &dummy_argv); - EXPECT_TRUE(caffe2::GlobalInit(&dummy_argc, &dummy_argv)); -} - -void LateRegisterInitFunction() { - ::caffe2::InitRegisterer testInitFunc( - TestInitFunction, false, "This should succeed but warn"); -} - -void LateRegisterEarlyInitFunction() { - ::caffe2::InitRegisterer testSecondInitFunc( - TestInitFunction, true, "This should fail for early init"); -} - -void LateRegisterFailInitFunction() { - ::caffe2::InitRegisterer testSecondInitFunc( - TestFailInitFunction, false, "This should fail for failed init"); -} - -TEST(InitTest, FailLateRegisterInitFunction) { - caffe2::GlobalInit(&dummy_argc, &dummy_argv); - LateRegisterInitFunction(); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(LateRegisterEarlyInitFunction(), ::c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(LateRegisterFailInitFunction(), ::c10::Error); - EXPECT_TRUE(gTestInitFunctionHasBeenRun); - EXPECT_TRUE(gTestFailInitFunctionHasBeenRun); -} - -} // namespace caffe2 diff --git a/caffe2/core/logging.h b/caffe2/core/logging.h deleted file mode 100644 index f47c0581b855..000000000000 --- a/caffe2/core/logging.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once -#include "c10/util/Logging.h" -#include "caffe2/core/common.h" diff --git a/caffe2/core/module_test.cc b/caffe2/core/module_test.cc deleted file mode 100644 index 585451d23b10..000000000000 --- a/caffe2/core/module_test.cc +++ /dev/null @@ -1,78 +0,0 @@ -#include -#include - -#include "caffe2/core/module.h" -#include "caffe2/core/operator.h" -#include -#include "caffe2/core/logging.h" - -// An explicitly defined module, testing correctness when we statically link a -// module -CAFFE2_MODULE(caffe2_module_test_static, "Static module for testing."); - -namespace caffe2 { - -class Caffe2ModuleTestStaticDummyOp : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - virtual string type() { - return "base"; - } -}; - -REGISTER_CPU_OPERATOR( - Caffe2ModuleTestStaticDummy, Caffe2ModuleTestStaticDummyOp); -OPERATOR_SCHEMA(Caffe2ModuleTestStaticDummy); - -TEST(ModuleTest, StaticModule) { - const string name = "caffe2_module_test_static"; - const auto& modules = CurrentModules(); - EXPECT_EQ(modules.count(name), 1); - EXPECT_TRUE(HasModule(name)); - - // LoadModule should not raise an error, since the module is already present. - LoadModule(name); - // Even a non-existing path should not cause error. - LoadModule(name, "/does/not/exist.so"); - EXPECT_EQ(modules.count(name), 1); - EXPECT_TRUE(HasModule(name)); - - // The module will then introduce the Caffe2ModuleTestStaticDummyOp. - OperatorDef op_def; - Workspace ws; - op_def.set_type("Caffe2ModuleTestStaticDummy"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); -} - -#ifdef CAFFE2_BUILD_SHARED_LIBS -TEST(ModuleTest, DynamicModule) { - const string name = "caffe2_module_test_dynamic"; - const auto& modules = CurrentModules(); - EXPECT_EQ(modules.count(name), 0); - EXPECT_FALSE(HasModule(name)); - - // Before loading, we should not be able to create the op. - OperatorDef op_def; - Workspace ws; - op_def.set_type("Caffe2ModuleTestDynamicDummy"); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - EXPECT_THROW( - CreateOperator(op_def, &ws), - EnforceNotMet); - - // LoadModule should load the proper module. - LoadModule(name); - EXPECT_EQ(modules.count(name), 1); - EXPECT_TRUE(HasModule(name)); - - // The module will then introduce the Caffe2ModuleTestDynamicDummyOp. - unique_ptr op_after_load = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op_after_load.get()); -} -#endif - -} // namespace caffe2 diff --git a/caffe2/core/net.h b/caffe2/core/net.h deleted file mode 100644 index 0726d8e8c6c9..000000000000 --- a/caffe2/core/net.h +++ /dev/null @@ -1,175 +0,0 @@ -#ifndef CAFFE2_CORE_NET_H_ -#define CAFFE2_CORE_NET_H_ - -#include -#include -#include -#include // NOLINT -#include -#include -#include - -#include "c10/core/thread_pool.h" -#include "c10/util/Registry.h" -#include "caffe2/core/blob.h" -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" -#include "caffe2/core/observer.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/core/tensor.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/simple_queue.h" - -C10_DECLARE_string(caffe2_override_executor); - -namespace caffe2 { - -class NetBase; -typedef ObserverBase NetObserver; -typedef std::function(NetBase*)> - NetObserverCreator; - -class OperatorBase; -class Workspace; - -// Net is a thin struct that owns all the operators together with the operator -// contexts. -class TORCH_API NetBase : public Observable { - public: - NetBase(const std::shared_ptr& net_def, Workspace* ws); - virtual ~NetBase() noexcept {} - - virtual bool SupportsAsync() = 0; - inline const vector& events() const { - return events_; - } - - virtual void Wait() { - // by default just wait till all events are finished - for (const auto& event : events_) { - event->Finish(); - } - } - - virtual bool Run() { - if (!RunAsync()) { - LOG(ERROR) << "Failed to execute async run"; - return false; - } - Wait(); - return handleRunError(); - } - - virtual bool RunAsync(); - - virtual void Cancel(); - - /* Benchmarks a network for one individual run so that we can feed new - * inputs on additional calls. - * This function returns the number of microseconds spent - * during the benchmark - */ - virtual float TEST_Benchmark_One_Run(); - - /** - * Benchmarks a network. - * - * This function returns a vector of float recording the number of milli- - * seconds spent during the benchmark. The 0-th item is the time spent per - * each network run, and if a net instantiation supports run_individual, - * the remainder of the vector returns the number of milliseconds spent per - * operator. - */ - virtual vector TEST_Benchmark( - const int /*warmup_runs*/, - const int /*main_runs*/, - const bool /*run_individual*/); - - inline const vector& external_output() const { - return external_output_; - } - - inline const vector& external_input() const { - return external_input_; - } - - /* Used to attach Observers to operators of a Net - * - * Returns pointers to objects owned with unique_ptrs. - * Use with caution. - */ - virtual vector GetOperators() const = 0; - - const string& Name() const { - return name_; - } - - inline const NetDef& debug_def() const { - CAFFE_ENFORCE(has_debug_def(), "net_def was null!"); - return *net_def_; - } - - inline bool has_debug_def() const { - return net_def_ != nullptr; - } - - protected: - virtual bool DoRunAsync() { - CAFFE_THROW("Not implemented"); - }; - - virtual bool handleRunError() { - for (const Event* event : events_) { - if (event->Query() != EventStatus::EVENT_SUCCESS) { - CAFFE_THROW(event->ErrorMessage()); - } - } - return true; - } - - vector external_input_; - vector external_output_; - string name_; - vector events_; - std::shared_ptr net_def_; - C10_DISABLE_COPY_AND_ASSIGN(NetBase); -}; - -class TORCH_API ExecutorHelper { - public: - ExecutorHelper() {} - virtual TaskThreadPoolBase* GetPool(const DeviceOption& option) const; - virtual std::vector GetOperators() const; - virtual int GetNumWorkers() const; - virtual ~ExecutorHelper() {} -}; - -C10_DECLARE_REGISTRY( - NetRegistry, - NetBase, - const std::shared_ptr&, - Workspace*); -#define REGISTER_NET_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__) -#define REGISTER_NET(name, ...) \ - C10_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__) - -/** - * @brief Creates a network, accessing / creating blobs in the given workspace. - * - * Note that this is different from Workspace::CreateNet. The latter adds the - * created net object to the workspace's net map, while this function returns - * a standalone net object. - */ -TORCH_API unique_ptr CreateNet(const NetDef& net_def, Workspace* ws); -TORCH_API unique_ptr CreateNet( - const std::shared_ptr& net_def, - Workspace* ws); - -TORCH_API void AddGlobalNetObserverCreator(NetObserverCreator creator); - -TORCH_API void ClearGlobalNetObservers(); - -} // namespace caffe2 - -#endif // CAFFE2_CORE_NET_H_ diff --git a/caffe2/core/net_async_tracing_test.cc b/caffe2/core/net_async_tracing_test.cc deleted file mode 100644 index 10a81ada9255..000000000000 --- a/caffe2/core/net_async_tracing_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "caffe2/core/net_async_tracing.h" - -namespace caffe2 { - -namespace tracing { - -void testExtractShardId(const string& name, int expectedId) { - EXPECT_EQ(extractShardId(name), expectedId); -} - -TEST(NetAsyncTracingTest, ExtractShardId) { - testExtractShardId("ABCDEFshard:1705!!A", 1705); - // Should use the last one - testExtractShardId("ABCDEFshard:4324!!Ashard:01220b", 1220); - // Nothing to extract - testExtractShardId("ABCDEFsha:222", -1); - // Regular cases - testExtractShardId("FC:shard:0", 0); - testExtractShardId("FC:shard:10", 10); - testExtractShardId("FC:shard:15", 15); -} - -TEST(NetAsyncTracingTest, EveryKIteration) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - arg { - name: "enable_tracing" - i: 1 - } - arg { - name: "tracing_mode" - s: "EVERY_K_ITERATIONS" - } - arg { - name: "tracing_filepath" - s: "/tmp" - } - arg { - name: "trace_every_nth_batch" - i: 1 - } - arg { - name: "dump_every_nth_batch" - i: 1 - } - op { - output: "out" - type: "UniformFill" - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - std::unique_ptr net(CreateNet(net_def, &ws)); - net->Run(); -} - -TEST(NetAsyncTracingTest, GlobalTimeSlice) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - arg { - name: "enable_tracing" - i: 1 - } - arg { - name: "tracing_filepath" - s: "/tmp" - } - arg { - name: "trace_for_n_ms" - i: 1 - } - arg { - name: "trace_every_n_ms" - i: 1 - } - op { - output: "out" - type: "UniformFill" - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - std::unique_ptr net(CreateNet(net_def, &ws)); - net->Run(); -} - -} // namespace tracing - -} // namespace caffe2 diff --git a/caffe2/core/net_dag_utils_test.cc b/caffe2/core/net_dag_utils_test.cc deleted file mode 100644 index dfbb56614301..000000000000 --- a/caffe2/core/net_dag_utils_test.cc +++ /dev/null @@ -1,296 +0,0 @@ -#include -#include "caffe2/core/net_dag_utils.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -namespace { -class DummySyncOp final : public Operator { - public: - DummySyncOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - return true; - } -}; - -class DummyAsyncOp final : public Operator { - public: - DummyAsyncOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - return true; - } - - bool HasAsyncPart() const override { - return true; - } -}; - -REGISTER_CPU_OPERATOR(DagUtilTestDummySync, DummySyncOp); -REGISTER_CPU_OPERATOR(DagUtilTestDummyAsync, DummyAsyncOp); - -OPERATOR_SCHEMA(DagUtilTestDummySync) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX); -OPERATOR_SCHEMA(DagUtilTestDummyAsync) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX); - -class DagUtilTestContext { - public: - DagUtilTestContext(const std::string& spec, Workspace* ws) { - net_def_ = std::make_shared(); - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, net_def_.get())); - operator_nodes_ = dag_utils::prepareOperatorNodes(net_def_, ws); - } - - dag_utils::ExecutionChains computeChains() { - return dag_utils::computeGroups(operator_nodes_); - } - - private: - std::shared_ptr net_def_{nullptr}; - std::vector operator_nodes_; -}; - -void PrintChains(const dag_utils::ExecutionChains& chains) { - for (const auto& kv : chains) { - std::stringstream ss; - ss << kv.first << ": "; - for (const auto& v : kv.second) { - ss << v << ", "; - } - LOG(INFO) << ss.str(); - } -} -} // namespace - -TEST(DagUtilTest, Empty) { - const auto spec = R"DOC( - name: "test0" - type: "async_scheduling" - )DOC"; - Workspace ws; - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - EXPECT_TRUE(chains.empty()); -} - -// 4 sync ops forming a diamond -TEST(DagUtilTest, AllSync) { - const auto spec = R"DOC( - name: "test1" - type: "async_scheduling" - external_input: "in" - op { - input: "in" - output: "n1" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n2" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n3" - type: "DagUtilTestDummySync" - } - op { - input: "n2" - input: "n3" - output: "out" - type: "DagUtilTestDummySync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0, 1, 2, 3}}}; - EXPECT_EQ(chains, expected); -} - -// 3 async ops forming an L shape -TEST(DagUtilTest, AllAsync) { - const auto spec = R"DOC( - name: "test2" - type: "async_scheduling" - external_input: "in0" - external_input: "in1" - op { - input: "in0" - output: "n1" - type: "DagUtilTestDummyAsync" - } - op { - input: "in1" - output: "n2" - type: "DagUtilTestDummyAsync" - } - op { - input: "n1" - output: "n3" - type: "DagUtilTestDummyAsync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in0"); - ws.CreateBlob("in1"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0}}, {1, {1}}, {2, {2}}}; - EXPECT_EQ(chains, expected); -} - -// 3 sync ops and 1 async op (#2) forming a diamond -TEST(DagUtilTest, Mixed0) { - const auto spec = R"DOC( - name: "test3" - type: "async_scheduling" - external_input: "in" - op { - input: "in" - output: "n1" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n2" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - output: "n3" - type: "DagUtilTestDummyAsync" - } - op { - input: "n2" - input: "n3" - output: "out" - type: "DagUtilTestDummySync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0, 1}}, {2, {2}}, {3, {3}}}; - EXPECT_EQ(chains, expected); -} - -// 3 sync ops and 1 async op (#2) forming a Y shape -TEST(DagUtilTest, Mixed1) { - const auto spec = R"DOC( - name: "test3" - type: "async_scheduling" - external_input: "in0" - external_input: "in1" - op { - input: "in0" - output: "n1" - type: "DagUtilTestDummySync" - } - op { - input: "in1" - output: "n2" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - input: "n2" - output: "n3" - type: "DagUtilTestDummyAsync" - } - op { - input: "n3" - output: "out" - type: "DagUtilTestDummySync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in0"); - ws.CreateBlob("in1"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{{0, {0, 1}}, {2, {2}}, {3, {3}}}; - EXPECT_EQ(chains, expected); -} -// More complicated mixed case. * means async -// 0* -> 1* -> 2 -// | -// 3 -> 4 -> 5 -// | | -// | 6 -// - -> 8* -// 7* -/ -TEST(DagUtilTest, Mixed2) { - const auto spec = R"DOC( - name: "test4" - type: "async_scheduling" - external_input: "in0" - external_input: "in1" - external_input: "in2" - op { - input: "in0" - output: "n1" - type: "DagUtilTestDummyAsync" - } - op { - input: "n1" - output: "n2" - type: "DagUtilTestDummyAsync" - } - op { - input: "n2" - output: "out0" - type: "DagUtilTestDummySync" - } - op { - input: "in1" - output: "n3" - type: "DagUtilTestDummySync" - } - op { - input: "n1" - input: "n3" - output: "n4" - type: "DagUtilTestDummySync" - } - op { - input: "n4" - output: "out1" - type: "DagUtilTestDummySync" - } - op { - input: "n3" - output: "out2" - type: "DagUtilTestDummySync" - } - op { - input: "in2" - output: "n7" - type: "DagUtilTestDummyAsync" - } - op { - input: "n3" - input: "n7" - output: "out3" - type: "DagUtilTestDummyAsync" - } - )DOC"; - Workspace ws; - ws.CreateBlob("in0"); - ws.CreateBlob("in1"); - ws.CreateBlob("in2"); - DagUtilTestContext t(spec, &ws); - auto chains = t.computeChains(); - dag_utils::ExecutionChains expected{ - {0, {0}}, {1, {1}}, {3, {3, 6}}, {4, {4, 2, 5}}, {7, {7}}, {8, {8}}}; - EXPECT_EQ(chains, expected); -} -} // namespace caffe2 diff --git a/caffe2/core/net_gpu_test.cc b/caffe2/core/net_gpu_test.cc deleted file mode 100644 index 1eb6fa513a23..000000000000 --- a/caffe2/core/net_gpu_test.cc +++ /dev/null @@ -1,130 +0,0 @@ -#include -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_async_base.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -namespace caffe2 { - -namespace { - -static std::atomic counter; - -// A net test dummy op that does nothing but scaffolding. Here, we -// inherit from OperatorBase because we instantiate on both CPU and -// GPU. In general, you want to only inherit from Operator. -class NetTestDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - - NetTestDummyOp(const OperatorDef& operator_def, Workspace* ws) - : OperatorBase(operator_def, ws), - fail_(OperatorBase::GetSingleArgument("fail", false)) {} - - bool Run(int /* unused */ /*stream_id*/) override { - if (fail_) { - return false; - } - counter.fetch_add(1); - return true; - } - - // Simulate CUDA operator behavior - bool HasAsyncPart() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - bool SupportsAsyncScheduling() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - protected: - const bool fail_; -}; - -REGISTER_CPU_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CPU_OPERATOR(NetTestDummy2, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy2, NetTestDummyOp); - -OPERATOR_SCHEMA(NetTestDummy) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); -OPERATOR_SCHEMA(NetTestDummy2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{1, 0}}); - -} // namespace - -void testExecution(std::unique_ptr& net, int num_ops) { - // Run 100 times - for (int i = 0; i < 100; i++) { - counter.exchange(0); - net.get()->Run(); - ASSERT_EQ(num_ops, counter.load()); - } -} - -void checkChainingAndRun( - const char* spec, - const dag_utils::ExecutionChains& expected) { - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - { - net_def.set_num_workers(4); - std::unique_ptr net(CreateNet(net_def, &ws)); - auto* dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(dag); - const auto& chains = dag->TEST_execution_chains(); - EXPECT_EQ(chains, expected); - testExecution(net, net_def.op().size()); - } -} - -TEST(NetTest, DISABLED_ChainingForDifferentDevices) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out" - type: "NetTestDummy" - device_option { - device_type: 1 - } - } - op { - input: "out" - output: "out2" - type: "NetTestDummy" - device_option { - device_type: 1 - } - } - op { - input: "out2" - output: "out3" - type: "NetTestDummy" - device_option { - device_type: 1 - device_id: 1 - } - } -)DOC"; - if (HasCudaGPU() && NumCudaDevices() >= 2) { - checkChainingAndRun(spec, {{0, {0, 1, 2}}, {3, {3}}}); - } -} - -} // namespace caffe2 diff --git a/caffe2/core/net_simple_refcount_test.cc b/caffe2/core/net_simple_refcount_test.cc deleted file mode 100644 index 14acf998064a..000000000000 --- a/caffe2/core/net_simple_refcount_test.cc +++ /dev/null @@ -1,70 +0,0 @@ -#include -#include "c10/util/StringUtil.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_async_scheduling.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -#include - -namespace caffe2 { - -namespace { - -// A net test dummy op that does nothing but scaffolding. Here, we -// inherit from OperatorBase because we instantiate on both CPU and -// GPU. In general, you want to only inherit from Operator. -class NetSimpleRefCountTestOp final : public Operator { - public: - NetSimpleRefCountTestOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - USE_OPERATOR_FUNCTIONS(CPUContext); - - bool RunOnDevice() override { - const int32_t& input = OperatorBase::Input(0); - int32_t* output = OperatorBase::Output(0); - *output = input + 1; - return true; - } -}; - -REGISTER_CPU_OPERATOR(NetSimpleRefCountTest, NetSimpleRefCountTestOp); - -OPERATOR_SCHEMA(NetSimpleRefCountTest).NumInputs(1).NumOutputs(1); - -TEST(NetSimpleRefCountTest, TestCorrectness) { - Workspace ws; - *(ws.CreateBlob("a")->GetMutable()) = 1; - NetDef net_def; - net_def.set_type("simple_refcount"); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"a"}, {"b"})); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"b"}, {"c"})); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"b"}, {"d"})); - net_def.add_op()->CopyFrom( - CreateOperatorDef("NetSimpleRefCountTest", "", {"c"}, {"e"})); - // After execution, what should look like is: - // a = 1 - // b = deallocated - // c = deallocated - // d = 3 - // e = 4 - std::unique_ptr net(CreateNet(net_def, &ws)); - net->Run(); - // Note on ASSERT vs EXPECT: ASSERT will quit directly if condition not - // met, which is why we guard IsType<> calls with ASSERT so that the - // subsequent Get() calls do not product an exception. - ASSERT_TRUE(ws.GetBlob("a")->IsType()); - EXPECT_EQ(ws.GetBlob("a")->Get(), 1); - EXPECT_EQ(ws.GetBlob("b")->GetRaw(), nullptr); - EXPECT_EQ(ws.GetBlob("c")->GetRaw(), nullptr); - ASSERT_TRUE(ws.GetBlob("d")->IsType()); - EXPECT_EQ(ws.GetBlob("d")->Get(), 3); - ASSERT_TRUE(ws.GetBlob("e")->IsType()); - EXPECT_EQ(ws.GetBlob("e")->Get(), 4); -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/net_test.cc b/caffe2/core/net_test.cc deleted file mode 100644 index a1c80eca6790..000000000000 --- a/caffe2/core/net_test.cc +++ /dev/null @@ -1,1122 +0,0 @@ -#include -#include "c10/util/StringUtil.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_async_scheduling.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -#include - -namespace caffe2 { - -namespace { - -static std::atomic counter; - -// A net test dummy op that does nothing but scaffolding. Here, we -// inherit from OperatorBase because we instantiate on both CPU and -// GPU. In general, you want to only inherit from Operator. -class NetTestDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - - NetTestDummyOp(const OperatorDef& operator_def, Workspace* ws) - : OperatorBase(operator_def, ws), - fail_(OperatorBase::GetSingleArgument("fail", false)) {} - - bool Run(int /* unused */ /*stream_id*/) override { - if (fail_) { - return false; - } - counter.fetch_add(1); - return true; - } - - // Simulate CUDA operator behavior - bool HasAsyncPart() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - bool SupportsAsyncScheduling() const override { - return debug_def().device_option().device_type() == PROTO_CUDA; - } - - protected: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - const bool fail_; -}; - -REGISTER_CPU_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy, NetTestDummyOp); -REGISTER_CPU_OPERATOR(NetTestDummy2, NetTestDummyOp); -REGISTER_CUDA_OPERATOR(NetTestDummy2, NetTestDummyOp); - -OPERATOR_SCHEMA(NetTestDummy) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); -OPERATOR_SCHEMA(NetTestDummy2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{1, 0}}); - -unique_ptr CreateNetTestHelper( - Workspace* ws, - const vector& input, - const vector& output) { - NetDef net_def; - { - auto& op = *(net_def.add_op()); - op.set_type("NetTestDummy"); - op.add_input("in"); - op.add_output("hidden"); - } - { - auto& op = *(net_def.add_op()); - op.set_type("NetTestDummy"); - op.add_input("hidden"); - op.add_output("out"); - } - - for (const auto& name : input) { - net_def.add_external_input(name); - } - for (const auto& name : output) { - net_def.add_external_output(name); - } - return CreateNet(net_def, ws); -} - -} // namespace - -TEST(NetTest, ConstructionNoDeclaredInputOutput) { - Workspace ws; - ws.CreateBlob("in"); - unique_ptr net( - CreateNetTestHelper(&ws, vector(), vector())); - EXPECT_TRUE(net.get() != nullptr); -} - -TEST(NetTest, ConstructionDeclaredInput) { - Workspace ws; - ws.CreateBlob("in"); - unique_ptr net( - CreateNetTestHelper(&ws, vector{"in"}, vector())); - EXPECT_TRUE(net.get() != nullptr); -} - -TEST(NetTest, ConstructionDeclaredOutput) { - Workspace ws; - ws.CreateBlob("in"); - unique_ptr net( - CreateNetTestHelper(&ws, vector(), vector{"out"})); - EXPECT_TRUE(net.get() != nullptr); -} - -TEST(NetTest, DeclaredInputInsufficient) { - Workspace ws; - ws.CreateBlob("in"); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW( - CreateNetTestHelper(&ws, vector{"unuseful_in"}, vector()), - EnforceNotMet); -} - -TEST(NetDeathTest, DeclaredOutputNotMet) { - Workspace ws; - ws.CreateBlob("in"); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW( - CreateNetTestHelper( - &ws, vector(), vector{"unproduced_out"}), - EnforceNotMet); -} - -void testExecution(std::unique_ptr& net, int num_ops) { - // Run 100 times - for (int i = 0; i < 100; i++) { - counter.exchange(0); - net.get()->Run(); - ASSERT_EQ(num_ops, counter.load()); - } -} - -void checkChainingAndRun( - const char* spec, - const dag_utils::ExecutionChains& expected) { - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - { - net_def.set_num_workers(4); - std::unique_ptr net(CreateNet(net_def, &ws)); - auto* dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(dag); - const auto& chains = dag->TEST_execution_chains(); - EXPECT_TRUE(chains == expected); - testExecution(net, net_def.op().size()); - } -} - -void checkNumChainsAndRun(const char* spec, const int expected_num_chains) { - Workspace ws; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - net_def.set_num_workers(4); - - // Create all external inputs - // NOLINTNEXTLINE(performance-for-range-copy) - for (auto inp : net_def.external_input()) { - ws.CreateBlob(inp); - } - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - auto* dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(dag); - const auto& chains = dag->TEST_execution_chains(); - EXPECT_EQ(expected_num_chains, chains.size()); - testExecution(net, net_def.op().size()); - } -} - -TEST(NetTest, DISABLED_ChainingForLinearModel) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out" - type: "NetTestDummy" - } -)DOC"; - checkChainingAndRun(spec, {{0, {0, 1}}}); -} - -TEST(NetTest, DISABLED_ChainingForFork) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out1" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out2" - type: "NetTestDummy" - } -)DOC"; - checkChainingAndRun(spec, {{0, {0}}, {1, {1}}, {2, {2}}}); -} - -// TEST(NetTest, ChainingForJoinWithAncestor) { -// const auto spec = R"DOC( -// name: "example" -// type: "dag" -// external_input: "in" -// op { -// input: "in" -// output: "hidden" -// type: "NetTestDummy" -// } -// op { -// input: "hidden" -// output: "out1" -// type: "NetTestDummy" -// } -// op { -// input: "hidden" -// output: "out2" -// type: "NetTestDummy" -// } -// op { -// input: "hidden" -// input: "out2" -// type: "NetTestDummy" -// } -// )DOC"; -// checkChainingAndRun(spec, {{0, {0}}, {1, {1}}, {2, {2, 3}}}); -// } - -TEST(NetTest, DISABLED_ChainingForForkJoin) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden1" - type: "NetTestDummy" - } - op { - input: "in" - output: "hidden2" - type: "NetTestDummy" - } - op { - input: "hidden1" - input: "hidden2" - output: "out" - type: "NetTestDummy" - } - op { - input: "out" - output: "out2" - type: "NetTestDummy" - } -)DOC"; - checkChainingAndRun(spec, {{0, {0}}, {1, {1}}, {2, {2, 3}}}); -} - -TEST(NetTest, DISABLED_ChainingForwardBackward) { - const auto spec = R"DOC( - name: "gpu_0" - type: "dag" - op { - input: "in" - input: "fc_0_w" - input: "fc_0_b" - output: "fc_0" - name: "0" - type: "NetTestDummy" - } - op { - input: "fc_0" - output: "fc_0" - name: "1" - type: "NetTestDummy" - } - op { - input: "fc_0" - input: "fc_1_w" - input: "fc_1_b" - output: "fc_1" - name: "2" - type: "NetTestDummy" - } - op { - input: "fc_1" - output: "fc_1" - name: "3" - type: "NetTestDummy" - } - op { - input: "fc_1" - input: "fc_2_w" - input: "fc_2_b" - output: "fc_2" - name: "4" - type: "NetTestDummy" - } - op { - input: "fc_2" - output: "fc_2" - name: "5" - type: "NetTestDummy" - } - op { - input: "fc_2" - input: "fc_3_w" - input: "fc_3_b" - output: "fc_3" - name: "6" - type: "NetTestDummy" - } - op { - input: "fc_3" - output: "fc_3" - name: "7" - type: "NetTestDummy" - } - op { - input: "fc_3" - input: "fc_4_w" - input: "fc_4_b" - output: "fc_4" - name: "8" - type: "NetTestDummy" - } - op { - input: "fc_4" - output: "fc_4" - name: "9" - type: "NetTestDummy" - } - op { - input: "fc_4" - input: "in2" - output: "LabelCrossEntropy" - name: "10" - type: "NetTestDummy" - } - op { - input: "LabelCrossEntropy" - output: "AveragedLoss" - name: "11" - type: "NetTestDummy" - } - op { - input: "AveragedLoss" - output: "AveragedLoss_autogen_grad" - name: "12" - type: "NetTestDummy" - } - op { - input: "LabelCrossEntropy" - input: "AveragedLoss_autogen_grad" - output: "LabelCrossEntropy_grad" - name: "13" - type: "NetTestDummy" - } - op { - input: "fc_4" - input: "label" - input: "LabelCrossEntropy_grad" - output: "fc_4_grad" - name: "14" - type: "NetTestDummy2" - } - op { - input: "fc_4" - input: "fc_4_grad" - output: "fc_4_grad" - name: "15" - type: "NetTestDummy2" - } - op { - input: "fc_3" - input: "fc_4_w" - input: "fc_4_grad" - output: "fc_4_w_grad" - output: "fc_4_b_grad" - output: "fc_3_grad" - name: "16" - type: "NetTestDummy" - } - op { - input: "fc_3" - input: "fc_3_grad" - output: "fc_3_grad" - name: "17" - type: "NetTestDummy2" - } - op { - input: "fc_2" - input: "fc_3_w" - input: "fc_3_grad" - output: "fc_3_w_grad" - output: "fc_3_b_grad" - output: "fc_2_grad" - name: "18" - type: "NetTestDummy" - } - op { - input: "fc_2" - input: "fc_2_grad" - output: "fc_2_grad" - name: "19" - type: "NetTestDummy2" - } - op { - input: "fc_1" - input: "fc_2_w" - input: "fc_2_grad" - output: "fc_2_w_grad" - output: "fc_2_b_grad" - output: "fc_1_grad" - name: "20" - type: "NetTestDummy" - } - op { - input: "fc_1" - input: "fc_1_grad" - output: "fc_1_grad" - name: "21" - type: "NetTestDummy2" - } - op { - input: "fc_0" - input: "fc_1_w" - input: "fc_1_grad" - output: "fc_1_w_grad" - output: "fc_1_b_grad" - output: "fc_0_grad" - name: "22" - type: "NetTestDummy" - } - op { - input: "fc_0" - input: "fc_0_grad" - output: "fc_0_grad" - name: "23" - type: "NetTestDummy2" - } - op { - input: "in" - input: "fc_0_w" - input: "fc_0_grad" - output: "fc_0_w_grad" - output: "fc_0_b_grad" - output: "data_grad" - name: "24" - type: "NetTestDummy" - } - external_input: "in" - external_input: "in2" - external_input: "LR" - external_input: "fc_0_w" - external_input: "fc_0_b" - external_input: "fc_1_w" - external_input: "fc_1_b" - external_input: "fc_2_w" - external_input: "fc_2_b" - external_input: "fc_3_w" - external_input: "fc_3_b" - external_input: "fc_4_w" - external_input: "fc_4_b" - external_input: "label" - )DOC"; - checkNumChainsAndRun(spec, 1); -} - -TEST(NetTest, DISABLED_ChainingForHogwildModel) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden1" - type: "NetTestDummy" - } - op { - input: "hidden1" - output: "mid1" - type: "NetTestDummy" - } - op { - input: "mid1" - output: "out1" - type: "NetTestDummy" - } - op { - input: "in" - output: "hidden2" - type: "NetTestDummy" - } - op { - input: "hidden2" - output: "mid2" - type: "NetTestDummy" - } - op { - input: "mid2" - output: "out2" - type: "NetTestDummy" - } -)DOC"; - checkNumChainsAndRun(spec, 2); -} - -TEST(NetTest, DISABLED_FailingOperator) { - const auto spec = R"DOC( - name: "example" - type: "dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "NetTestDummy" - } - op { - input: "hidden" - output: "out" - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } -)DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - net_def.set_num_workers(4); - std::unique_ptr net(CreateNet(net_def, &ws)); - for (int i = 0; i < 10; i++) { - counter.exchange(0); - bool run_result = false; - try { - run_result = net->Run(); - } catch (const std::exception&) { - // async_scheduling would throw - } - ASSERT_FALSE(run_result); - - ASSERT_EQ(1, counter.load()); - } - } -} - -const int kTestPoolSize = 4; - -class ExecutorHelperDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - - ExecutorHelperDummyOp(const OperatorDef& operator_def, Workspace* ws) - : OperatorBase(operator_def, ws) {} - - bool Run(int /* unused */ /*stream_id*/) override { - auto helper = GetExecutorHelper(); - CAFFE_ENFORCE(helper); - auto pool = helper->GetPool(device_option()); - CAFFE_ENFORCE(pool); - auto pool_size = pool->size(); - CAFFE_ENFORCE_EQ(pool_size, kTestPoolSize); - return true; - } -}; - -REGISTER_CPU_OPERATOR(ExecutorHelperDummy, ExecutorHelperDummyOp); - -OPERATOR_SCHEMA(ExecutorHelperDummy); - -TEST(NetTest, OperatorWithExecutorHelper) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - op { - type: "ExecutorHelperDummy" - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - net_def.set_num_workers(kTestPoolSize); - std::unique_ptr net(CreateNet(net_def, &ws)); - ASSERT_TRUE(net->Run()); -} - -TEST(NetTest, DISABLED_OperatorWithDisabledEvent) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - external_input: "in" - op { - input: "in" - output: "out" - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } -)DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - net->GetOperators()[0]->DisableEvent(); - // async_scheduling propagates exception - bool caught_exception = false; - try { - net->Run(); - } catch (const std::exception& e) { - caught_exception = true; - } - ASSERT_TRUE(caught_exception); - } -} - -TEST(NetTest, ExecutorOverride) { - const auto spec = R"DOC( - name: "example" - type: "dag" - )DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - Workspace ws; - auto old = FLAGS_caffe2_override_executor; - auto g = MakeGuard([&]() { FLAGS_caffe2_override_executor = old; }); - FLAGS_caffe2_override_executor = "dag,async_scheduling"; - - std::unique_ptr net(CreateNet(net_def, &ws)); - auto async_net = - caffe2::dynamic_cast_if_rtti(net.get()); - ASSERT_TRUE(async_net != nullptr); - } -} - -TEST(NetTest, AsyncEmptyNet) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - )DOC"; - - Workspace ws; - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - bool caught_exception = false; - try { - ASSERT_TRUE(net->Run()); - } catch (const std::exception& e) { - caught_exception = true; - } - ASSERT_FALSE(caught_exception); - } -} - -TEST(NetTest, DISABLED_RunAsyncFailure) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - op { - input: "in" - output: "out" - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } - )DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - - bool caught_exception = false; - try { - ASSERT_FALSE(net->Run()); - } catch (const std::exception& e) { - caught_exception = true; - } - ASSERT_TRUE(caught_exception); - } -} - -TEST(NetTest, NoTypeNet) { - const auto spec = R"DOC( - name: "no_type_net" - )DOC"; - - Workspace ws; - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - { - std::unique_ptr net(CreateNet(net_def, &ws)); - ASSERT_TRUE(net); - } -} - -class NotFinishingOp final : public Operator { - public: - NotFinishingOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // never calls SetFinished - return true; - } - - bool HasAsyncPart() const override { - return true; - } -}; - -REGISTER_CPU_OPERATOR(NotFinishingOp, NotFinishingOp); - -OPERATOR_SCHEMA(NotFinishingOp); - -TEST(NetTest, PendingOpsAndNetFailure) { - const auto spec = R"DOC( - name: "example" - type: "async_scheduling" - op { - type: "NotFinishingOp" - } - op { - type: "NetTestDummy" - arg { - name: "fail" - i: 1 - } - } -)DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - - Workspace ws; - std::unique_ptr net(CreateNet(net_def, &ws)); - - try { - // net is not stuck and returns false - ASSERT_FALSE(net->Run()); - } catch (const caffe2::AsyncNetCancelled&) { - // Cancellation exception is fine since if the ops run concurrently the - // NotFinishingOp may be cancelled with an exception. - } -} - -class AsyncErrorOp final : public Operator { - public: - AsyncErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws), - throw_(OperatorBase::GetSingleArgument("throw", false)), - fail_in_sync_( - OperatorBase::GetSingleArgument("fail_in_sync", false)), - sleep_time_s_(OperatorBase::GetSingleArgument("sleep_time", 1)), - error_msg_(OperatorBase::GetSingleArgument( - "error_msg", - "Error")) {} - - bool RunOnDevice() override { - if (fail_in_sync_) { - if (throw_) { - throw std::logic_error(error_msg_); - } else { - return false; - } - } else { - if (thread_) { - thread_->join(); - } - thread_ = std::make_unique([this]() { - try { - std::this_thread::sleep_for(std::chrono::seconds(sleep_time_s_)); - if (throw_) { - throw std::logic_error(error_msg_); - } else { - if (!cancel_.test_and_set()) { - event().SetFinished(error_msg_.c_str()); - } - } - } catch (...) { - if (!cancel_.test_and_set()) { - event().SetFinishedWithException(error_msg_.c_str()); - } - } - }); - return true; - } - } - - bool HasAsyncPart() const override { - return true; - } - - void CancelAsyncCallback() override { - cancel_.test_and_set(); - } - - ~AsyncErrorOp() override { - if (thread_) { - thread_->join(); - } - } - - private: - std::unique_ptr thread_; - bool throw_; - bool fail_in_sync_; - int sleep_time_s_; - std::string error_msg_; - std::atomic_flag cancel_ = ATOMIC_FLAG_INIT; -}; - -REGISTER_CPU_OPERATOR(AsyncErrorOp, AsyncErrorOp); -OPERATOR_SCHEMA(AsyncErrorOp); - -std::unique_ptr AsyncErrorNet( - Workspace* ws, - const std::string& net_name, - bool throw_, - bool fail_in_sync) { - std::string spec_template = R"DOC( - name: "" - type: "async_scheduling" - op { - type: "AsyncErrorOp" - arg { - name: "throw" - i: - } - arg { - name: "fail_in_sync" - i: - } - } - )DOC"; - - std::string spec = spec_template; - ReplaceAll(spec, "", net_name.c_str()); - ReplaceAll(spec, "", throw_ ? "1" : "0"); - ReplaceAll(spec, "", fail_in_sync ? "1" : "0"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - return CreateNet(net_def, ws); -} - -TEST(NetTest, AsyncErrorOpTest) { - Workspace ws; - - // Throw in sync part - auto net = AsyncErrorNet(&ws, "net1", /*throw_*/ true, /*fail_in_sync*/ true); -#ifdef CAFFE2_USE_EXCEPTION_PTR - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(net->Run(), std::logic_error); -#endif - - // Return false in sync part - net = AsyncErrorNet(&ws, "net2", /*throw_*/ false, /*fail_in_sync*/ true); - ASSERT_FALSE(net->Run()); - - // SetFinishedWithException in async part - net = AsyncErrorNet(&ws, "net3", /*throw_*/ true, /*fail_in_sync*/ false); -#ifdef CAFFE2_USE_EXCEPTION_PTR - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(net->Run(), std::logic_error); -#endif - - // SetFinished(err) in async part - net = AsyncErrorNet(&ws, "net4", /*throw_*/ false, /*fail_in_sync*/ false); - ASSERT_FALSE(net->Run()); -} - -TEST(NetTest, AsyncErrorTimingsTest) { - Workspace ws; - std::string spec = R"DOC( - name: "net" - type: "async_scheduling" - op { - type: "AsyncErrorOp" - arg { - name: "throw" - i: 1 - } - arg { - name: "fail_in_sync" - i: 0 - } - arg { - name: "sleep_time" - i: 2 - } - arg { - name: "error_msg" - s: "Error1" - } - } - op { - type: "AsyncErrorOp" - arg { - name: "throw" - i: 1 - } - arg { - name: "fail_in_sync" - i: 0 - } - arg { - name: "sleep_time" - i: 1 - } - arg { - name: "error_msg" - s: "Error2" - } - } - )DOC"; - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - auto net = CreateNet(net_def, &ws); - - try { - net->Run(); - } catch (const std::logic_error& e) { - ASSERT_TRUE(std::string(e.what()) == "Error2"); - } catch (...) { - FAIL() << "Expected std::logic_error thrown"; - } -} - -class SyncErrorOp final : public Operator { - public: - SyncErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws), - fail_(OperatorBase::GetSingleArgument("fail", true)), - throw_(OperatorBase::GetSingleArgument("throw", false)) {} - - bool RunOnDevice() override { - if (fail_) { - if (throw_) { - throw std::logic_error("Error"); - } else { - return false; - } - } else { - return true; - } - } - - // NOLINTNEXTLINE(modernize-use-equals-default) - ~SyncErrorOp() override {} - - private: - bool fail_; - bool throw_; -}; - -REGISTER_CPU_OPERATOR(SyncErrorOp, SyncErrorOp); -OPERATOR_SCHEMA(SyncErrorOp); - -std::unique_ptr -ChainErrorNet(Workspace* ws, const std::string& net_name, bool throw_) { - std::string spec_template = R"DOC( - name: "" - type: "async_scheduling" - op { - type: "SyncErrorOp" - arg { - name: "fail" - i: 1 - } - arg { - name: "throw" - i: - } - } - op { - type: "SyncErrorOp" - arg { - name: "fail" - i: 0 - } - } - )DOC"; - - std::string spec = spec_template; - ReplaceAll(spec, "", net_name.c_str()); - ReplaceAll(spec, "", throw_ ? "1" : "0"); - - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(spec, &net_def)); - return CreateNet(net_def, ws); -} - -TEST(NetTest, ChainErrorTest) { - Workspace ws; - - auto net = ChainErrorNet(&ws, "net1", /*throw_*/ true); -#ifdef CAFFE2_USE_EXCEPTION_PTR - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(net->Run(), std::logic_error); -#endif - - net = ChainErrorNet(&ws, "net2", /*throw_*/ false); - ASSERT_FALSE(net->Run()); -} - -void testProfDAGNetErrorCase(bool test_error) { - std::string spec_template = R"DOC( - name: "prof_dag_error_test_net" - type: "prof_dag" - external_input: "in" - op { - input: "in" - output: "hidden" - type: "SyncErrorOp" - arg { - name: "fail" - i: - } - arg { - name: "throw" - i: 0 - } - } - op { - input: "hidden" - output: "out" - type: "SyncErrorOp" - arg { - name: "fail" - i: 0 - } - } - )DOC"; - - Workspace ws; - ws.CreateBlob("in"); - - NetDef net_def; - std::string net_spec = spec_template; - ReplaceAll(net_spec, "", test_error ? "1" : "0"); - CAFFE_ENFORCE(TextFormat::ParseFromString(net_spec, &net_def)); - auto net = CreateNet(net_def, &ws); - - // with failing op - net runs return false, without - true - for (auto num_runs = 0; num_runs < 10; ++num_runs) { - auto ret = net->Run(); - ASSERT_TRUE(test_error ? !ret : ret); - } - - // with failing op - prof_dag handles invalid runs and returns empty stats, - // without - returns stats for each op - auto* prof_dag = dynamic_cast_if_rtti(net.get()); - TORCH_CHECK_NOTNULL(prof_dag); - auto stats_proto = prof_dag->GetPerOperatorCost(); - ASSERT_EQ( - stats_proto.stats_size(), test_error ? 0 : net->GetOperators().size()); -} - -TEST(NetTest, ProfDAGNetErrorTest) { - testProfDAGNetErrorCase(/*test_error=*/false); - testProfDAGNetErrorCase(/*test_error=*/true); -} - -} // namespace caffe2 diff --git a/caffe2/core/numa.h b/caffe2/core/numa.h deleted file mode 100644 index 8424d544fa38..000000000000 --- a/caffe2/core/numa.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once -#include "c10/util/numa.h" -#include "caffe2/core/common.h" diff --git a/caffe2/core/observer.h b/caffe2/core/observer.h deleted file mode 100644 index 3897bb76b52a..000000000000 --- a/caffe2/core/observer.h +++ /dev/null @@ -1,164 +0,0 @@ -#pragma once - -#include -#include - -#include "caffe2/core/logging.h" - -namespace caffe2 { - -/** - * Use this to implement a Observer using the Observer Pattern template. - */ - -template -class ObserverBase { - public: - explicit ObserverBase(T* subject) : subject_(subject) {} - - virtual void Start() {} - virtual void Stop() {} - - virtual std::string debugInfo() { - return "Not implemented."; - } - - virtual ~ObserverBase() noexcept {} - - T* subject() const { - return subject_; - } - - virtual std::unique_ptr> rnnCopy(T* subject, int rnn_order) - const { - return nullptr; - } - - protected: - T* subject_; -}; - -/** - * Inherit to make your class observable. - */ -template -class Observable { - public: - Observable() = default; - - Observable(Observable&&) = default; - Observable& operator =(Observable&&) = default; - - virtual ~Observable() = default; - - C10_DISABLE_COPY_AND_ASSIGN(Observable); - - using Observer = ObserverBase; - - /* Returns a reference to the observer after addition. */ - const Observer* AttachObserver(std::unique_ptr observer) { - CAFFE_ENFORCE(observer, "Couldn't attach a null observer."); - std::unordered_set observers; - for (auto& ob : observers_list_) { - observers.insert(ob.get()); - } - - const auto* observer_ptr = observer.get(); - if (observers.count(observer_ptr)) { - return observer_ptr; - } - observers_list_.push_back(std::move(observer)); - UpdateCache(); - - return observer_ptr; - } - - /** - * Returns a unique_ptr to the removed observer. If not found, return a - * nullptr - */ - std::unique_ptr DetachObserver(const Observer* observer_ptr) { - for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) { - if (it->get() == observer_ptr) { - auto res = std::move(*it); - observers_list_.erase(it); - UpdateCache(); - return res; - } - } - return nullptr; - } - - virtual size_t NumObservers() { - return num_observers_; - } - - private: - inline static void StartObserver(Observer* observer) { - try { - observer->Start(); - } catch (const std::exception& e) { - LOG(ERROR) << "Exception from observer: " << e.what(); - } catch (...) { - LOG(ERROR) << "Exception from observer: unknown"; - } - } - - inline static void StopObserver(Observer* observer) { - try { - observer->Stop(); - } catch (const std::exception& e) { - LOG(ERROR) << "Exception from observer: " << e.what(); - } catch (...) { - LOG(ERROR) << "Exception from observer: unknown"; - } - } - - void UpdateCache() { - num_observers_ = observers_list_.size(); - if (num_observers_ != 1) { - // we cannot take advantage of the cache - return; - } - observer_cache_ = observers_list_[0].get(); - } - - public: - void StartAllObservers() { - // do not access observers_list_ unless necessary - if (num_observers_ == 0) { - return; - } else if (num_observers_ == 1) { - StartObserver(observer_cache_); - } else { - for (auto& observer : observers_list_) { - StartObserver(observer.get()); - } - } - } - - void StopAllObservers() { - // do not access observers_list_ unless necessary - if (num_observers_ == 0) { - return; - } else if (num_observers_ == 1) { - StopObserver(observer_cache_); - } else { - for (auto& observer : observers_list_) { - StopObserver(observer.get()); - } - } - } - - private: - // an on-stack cache for fast iteration; - // ideally, inside StartAllObservers and StopAllObservers, - // we should never access observers_list_ - Observer* observer_cache_; - size_t num_observers_ = 0; - - protected: - std::vector> observers_list_; -}; - -} // namespace caffe2 diff --git a/caffe2/core/observer_test.cc b/caffe2/core/observer_test.cc deleted file mode 100644 index 50faf92e8414..000000000000 --- a/caffe2/core/observer_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -#include -#include "c10/util/Registry.h" -#include "caffe2/core/common.h" -#include "caffe2/core/net.h" -#include "caffe2/core/net_simple.h" -#include "caffe2/core/observer.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/scope_guard.h" - -namespace caffe2 { - -namespace { - -static std::atomic counter; - -template -class DummyObserver final : public ObserverBase { - public: - explicit DummyObserver(T* subject_) : ObserverBase(subject_) {} - void Start() override; - void Stop() override; - - // NOLINTNEXTLINE(modernize-use-equals-default) - ~DummyObserver() override {} -}; - -template <> -void DummyObserver::Start() { - vector operators = subject_->GetOperators(); - for (auto& op : operators) { - op->AttachObserver(std::make_unique>(op)); - } - counter.fetch_add(1000); -} - -template <> -void DummyObserver::Start() { - counter.fetch_add(100); -} - -template <> -void DummyObserver::Stop() { - counter.fetch_add(10); -} - -template <> -void DummyObserver::Stop() { - counter.fetch_add(1); -} - -class ObsTestDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - StartAllObservers(); - StopAllObservers(); - return true; - } -}; - -REGISTER_CPU_OPERATOR(ObsTestDummy, ObsTestDummyOp); -REGISTER_CUDA_OPERATOR(ObsTestDummy, ObsTestDummyOp); - -OPERATOR_SCHEMA(ObsTestDummy) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -unique_ptr CreateNetTestHelper(Workspace* ws, bool isDAG = false) { - NetDef net_def; - if (isDAG) { - net_def.set_type("dag"); - } - { - auto& op = *(net_def.add_op()); - op.set_type("ObsTestDummy"); - op.add_input("in"); - op.add_output("hidden"); - } - { - auto& op = *(net_def.add_op()); - op.set_type("ObsTestDummy"); - op.add_input("hidden"); - op.add_output("out"); - } - net_def.add_external_input("in"); - net_def.add_external_output("out"); - - return CreateNet(net_def, ws); -} -} - -TEST(ObserverTest, TestNotify) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws)); - EXPECT_EQ(caffe2::dynamic_cast_if_rtti(net.get()), net.get()); - unique_ptr> net_ob = - make_unique>(net.get()); - net.get()->AttachObserver(std::move(net_ob)); - net.get()->Run(); - auto count_after = counter.load(); - EXPECT_EQ(1212, count_after - count_before); -} - -TEST(ObserverTest, TestUniqueMap) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws)); - EXPECT_EQ(caffe2::dynamic_cast_if_rtti(net.get()), net.get()); - unique_ptr> net_ob = - make_unique>(net.get()); - auto* ref = net.get()->AttachObserver(std::move(net_ob)); - net.get()->Run(); - unique_ptr::Observer> test = - net.get()->DetachObserver(ref); - auto count_after = counter.load(); - EXPECT_EQ(1212, count_after - count_before); -} - -TEST(ObserverTest, TestNotifyAfterDetach) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws)); - unique_ptr> net_ob = - make_unique>(net.get()); - auto* ob = net.get()->AttachObserver(std::move(net_ob)); - net.get()->DetachObserver(ob); - net.get()->Run(); - auto count_after = counter.load(); - EXPECT_EQ(0, count_after - count_before); -} - -TEST(ObserverTest, TestDAGNetBase) { - auto count_before = counter.load(); - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws, true)); - unique_ptr> net_ob = - make_unique>(net.get()); - net.get()->AttachObserver(std::move(net_ob)); - net.get()->Run(); - auto count_after = counter.load(); - EXPECT_EQ(1212, count_after - count_before); -} - -#if 0 -// This test intermittently segfaults, -// see https://github.com/pytorch/pytorch/issues/9137 -TEST(ObserverTest, TestMultipleNetBase) { - Workspace ws; - ws.CreateBlob("in"); - NetDef net_def; - unique_ptr net(CreateNetTestHelper(&ws, true)); - EXPECT_EQ(caffe2::dynamic_cast_if_rtti(net.get()), net.get()); - - // There may be some default observers - const size_t prev_num = net.get()->NumObservers(); - const int num_tests = 100; - vector::Observer*> observers; - for (int i = 0; i < num_tests; ++i) { - unique_ptr> net_ob = - make_unique>(net.get()); - observers.emplace_back(net.get()->AttachObserver(std::move(net_ob))); - } - - net.get()->Run(); - - for (const auto& observer : observers) { - net.get()->DetachObserver(observer); - } - - EXPECT_EQ(net.get()->NumObservers(), prev_num); -} -#endif -} // namespace caffe2 diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h deleted file mode 100644 index 3277357b4f34..000000000000 --- a/caffe2/core/operator.h +++ /dev/null @@ -1,1600 +0,0 @@ -#ifndef CAFFE2_CORE_OPERATOR_H_ -#define CAFFE2_CORE_OPERATOR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/common.h" -#include "caffe2/core/net.h" -#include "caffe2/core/observer.h" -#include "caffe2/core/operator_gradient.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/core/tensor.h" -#include "caffe2/core/tensor_int8.h" -#include "caffe2/core/types.h" -#include "caffe2/core/workspace.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/proto_utils.h" - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -#include -#include -#include -#endif - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") -#endif - -C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions); -C10_DECLARE_bool(caffe2_operator_throw_if_fp_overflow_exceptions); -#ifdef __GNU_LIBRARY__ -C10_DECLARE_bool(caffe2_operator_throw_on_first_occurrence_if_fp_exceptions); -#endif - -namespace c10 { -struct FunctionSchema; -} - -namespace caffe2 { - -class TORCH_API OperatorBase; -typedef ObserverBase OperatorObserver; - -class TORCH_API OperatorBase : public Observable { - public: - explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws); - - /* - * Notes: All outputs ivalues must be tensors. Input ivalue list must start - * with all tensors ("inputs" in caffe2 terminology), - * followed by non-tensors ("arguments" in caffe2 terminology). - * Alternatively, inputs can be one tensor list ivalue followed by non-tensors - * to represent operators with a variable number of inputs. - */ -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - explicit OperatorBase( - const c10::FunctionSchema& schema, - std::vector inputs, - std::vector outputs); -#endif - - virtual ~OperatorBase() noexcept; - - /** @brief Return true if the operator was instantiated with OperatorDef - * New operators should be instantiated with FunctionSchema - */ - bool isLegacyOperator() const { -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - return !fn_schema_; -#else - return true; -#endif - } - - const c10::FunctionSchema& getFunctionSchema() const { - CAFFE_ENFORCE(!isLegacyOperator()); -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - return *fn_schema_.get(); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - /** @brief Checks if the operator has an argument of the given name. - */ - inline bool HasArgument(c10::string_view name) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::HasArgument(*operator_def_, name); - } - return argumentIndexWithName(name).has_value(); - } - - // Functions that deal with arguments. Basically, this allows us to map an - // argument name to a specific type of argument that we are trying to access. - template - inline T GetSingleArgument(c10::string_view name, const T& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetSingleArgument( - *operator_def_, name, default_value); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto index = argumentIndexWithName(name); - CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); - const auto& value = newstyle_inputs_[index.value()]; - return value.template to(); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - template - inline bool HasSingleArgumentOfType(c10::string_view name) const { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::HasSingleArgumentOfType( - *operator_def_, name); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - template - inline vector GetVectorFromIValueList(const c10::IValue& value) const { - return value.template to>().vec(); - } -#endif - - template - inline vector GetRepeatedArgument( - c10::string_view name, - const vector& default_value = {}) const; - - // Get the inputs and outputs as specific types. - template - inline const T& Input(int idx) { - static_assert( - !std::is_same::value, - "You should use Input(int, DeviceType) for " - "Tensor."); - TORCH_DCHECK_LT((size_t)idx, inputs_.size()); - try { - return inputs_.at(idx)->template Get(); - } catch (::caffe2::EnforceNotMet& enf) { - if (has_debug_def()) { - TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), "."); - } - throw enf; - } - } - - // TODO(jerryzh): Remove template - // and the type argument? - // This is to keep the API changes minimal and make refactoring - // a bit easier - template - inline const T& Input(int idx, DeviceType type) { - if (isLegacyOperator()) { - static_assert( - std::is_same::value, - "Input(int, DeviceType) is only available for Tensor"); - TORCH_DCHECK_LT((size_t)idx, inputs_.size()); - try { - // TODO(jerryzh): We'll need to check device type in Get() later - // Get() -> Get(type) - const auto& tensor = inputs_.at(idx)->template Get(); - return tensor; - } catch (::caffe2::EnforceNotMet& enf) { - if (has_debug_def()) { - TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), "."); - } - throw enf; - } - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - TORCH_DCHECK_LT(0U, newstyle_inputs_.size()); - IValue ival; - if (newstyle_inputs_[0].isTensorList()) { - // if the first input is a tensor list, we get input tensors by indexing - // into that list. currently, this means that only tensors from that list - // are accessible as inputs. any hypothetical input tensors that come - // after the list are not accessible. - auto tensorList = newstyle_inputs_[0].toTensorVector(); - TORCH_DCHECK_LT((size_t)idx, tensorList.size()); - ival = tensorList[idx]; - } else { - // if the first input is not a tensor list, we get input tensors by - // indexing into the inputs. - TORCH_DCHECK_LT((size_t)idx, newstyle_inputs_.size()); - ival = newstyle_inputs_[idx]; - } - CAFFE_ENFORCE( - ival.isTensor(), - "Input(int, DeviceType) is only available for IValues that store Tensors"); - auto t = ival.toTensor(); - if (!t.is_contiguous()) { - t = t.contiguous(); - } - Tensor tensor = caffe2::Tensor(std::move(t)); - CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type); - input_tensors_[idx] = std::move(tensor); - return input_tensors_[idx]; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - template - inline T* Output(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "Output(idx) not supported for operators exported to c10. Please use XOutput instead."); - - static_assert( - !std::is_same::value, - "You should use Output(int, DeviceType) for " - "Tensor."); - return outputs_.at(idx)->template GetMutable(); - } - - // TODO(jerryzh): Remove this template - template - inline T* Output(int idx, DeviceType type) { - if (isLegacyOperator()) { - static_assert( - std::is_same::value, - "Output(int, DeviceType) is only available for Tensor"); - // When you get a Tensor here it is not fully initialized - return BlobGetMutableTensor(outputs_.at(idx), type); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto &output = output_tensors_[idx]; - if (!output.defined() || output.GetDeviceType() != type) { - // Fix tensor type - output = Tensor(type); - } - return &output; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - inline Tensor - XOutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) { - CAFFE_ENFORCE_WITH_CALLER( - options.device_opt() != c10::nullopt, - "device must be provided in option."); - if (isLegacyOperator()) { - return XBlobGetMutableTensor(outputs_.at(idx), dims, options); - } - - return OutputTensor(idx, dims, options)->UnsafeSharedInstance(); - } - - void SetOutputTensor(int idx, Tensor tensor) { - if (!isLegacyOperator()) { -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - output_tensors_[idx] = std::move(tensor); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } else { - // update the tensor in the workspace - BlobSetTensor(outputs_.at(idx), std::move(tensor)); - } - } - - Tensor OutputTensorOrUndefined(int idx) { - if (isLegacyOperator()) { - return BlobGetTensorOrUndefined(*outputs_.at(idx)); - } - return output_tensors_[idx].UnsafeSharedInstance(); - } - - inline Tensor* - OutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) { - if (isLegacyOperator()) { - CAFFE_ENFORCE_WITH_CALLER( - options.device_opt() != c10::nullopt, - "device must be provided in options."); - return BlobGetMutableTensor(outputs_.at(idx), dims, options); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto &output = output_tensors_[idx]; - output = output.defined() - ? GetSizedTensorWithOptions(std::move(output), dims, options) - : caffe2::empty(dims, options); - - return &output; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - - // Get output Tensor of the operator and CopyFrom the given Tensor - Tensor* OutputTensorCopyFrom( - int idx, - at::TensorOptions options, - const Tensor& src, - bool async = false) { - CAFFE_ENFORCE_WITH_CALLER( - options.device_opt() != c10::nullopt, - "device must be provided in options."); - // Ouptut Tensor will always have the same data type as `src` - if (!options.has_dtype()) { - options = options.dtype(src.dtype()); - } - CAFFE_ENFORCE_WITH_CALLER( - options.dtype() == src.dtype(), - "We don't allow change of src data type in OutputTensorCopyFrom"); - Tensor* t = OutputTensor(idx, src.sizes(), options); - t->CopyFrom(src, async); - return t; - } - - Tensor* OutputTensorAlias(int idx, const Tensor& src) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputTensorAlias(idx, src) not (yet) supported for operators exported to c10."); - return BlobSetTensor(OutputBlob(idx), src.Alias()); - } - - template - inline T* Output(int idx, T* allocated) { - CAFFE_ENFORCE( - isLegacyOperator(), - "Output(idx, allocated) not supported for operators exported to c10. Please use XOutput."); - outputs_.at(idx)->Reset(allocated); - return allocated; - } - - inline const Blob& InputBlob(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "InputBlob(idx) not (yet) supported for operators exported to c10."); - return *inputs_.at(idx); - } - - inline Blob* OutputBlob(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputBlob(idx) not (yet) supported for operators exported to c10."); - return outputs_.at(idx); - } - - // Check whether output j is an alias of input i by comparing Blob pointers, - // note this does not check if the two Blobs points to the same Tensor, or if - // the Tensor pointers point to the same TensorImpl, or if the Storages alias - inline bool IsInputOutputAlias(int i, int j) { - CAFFE_ENFORCE( - isLegacyOperator(), - "IsInputOutputAlias(i, j) not (yet) supported for operators exported to c10."); - return inputs_.at(i) == outputs_.at(j); - } - - template - inline bool InputIsType(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "InputIsType(idx) not (yet) supported for operators exported to c10."); - static_assert( - !std::is_same::value, - "You should use InputIsTensorType(int, DeviceType) for " - "Tensor."); - return inputs_.at(idx)->template IsType(); - } - - inline bool InputIsTensorType(int idx, DeviceType device_type) { - CAFFE_ENFORCE( - isLegacyOperator(), - "InputIsTensorType(idx, device_type) not (yet) supported for operators exported to c10."); - return BlobIsTensorType(*inputs_.at(idx), device_type); - } - - template - inline bool OutputIsType(int idx) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputIsType(idx) not (yet) supported for operators exported to c10."); - static_assert( - !std::is_same::value, - "You should use OutputIsTensorType(int, DeviceType) for " - "Tensor."); - return outputs_.at(idx)->template IsType(); - } - - inline bool OutputIsTensorType(int idx, DeviceType type) { - CAFFE_ENFORCE( - isLegacyOperator(), - "OutputIsTensorType(idx, type) not (yet) supported for operators exported to c10."); - return BlobIsTensorType(*outputs_.at(idx), type); - } - - inline int InputSize() const { - return input_size_; - } - - inline int OutputSize() const { - if (isLegacyOperator()) { - return outputs_.size(); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - return output_tensors_.size(); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif - } - inline const vector& Inputs() const { - CAFFE_ENFORCE( - isLegacyOperator(), - "Inputs() not supported for operators exported to c10."); - return inputs_; - } - inline const vector& Outputs() { - CAFFE_ENFORCE( - isLegacyOperator(), - "Outputs() not supported for operators exported to c10."); - return outputs_; - } - vector InputTensorShapes() const; - - virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) { - ev.Finish(); - } - - inline void Wait(const OperatorBase& other, int stream_id = -1) { - if (!other.IsEventDisabled()) { - WaitEvent(other.event(), stream_id); - } - } - - virtual void WaitEvents( - const std::vector& events, - int /*stream_id*/ = -1) { - for (const auto& ev : events) { - ev->Finish(); - } - } - - virtual void Finish() { - if (event_) { - event_->Finish(); - } - } - - virtual bool Run(int /* unused */ /*stream_id*/ = 0) { - CAFFE_NOT_IMPLEMENTED; - } - - virtual bool HasAsyncPart() const { - return false; - } - - virtual bool SupportsAsyncScheduling() const { - return false; - } - - virtual void CancelAsyncCallback() {} - - virtual void Cancel() {} - - // RunAsync, if implemented by the specific operators, will schedule the - // computation on the corresponding context and record the event in its - // event_ member object. If the specific operator does not support RunAsync, - // it will simply be synchronous as a fallback. - virtual bool RunAsync(int stream_id = 0); - - virtual void AddRelatedBlobInfo(EnforceNotMet* err); - - virtual std::string debug_info_string() const { - return ""; - } - - inline const OperatorDef& debug_def() const { - CAFFE_ENFORCE(has_debug_def(), "operator_def was null!"); - return *operator_def_; - } - - inline void set_debug_def( - const std::shared_ptr& operator_def) { - operator_def_ = operator_def; - } - - inline bool has_debug_def() const { - return operator_def_ != nullptr; - } - - public: - void RecordLastFailedOpNetPosition() { - if (net_position_ != kNoNetPositionSet) { - VLOG(1) << "Operator with id " << net_position_ << " failed"; - operator_ws_->last_failed_op_net_position = net_position_; - } else { - VLOG(1) << "Failed operator doesn't have id set"; - } - } - - int net_position() const { - return net_position_; - } - - void set_net_position(int idx) { - net_position_ = idx; - } - - const DeviceOption& device_option() const { - return device_option_; - } - - const Event& event() const { - CAFFE_ENFORCE(event_, "Event is disabled"); - return *event_; - } - - Event& event() { - CAFFE_ENFORCE(event_, "Event is disabled"); - return *event_; - } - - void ResetEvent() { - if (event_) { - event_->Reset(); - } - } - - void DisableEvent() { - event_ = nullptr; - } - - bool IsEventDisabled() const { - return !event_; - } - - // Internal API invoked by observers. Normal callers shouldn't invoke it. - virtual void SyncDeviceBarrierForObservers() { - CAFFE_NOT_IMPLEMENTED; - } - - // Checks whether stream is ready to execute new computation, - // used in stream allocation optimization to skip stream that is currently - // busy. Depends on context and operator's device, returns true by default - virtual bool IsStreamFree(int /* unused */) const { - return true; - } - - const std::string& type() const { - return type_; - } - - void annotate_engine(const std::string& engine) { - engine_ = engine; - } - - const std::string& engine() const { - return engine_; - } - - void SetExecutorHelper(ExecutorHelper* helper) { - helper_ = helper; - } - - ExecutorHelper* GetExecutorHelper() const { - return helper_; - } - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - std::vector move_output_tensors() && { - return std::move(output_tensors_); - } -#endif - - public: - static const int kNoNetPositionSet = -1; - - private: - Workspace* operator_ws_; - std::shared_ptr operator_def_; - DeviceOption device_option_; - std::string engine_; - std::string type_; - vector inputs_; - vector outputs_; - // Preferably use std::optional, but nvcc doesn't work -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - std::unique_ptr fn_schema_; - vector newstyle_inputs_; -#endif - // HACK - // We preserve the fact that Output() returns Tensor* - // by storing Tensor in a vector owned by the - // operator. - vector input_tensors_; - vector output_tensors_; - - int input_size_; - - int net_position_{kNoNetPositionSet}; - - ExecutorHelper* helper_ = nullptr; - - protected: - virtual void RecordEvent(const char* /*err_msg*/ = nullptr) { - CAFFE_NOT_IMPLEMENTED; - } - - void SetEventFinished(const char* err_msg = nullptr) { - if (event_) { - event_->SetFinished(err_msg); - } - } - - void SetEventFinishedWithException(const char* err_msg = nullptr) { - if (event_) { - event_->SetFinishedWithException(err_msg); - } - } - - std::string getErrorMsg() { - if (has_debug_def()) { - return "Error from operator: " + ProtoDebugString(debug_def()); - } else { - return "Error from operator: no op def"; - } - } - - std::optional argumentIndexWithName(c10::string_view name) const; - - // An event used by asynchronous execution. - std::unique_ptr event_; - - C10_DISABLE_COPY_AND_ASSIGN(OperatorBase); -}; - -template <> -inline NetDef OperatorBase::GetSingleArgument( - c10::string_view name, - const NetDef& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetSingleArgument( - *operator_def_, name, default_value); - } - CAFFE_THROW("Cannot get NetDefs from IValue"); - return NetDef(); -} - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - auto vs = value.toIntVector(); - vector out; - out.reserve(vs.size()); - for (int64_t v : vs) { - out.emplace_back(v); - } - return out; -} - -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - const auto& vs = value.toDoubleVector(); - vector out; - out.reserve(vs.size()); - for (double v : vs) { - out.emplace_back(v); - } - return out; -} - -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - auto vs = value.template to>(); - vector out; - out.reserve(vs.size()); - for (string v : vs) { - out.emplace_back(v); - } - return out; -} - -// We need this specialisation because IValue based lists don't support -// int16_t. We need to load it as List and transform to int16_t. -template <> -inline vector OperatorBase::GetVectorFromIValueList( - const c10::IValue& value) const { - auto list = value.template to>(); - std::vector result; - result.reserve(list.size()); - for (int64_t elem : list) { - result.push_back(static_cast(elem)); - } - return result; -} -#endif - -// OP_SINGLE_ARG provides a shorter initialization choice for initialization of -// member variables for the class constructors. -#define OP_SINGLE_ARG(type, name, variable, default) \ - variable(OperatorBase::GetSingleArgument(name, (default))) - -// INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the -// operator's inputs and outputs, in order to avoid confusion. For example, for -// a fully convolution layer that has input, weight and bias, you can define its -// input tags as: -// INPUT_TAGS(INPUT, WEIGHT, BIAS); -// And in the code, instead of doing -// auto& weight = Input(1); -// you can now do -// auto& weight = Input(WEIGHT); -// to make it more clear. -#define INPUT_TAGS(first_input, ...) \ - enum _InputTags { first_input = 0, __VA_ARGS__ } -#define OUTPUT_TAGS(first_input, ...) \ - enum _OutputTags { first_input = 0, __VA_ARGS__ } - -template -inline vector OperatorBase::GetRepeatedArgument( - c10::string_view name, - const vector& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetRepeatedArgument( - *operator_def_, name, default_value); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto index = argumentIndexWithName(name); - CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); - const auto& value = newstyle_inputs_[index.value()]; - return GetVectorFromIValueList(value); -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif -} - -// We need this specialisation because IValue based lists don't support -// int16_t. We need to load it as List and transform to int16_t. -template <> -inline vector OperatorBase::GetRepeatedArgument( - c10::string_view name, - const vector& default_value) const { - if (isLegacyOperator()) { - CAFFE_ENFORCE(operator_def_, "operator_def was null!"); - return ArgumentHelper::GetRepeatedArgument( - *operator_def_, name, default_value); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - auto index = argumentIndexWithName(name); - CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); - const auto& value = newstyle_inputs_[index.value()]; - auto vec = GetVectorFromIValueList(value); - std::vector result; - result.reserve(vec.size()); - for (int64_t elem : vec) { - result.push_back(static_cast(elem)); - } - return result; -#else - CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); -#endif -} - -// Operator is the class that you usually want to derive, if your operator will -// run on different devices. You should then implement the RunOnDevice() -// function. -template -class Operator : public OperatorBase { - public: - explicit Operator(const OperatorDef& operator_def, Workspace* ws, StreamId stream = 0) - : OperatorBase(operator_def, ws), context_(operator_def.device_option()) { - // In the constructor, we switch to the device so that the child class - // constructors will run on that device. - context_.SwitchToDevice(stream); - } -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - explicit Operator( - const c10::FunctionSchema& fn_schema, - std::vector inputs, - std::vector outputs, - StreamId stream = 0) - : OperatorBase(fn_schema, std::move(inputs), std::move(outputs)) { - // In the constructor, we switch to the device so that the child class - // constructors will run on that device. - context_.SwitchToDevice(stream); - } -#endif - ~Operator() noexcept override {} - - /// Retrieve a non-owning reference to the input at position 'idx' for this - /// operator. The returned reference is valid for the duration of the - /// RunOnDevice call. The optional 'type' parameter can be used to assert a - /// required device type for the input (by default, we assert that the tensor - /// is consistent with the device type implied by the Context parameter of an - /// Operator.) - inline const Tensor& Input( - int idx, - DeviceType type = Context::GetDeviceType()) { - return OperatorBase::template Input(idx, type); - } - - /// XOutput is a modernized version of Output which returns a Tensor - /// rather than a Tensor* (the raw pointer in the latter case is - /// useless, as Tensor is a pointer type.) - Tensor XOutput(int idx, at::IntArrayRef dims, at::TensorOptions options) { - // We'll default device to the device of the current Operator Context - if (options.device_opt() == c10::nullopt) { - return OperatorBase::XOutputTensor( - idx, dims, options.device(context_.device())); - } - return OperatorBase::XOutputTensor(idx, dims, options); - } - - /// Retrieve a non-owning pointer to the output at position 'idx', - /// initializing it to have size 'dims' and properties 'options' if - /// there is no pre-existing output or the pre-existing output does - /// not have the correct options. The returned pointer is valid for - /// the duration of the RunOnDevice call. If device is not explicitly - /// specified in options, we default to allocating output on the - /// current device of the device type implied by the Context parameter - /// of this Operator. - /// - /// Note [Operator::Output what?] - /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - /// The contract of Operator::Output is somewhat complex; it is perhaps better - /// understood in terms of what was historically an idiomatic Caffe2 operator - /// implementation: - /// - /// void RunOnDevice() override { - /// auto* output = Output(0, output_size, dtype()); - /// float* output_ptr = output->data(); - /// // write into output_ptr - /// } - /// - /// In the simple case, this code does the following things: - /// - /// 1. Allocates a new tensor with size 'output_size' and dtype 'float' - /// (and device type whatever the Operator's device type is) - /// 2. "Registers" this tensor as the 0th output tensor of this operator - /// (Caffe2 operators don't "return" outputs; instead, outputs - /// are shoved into an output vector which the executor reads out.) - /// 3. Returns the tensor, so the operator implementation can write - /// the actual output data into the tensor. - /// - /// So what's this business with "pre-existing" outputs? Caffe2 - /// commonly applies an optimization whereby it reuses tensors on - /// subsequent runs of operators in a graph. It doesn't know ahead - /// of time what intermediate tensors it will need, so the first - /// time it runs a graph it has all of the operators create the outputs - /// necessary (as described above). However, the second time around, - /// it will reuse all of the tensors created from the first time. - /// If they are lucky, this time the Output() call is a no-op and - /// just returns the old tensor. - /// - /// However, we cannot /guarantee/ that the output size will be the - /// same the next time the Operator is called; for example, output - /// size may be data dependent and vary between runs. In this case, - /// we have to resize it to the correct size. Resizing is still - /// helpful, as we may be able to fit the output in the same - /// space that was previously used. - /// - Tensor* Output(int idx, at::IntArrayRef dims, at::TensorOptions options) { - // We'll default device to the device of the current Operator Context - if (options.device_opt() == c10::nullopt) { - return OperatorBase::OutputTensor( - idx, dims, options.device(context_.device())); - } - return OperatorBase::OutputTensor(idx, dims, options); - } - - /// Legacy: please consider using the version of Output() which also takes - /// dtype and size as arguments. - inline Tensor* Output(int idx, DeviceType type = Context::GetDeviceType()) { - return OperatorBase::template Output(idx, type); - } - - /// Get the output Tensor of an operator (allocating it if it is not - /// already initialized), and copy the contents of src into it. - /// You probably don't actually want to use this function (the fact - /// that you have a Tensor to copy from is probably a mistake: - /// you should have written the output into the output tensor, - /// from Output, directly in the first place), but this method - /// is situationally useful. - Tensor* OutputTensorCopyFrom( - int idx, - at::TensorOptions options, - const Tensor& src, - bool async = false) { - if (options.device_opt() == c10::nullopt) { - return OperatorBase::OutputTensorCopyFrom( - idx, options.device(context_.device()), src, async); - } - return OperatorBase::OutputTensorCopyFrom(idx, options, src, async); - } - - void WaitEvent(const Event& ev, int stream_id = -1) final { - if (stream_id >= 0) { - context_.SwitchToDevice(stream_id); - } - context_.WaitEvent(ev); - } - - void WaitEvents(const std::vector& events, int stream_id = -1) - final { - if (stream_id >= 0) { - context_.SwitchToDevice(stream_id); - } - for (const auto& ev : events) { - context_.WaitEvent(*ev); - } - } - - // The run function of Operator switches to the device, and then carries out - // the actual computation with RunOnDevice(). You should implement RunOnDevice - // instead of Run(). - // Note: Run does not update operator's event and can be used only with - // non-async executors that do not rely on events - bool Run(int stream_id = 0) final { - try { - StartAllObservers(); - - context_.SwitchToDevice(stream_id); - - // Clear floating point exception flags before RunOnDevice. We will test - // exception flags afterwards, and raise an error if an exception has - // happened. - if (FLAGS_caffe2_operator_throw_if_fp_exceptions || - FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - std::feclearexcept(FE_ALL_EXCEPT); - } - -#ifdef __GNU_LIBRARY__ - // If glibc is available, use feenableexcept that will raise exception - // right away. - int old_enabled_exceptions = 0; - if (FLAGS_caffe2_operator_throw_on_first_occurrence_if_fp_exceptions) { - if (FLAGS_caffe2_operator_throw_if_fp_exceptions || - FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - int flag = 0; - if (FLAGS_caffe2_operator_throw_if_fp_exceptions) { - flag |= FE_DIVBYZERO | FE_INVALID; - } - if (FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - flag |= FE_OVERFLOW; - } - old_enabled_exceptions = feenableexcept(flag); - } - } -#endif - bool result = RunOnDevice(); -#ifdef __GNU_LIBRARY__ - if (FLAGS_caffe2_operator_throw_on_first_occurrence_if_fp_exceptions) { - if (FLAGS_caffe2_operator_throw_if_fp_exceptions || - FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - fedisableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); - std::feclearexcept(FE_ALL_EXCEPT); - feenableexcept(old_enabled_exceptions); - } - } -#endif - if (FLAGS_caffe2_operator_throw_if_fp_exceptions) { - CAFFE_ENFORCE( - !std::fetestexcept(FE_DIVBYZERO), - "Division by zero floating point exception (FE_DIVBYZERO) reported."); - CAFFE_ENFORCE( - !std::fetestexcept(FE_INVALID), - "Invalid floating point exception (FE_INVALID) reported."); - } - if (FLAGS_caffe2_operator_throw_if_fp_overflow_exceptions) { - CAFFE_ENFORCE( - !std::fetestexcept(FE_OVERFLOW), - "Overflow floating point exception (FE_OVERFLOW) reported."); - } - if (!result) { - this->RecordLastFailedOpNetPosition(); - } - context_.FinishDeviceComputation(); // throws on error - - StopAllObservers(); - - return result; - } catch (EnforceNotMet& err) { - if (has_debug_def()) { - err.add_context( - "Error from operator: \n" + ProtoDebugString(debug_def())); - AddRelatedBlobInfo(&err); - } - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } catch (...) { - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } - } - - bool RunAsync(int stream_id = 0) final { - try { - StartAllObservers(); - - context_.SwitchToDevice(stream_id); - auto result = RunOnDevice(); - if (result) { - if (HasAsyncPart()) { - RecordEvent(); - } else { - // Manually set CPU operator's event status to finished, - // unless this is an async CPU operator - SetEventFinished(); - } - } else { - SetEventFinished(getErrorMsg().c_str()); - this->RecordLastFailedOpNetPosition(); - } - - StopAllObservers(); - - return result; - } catch (EnforceNotMet& err) { - if (has_debug_def()) { - err.add_context( - "Error from operator: \n" + ProtoDebugString(debug_def())); - AddRelatedBlobInfo(&err); - } - SetEventFinishedWithException(err.what()); - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } catch (const std::exception& err) { - SetEventFinishedWithException(err.what()); - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } catch (...) { - SetEventFinishedWithException(getErrorMsg().c_str()); - this->RecordLastFailedOpNetPosition(); - StopAllObservers(); - throw; - } - } - - bool IsStreamFree(int stream_id) const override { - return context_.IsStreamFree(device_option(), stream_id); - } - - virtual bool RunOnDevice() = 0; - - // Returns whether operator has async on device part. - // CUDA operators by default have async parts, CPU operators by default - // don't have async parts and are finished after RunOnDevice call. - // Events of operators that don't have async parts are automatically set - // to finished state by RunAsync. - // Defaulting to the value from context (true for CUDA, false for CPU). - // Override in case of async CPU operators - // Async CPU operators are expected to catch all exceptions in async parts - // and set Event to finished/failed state with Event::SetFinished or - // SetFinishedWithException call. - bool HasAsyncPart() const override { - return context_.HasAsyncPartDefault(); - } - - // Returns whether operator's RunOnDevice schedules async on device part and - // can be run without waiting for parent operator's async part to be finished - // on the same device. - // Note: when true, RunOnDevice must not access the content of the input blobs - // as they might not be computed yet - // Note: when true, operator's device needs to support async scheduling: - // - supports concept of streams: async ops scheduled on the same stream are - // guaranteed to be executed in the same order they were scheduled - // - provides non-blocking cross device/cross stream synchronization - // primitives - // - // By default, assuming an op with an async part can be scheduled - // asynchronously if device supports async scheduling - bool SupportsAsyncScheduling() const override { - return HasAsyncPart() && context_.SupportsAsyncScheduling(); - } - - void SyncDeviceBarrierForObservers() override { - context_.FinishDeviceComputation(); - } - - const Context* getContext() const { - return &context_; - } - Context* getContext() { - return &context_; - } - - protected: - void RecordEvent(const char* err_msg = nullptr) final { - if (event_) { - context_.Record(event_.get(), err_msg); - } - } - - Context context_; -}; - -#define USE_OPERATOR_BASE_FUNCTIONS \ - /* using override */ using OperatorBase::HasArgument; \ - /* using override */ using OperatorBase::GetSingleArgument; \ - /* using override */ using OperatorBase::HasSingleArgumentOfType; \ - /* using override */ using OperatorBase::GetRepeatedArgument; \ - /* using override */ using OperatorBase::InputIsType; \ - /* using override */ using OperatorBase::InputSize; \ - /* using override */ using OperatorBase::Output; \ - /* using override */ using OperatorBase::Input; \ - /* using override */ using OperatorBase::OutputSize; \ - /* using override */ using OperatorBase::IsInputOutputAlias; \ - /* using override */ using OperatorBase::OutputTensorAlias - -#define USE_OPERATOR_FUNCTIONS(context) \ - USE_OPERATOR_BASE_FUNCTIONS; \ - /* using override */ using Operator::context_; \ - /* using override */ using Operator::Input; \ - /* using override */ using Operator::InputBlob; \ - /* using override */ using Operator::Output; \ - /* using override */ using Operator::OutputBlob; \ - /* using override */ using Operator::OutputTensorCopyFrom - -#define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context) - -#define USE_SIMPLE_CTOR_DTOR(name) \ - template \ - explicit name(Args&&... args) \ - : Operator(std::forward(args)...) {} \ - virtual ~name() noexcept override {} - -// Helpers to implement runtime op polymorphism. Often it's convenient to make -// an op work on different input types (e.g. i32 vs i64 indices) or special-case -// it for particular input size (e.g. ScatterWeightedSum for block size of 1 -// doesn't need to call Eigen). -// -// DispatchHelper provides compile-time generation of nested "if" statements, -// e.g. `DispatchHelper>::call(this, block_size);` -// unrolls into: -// if (block_size == 1) { -// return DoRunWithValue<1>(); -// } else if (block_size = 4) { -// return DoRunWithValue<4>(); -// } else { -// return DoRunWithValue<-1>(); -// }` -// -// DoRunWithValue implementation can use template arguments to do "if" -// statements -// or proxy to functions in math.h which often provide fixed size -// implementation. -// -// Similarly `TensorTypes(this, Input(0))` provides branching -// based on type of the first input and calls DoRunWithType. -// -// Note, that the same instance of Op class is used as the method, not class is -// templated. We might consider adding static class-level polymorphism later. -// -// Convenient macro USE_DISPATCH_HELPER is provided for declaring friendship in -// case DoRunWithValue or DoRunWithType are declared non-public. - -#define USE_DISPATCH_HELPER \ - template \ - friend struct DispatchHelper - -template -struct FixedValues {}; - -template -struct TensorTypes {}; - -// Special tag that can be listed in TensorTypes to denote that a special -// implementation in 'RunWithOtherType' needs to be called instead of failing -// Obviously this needs to be the last item in lists, e.g. -// TensorTypes -struct GenericTensorImplementation {}; - -// Same as TensorTypes but call DoRunWithType2 -template -struct TensorTypes2 {}; - -template -struct DispatchHelper; - -template -struct DispatchHelper, ExtraArgs...> { - template - static bool call(Op* op, int value) { - if (FirstVal == value) { - return op->template DoRunWithValue(); - } - return DispatchHelper, ExtraArgs...>::template call< - Op>(op, value); - } -}; - -template -struct DispatchHelper, ExtraArgs...> { - template - static bool call(Op* op, int64_t /*size*/) { - return op->template DoRunWithValue(); - } -}; - -#define C10_DEFINE_TENSOR_TYPES_DISPATCHER( \ - TensorTypes, DoRunWithType, DoRunWithOtherType) \ - template \ - struct DispatchHelper, ExtraArgs...> { \ - template \ - static bool call(Op* op, const TypeMeta meta) { \ - static_assert( \ - !std::is_same::value, \ - "GenericTensorImplementation must be the last in TensorTypes list"); \ - if (meta.Match()) { \ - return op->template DoRunWithType(); \ - } \ - return DispatchHelper, ExtraArgs...>:: \ - template call(op, meta); \ - } \ - template \ - static bool call(Op* op, const Tensor& tensor) { \ - return call(op, tensor.dtype()); \ - } \ - template \ - static bool call(Op* op, const Blob& blob) { \ - return call(op, blob.meta()); \ - } \ - }; \ - \ - template \ - struct DispatchHelper, ExtraArgs...> { \ - template \ - static bool call(Op* /* unused */, const TypeMeta meta) { \ - CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ - } \ - template \ - static bool call(Op* op, const Tensor& tensor) { \ - return call(op, tensor.dtype()); \ - } \ - template \ - static bool call(Op* op, const Blob& blob) { \ - return call(op, blob.meta()); \ - } \ - }; \ - \ - template \ - struct DispatchHelper< \ - TensorTypes, \ - ExtraArgs...> { \ - template \ - static bool call(Op* op, const TypeMeta) { \ - return op->template DoRunWithOtherType(); \ - } \ - template \ - static bool call(Op* op, const Tensor& tensor) { \ - return call(op, tensor.dtype()); \ - } \ - template \ - static bool call(Op* op, const Blob& blob) { \ - return call(op, blob.meta()); \ - } \ - }; -C10_DEFINE_TENSOR_TYPES_DISPATCHER( - TensorTypes, - DoRunWithType, - DoRunWithOtherType) -C10_DEFINE_TENSOR_TYPES_DISPATCHER( - TensorTypes2, - DoRunWithType2, - DoRunWithOtherType2) -#undef C10_DEFINE_TENSOR_TYPES_DISPATCHER - -// The device type registry. This works in two phases: -// (1) gDeviceTypeRegistry() maps the device types values to the actual operator -// registry function. -// (2) Then, one can call the operator registry function to further create the -// operators. -typedef c10::Registry< - std::string, - std::unique_ptr, - const OperatorDef&, - Workspace*> - OperatorRegistry; -typedef c10::Registry< - std::string, - std::unique_ptr, - const OperatorDef&, - Workspace*>* (*RegistryFunction)(); -TORCH_API std::map* gDeviceTypeRegistry(); - -struct TORCH_API DeviceTypeRegisterer { - explicit DeviceTypeRegisterer(DeviceType type, RegistryFunction func); -}; - -#if defined(_MSC_VER) -#define IMPORT_IF_NOT_MSVC -#else -#define IMPORT_IF_NOT_MSVC C10_IMPORT -#endif - -#define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \ - namespace { \ - static DeviceTypeRegisterer C10_ANONYMOUS_VARIABLE( \ - DeviceType)(type, ®istry_function); \ - } - -// The operator registry. Since we are not expecting a great number of devices, -// we will simply have an if-then type command and allocate the actual -// generation to device-specific registerers. -// Note that although we have CUDA and CUDNN here, the registerers themselves do -// not depend on specific cuda or cudnn libraries. This means that we will be -// able to compile it even when there is no cuda available - we simply do not -// link any cuda or cudnn operators. -C10_DECLARE_REGISTRY( - CPUOperatorRegistry, - OperatorBase, - const OperatorDef&, - Workspace*); -#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) -#define REGISTER_CPU_OPERATOR(name, ...) \ - IMPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \ - CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - } \ - C10_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) -#define REGISTER_CPU_OPERATOR_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__) - -#define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) - -// Use these macros to register gradient operators. They can be automatically -// excluded from builds that don't need them (e.g., mobile). -#ifdef CAFFE2_NO_GRADIENT_OPS -#define REGISTER_CPU_GRADIENT_OPERATOR(...) /* No gradients. */ -#else -#define REGISTER_CPU_GRADIENT_OPERATOR(...) \ - C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR(__VA_ARGS__)) -#endif - -#ifdef CAFFE2_NO_GRADIENT_OPS -#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) /* No gradients. */ -#else -#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) \ - C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR_WITH_ENGINE(__VA_ARGS__)) -#endif - -C10_DECLARE_REGISTRY( - CUDAOperatorRegistry, - OperatorBase, - const OperatorDef&, - Workspace*); -#define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__) -#define REGISTER_CUDA_OPERATOR(name, ...) \ - IMPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \ - CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - } \ - C10_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__) -#define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__) - -#define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) - -// Macros for cudnn since we use it often -#define REGISTER_CUDNN_OPERATOR(name, ...) \ - REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) - -// Macros for HIP operators -C10_DECLARE_REGISTRY( - HIPOperatorRegistry, - OperatorBase, - const OperatorDef&, - Workspace*); -#define REGISTER_HIP_OPERATOR_CREATOR(key, ...) \ - C10_REGISTER_CREATOR(HIPOperatorRegistry, key, __VA_ARGS__) -#define REGISTER_HIP_OPERATOR(name, ...) \ - IMPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_HIP##name() { \ - CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ - } \ - C10_REGISTER_CLASS(HIPOperatorRegistry, name, __VA_ARGS__) -#define REGISTER_HIP_OPERATOR_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(HIPOperatorRegistry, str_name, __VA_ARGS__) - -#define REGISTER_HIP_OPERATOR_WITH_ENGINE(name, engine, ...) \ - C10_REGISTER_CLASS(HIPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) - -#define REGISTER_MIOPEN_OPERATOR(name, ...) \ - REGISTER_HIP_OPERATOR_WITH_ENGINE(name, MIOPEN, __VA_ARGS__) \ - REGISTER_HIP_OPERATOR_WITH_ENGINE( \ - name, CUDNN, __VA_ARGS__) // Make CUDNN an alias of MIOPEN for HIP ops - -// StaticLinkingProtector is a helper class that ensures that the Caffe2 -// library is linked correctly with whole archives (in the case of static -// linking). What happens is that when CreateOperator is called for the first -// time, it instantiates an OperatorLinkingProtector object to check if the -// operator registry is empty. If it is empty, this means that we are not -// properly linking the library. -// -// You should not need to use this class. -struct StaticLinkingProtector { - StaticLinkingProtector() { - const auto registered_ops = CPUOperatorRegistry()->Keys().size(); - // Note: this is a check failure instead of an exception, because if - // the linking is wrong, Caffe2 won't be able to run properly anyway, - // so it's better to fail loud. - // If Caffe2 is properly linked with whole archive, there should be more - // than zero registered ops. - if (registered_ops == 0) { - LOG(FATAL) - << "You might have made a build error: the Caffe2 library does not seem " - "to be linked with whole-static library option. To do so, use " - "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the " - "Caffe2 library."; - } - } -}; - -// An exception that can be thrown by an operator constructor that notifies -// that it does not support the given setting. This can be usually used for -// specific engines that only implement a subset of the features required by -// the original operator schema. -// TODO(jiayq): make more feature-complete exception message. -class TORCH_API UnsupportedOperatorFeature : public std::exception { - public: - UnsupportedOperatorFeature(const string& msg) : msg_(msg) {} - const char* what() const noexcept override { - return msg_.c_str(); - } - - private: - string msg_; -}; - -// A helper macro that should ONLY be used in the operator constructor to check -// if needed features are met. If not, throws the UnsupportedOperatorFeature -// exception with the given message. -#define OPERATOR_NEEDS_FEATURE(condition, ...) \ - if (!(condition)) { \ - throw UnsupportedOperatorFeature(::c10::str(__VA_ARGS__)); \ - } - -// Creates an operator with the given operator definition. -// Throws on error and never returns nullptr -TORCH_API unique_ptr CreateOperator( - const OperatorDef& operator_def, - Workspace* ws, - int net_position = OperatorBase::kNoNetPositionSet); - -TORCH_API const std::string OpRegistryKey( - const std::string& op_type, - const std::string& engine = ""); - -// User can set the preferred engines as a list of engine names, in -// descending order of preference. -using EnginePrefType = std::vector; -// {device_type -> {operator_name -> EnginePrefType}} -using PerOpEnginePrefType = - CaffeMap>; -// {device_type -> EnginePrefType} -using GlobalEnginePrefType = CaffeMap; -TORCH_API void SetPerOpEnginePref( - const PerOpEnginePrefType& per_op_engine_pref); -TORCH_API void SetGlobalEnginePref( - const GlobalEnginePrefType& global_engine_pref); -TORCH_API void SetEnginePref( - const PerOpEnginePrefType& per_op_engine_pref, - const GlobalEnginePrefType& global_engine_pref); -TORCH_API void SetOpEnginePref( - const std::string& op_type, - const CaffeMap& op_pref); - -TORCH_API void LoadInt8TensorInfoOfBlob( - std::vector* scale, - std::vector* offset, - uint32_t* axis, - const Blob* b); - -TORCH_API TensorShape GetTensorShapeOfBlob(const Blob* b); - -TORCH_API TensorShapes InferBlobShapesAndTypes( - CaffeMap& blob_desc, - const vector& nets); - -TORCH_API TensorShapes InferBlobShapesAndTypesFromWorkspace( - Workspace* ws, - const vector& nets); - -TORCH_API TensorShapes InferBlobShapesAndTypesFromMap( - const CaffeMap>& blob_dimensions, - const vector& nets); - -TORCH_API TensorShapes InferBlobShapesAndTypesFromMap( - const CaffeMap>& blob_dimensions, - const CaffeMap& blob_types, - const vector& nets); - -TORCH_API std::map> -ValidateTensorDevices(OperatorBase& op, const OperatorDef& op_def); - -// Get a set of registered operator names -TORCH_API std::set GetRegisteredOperators(); - -// Operator logging capabilities -TORCH_API void SetOperatorLogger( - std::function tracer); -std::function GetOperatorLogger(); - -#ifndef C10_MOBILE -// This is for transferring tensor data between C2 and backends. -struct ExternalTensorDescriptor { - uint64_t dataType; - uint32_t dimensions; - const uint64_t* shape; - uint8_t isOffline = 0; - uint32_t quantizationAxis; - uint64_t quantizationParams; - const float* scales; - const int32_t* biases; - uint64_t buffer; -}; - -class ExternalTensorFunctionsBase { - public: - explicit ExternalTensorFunctionsBase() {} - virtual ~ExternalTensorFunctionsBase() {} - virtual bool isQuantized() const = 0; - virtual bool IsSameMetaType(TypeIdentifier id) = 0; - virtual void SetupExternalTensorDescriptor( - const Blob* blob, - std::vector>* shapes, - std::vector>* all_scales, - std::vector>* all_offsets, - ExternalTensorDescriptor* desc) = 0; - virtual void LoadInfoOfBlob( - const Blob* blob, - std::vector* scale, - std::vector* offset, - uint32_t* axis) = 0; - virtual TypeIdentifier GetTypeMetaId() = 0; - virtual TypeMeta GetExternalTensorType(const void* c) = 0; - virtual vector GetExternalTensorInfo( - const void* c, - size_t* capacity, - DeviceOption* device) = 0; -}; - -C10_DECLARE_TYPED_REGISTRY( - ExternalTensorFunctionsBaseRegistry, - TypeIdentifier, - ExternalTensorFunctionsBase, - std::unique_ptr); - -#define REGISTER_EXTERNAL_TENSOR_FUNCTIONS(id, ...) \ - C10_REGISTER_TYPED_CLASS(ExternalTensorFunctionsBaseRegistry, id, __VA_ARGS__) -inline unique_ptr CreateExternalTensorFunctions( - TypeIdentifier id) { - return ExternalTensorFunctionsBaseRegistry()->Create(id); -} -#endif // C10_MOBILE - -} // namespace caffe2 - -C10_CLANG_DIAGNOSTIC_POP() - -#endif // CAFFE2_CORE_OPERATOR_H_ diff --git a/caffe2/core/operator_gpu_test.cc b/caffe2/core/operator_gpu_test.cc deleted file mode 100644 index 80c58b7a3c75..000000000000 --- a/caffe2/core/operator_gpu_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include - -#include -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -class JustTest : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - virtual std::string type() { - return "BASE"; - } -}; - -class JustTestCUDA : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - std::string type() override { - return "CUDA"; - } -}; - -class JustTestCUDNN : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - std::string type() override { - return "CUDNN"; - } -}; - -OPERATOR_SCHEMA(JustTest).NumInputs(0, 1).NumOutputs(0, 1); -REGISTER_CUDA_OPERATOR(JustTest, JustTestCUDA); -REGISTER_CUDNN_OPERATOR(JustTest, JustTestCUDNN); - -TEST(EnginePrefTest, GPUDeviceDefaultPreferredEngines) { - if (!HasCudaGPU()) - return; - OperatorDef op_def; - Workspace ws; - op_def.mutable_device_option()->set_device_type(PROTO_CUDA); - op_def.set_type("JustTest"); - - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // CUDNN should be taken as it's in the default global preferred engines - // list - EXPECT_EQ(static_cast(op.get())->type(), "CUDNN"); - } -} - -} // namespace caffe2 diff --git a/caffe2/core/operator_gradient.h b/caffe2/core/operator_gradient.h deleted file mode 100644 index 5c8d97a38fd2..000000000000 --- a/caffe2/core/operator_gradient.h +++ /dev/null @@ -1,337 +0,0 @@ -#ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_ -#define CAFFE2_CORE_OPERATOR_GRADIENT_H_ - -#include "c10/util/Registry.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/proto_utils.h" - -namespace caffe2 { - -/* @brief A struct that abstracts on top of dense and sparse blobs. - * - * For a dense blob, its gradient name should be written into dense_, and for - * a sparse blob, its gradient name should be written into indice_ for - * the sparse indices and value_ for the values. - */ -struct TORCH_API GradientWrapper { - string dense_; - string indices_; - string values_; - - inline bool IsDense() const { - return (dense_.size() != 0); - } - inline bool IsSparse() const { - return (indices_.size() != 0 || values_.size() != 0); - } - inline bool IsEmpty() const { - return (!IsDense() && !IsSparse()); - } -}; - -/** - * A struct that holds the gradient operators and related gradient maps. - */ -struct TORCH_API GradientOpsMeta { - vector ops_; - vector g_input_; - - GradientOpsMeta() {} - GradientOpsMeta( - const vector& ops, - const vector& v) - : ops_(ops), g_input_(v) {} -}; - -class TORCH_API GradientMakerBase { - public: - GradientMakerBase( - const OperatorDef& def, - const vector& g_output) - : def_(def), g_output_(g_output), g_input_(def.input_size()){}; - virtual ~GradientMakerBase() {} - virtual bool CopyDeviceOption() const { - return true; - } - virtual bool CopyEngine() const { - return true; - } - virtual bool CopyArguments() const { - return true; - } - - virtual void VerifyOp() const { - auto* schema = OpSchemaRegistry::Schema(def_.type()); - if (schema) { - CAFFE_ENFORCE( - schema->Verify(def_), - "(GradientMaker) Operator def did not pass schema checking: ", - ProtoDebugString(def_)); - } - } - - /** - * @brief Returns the gradient ops meta. - * - * If your gradient op generator only use standard input and output - * manipulations, you can simply implement GetGradientDefs() that - * returns vector. In that, you can call GI, GI_V and GI_I - * that will automatically create the gradient registration for you. - * - * If you need to do custom gradient name registration, overload this - * function directly. - */ - virtual GradientOpsMeta Get() { - VerifyOp(); - vector new_defs = GetGradientDefs(); - for (auto& opdef : new_defs) { - opdef.set_is_gradient_op(true); - } - return GradientOpsMeta(new_defs, g_input_); - }; - - const OperatorDef& Def() const { - return def_; - } - - protected: - virtual vector GetGradientDefs() { - CAFFE_NOT_IMPLEMENTED; - } - - // Helper functions to return names for the gradient computation. - // I(idx), O(idx): return the input and output names. - // GO(idx): return the name of the gradient for output idx. - // GI(idx), GI_I(idx), GI_V(idx): return the name of the gradient for - // input idx, and also registers that name into the gradient - // registry to be returned. - string I(const int i) { - CAFFE_ENFORCE((i >= 0) && (i < def_.input().size())); - return def_.input(i); - } - string O(const int i) { - CAFFE_ENFORCE((i >= 0) && (i < def_.output().size())); - return def_.output(i); - } - string GI(const int i) { - CAFFE_ENFORCE( - !g_input_.at(i).IsSparse(), - "Input ", - def_.input(i), - " already set to sparse."); - g_input_.at(i).dense_ = GradientName(def_.input(i)); - return GradientName(def_.input(i)); - } - string GI_I(const int i) { - CAFFE_ENFORCE( - !g_input_.at(i).IsDense(), - "Input ", - def_.input(i), - " already set to dense."); - g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i)); - return GradientSliceIndices(def_.input(i)); - } - string GI_V(const int i) { - CAFFE_ENFORCE( - !g_input_.at(i).IsDense(), - "Input ", - def_.input(i), - " already set to dense."); - g_input_.at(i).values_ = GradientSliceValues(def_.input(i)); - return GradientSliceValues(def_.input(i)); - } - string GO(const int i) { - CAFFE_ENFORCE( - g_output_.at(i).IsDense(), - "Gradient of output ", - def_.output(i), - (g_output_.at(i).IsSparse() ? " is sparse (expected dense)." - : " is not provided!")); - return g_output_.at(i).dense_; - } - string GO_I(const int i) { - CAFFE_ENFORCE( - g_output_.at(i).IsSparse(), - "Gradient of output ", - def_.output(i), - (g_output_.at(i).IsDense() ? " is dense (expected sparse)." - : " is not provided!")); - return g_output_.at(i).indices_; - } - string GO_V(const int i) { - CAFFE_ENFORCE( - g_output_.at(i).IsSparse(), - "Gradient of output ", - def_.output(i), - (g_output_.at(i).IsDense() ? " is dense (expected sparse)." - : " is not provided!")); - return g_output_.at(i).values_; - } - const GradientWrapper& GradOut(int i) { - return g_output_.at(i); - } - - // Function to add a gradient pair to map. - void SetDense(const int i, const string& name) { - CAFFE_ENFORCE( - !g_input_.at(i).IsSparse(), - "Input ", - def_.input(i), - " already set to sparse."); - g_input_.at(i).dense_ = name; - } - void SetSparse(const int i, const string& indices, const string& values) { - CAFFE_ENFORCE( - !g_input_.at(i).IsDense(), - "Input ", - def_.input(i), - " already set to dense."); - g_input_.at(i).indices_ = indices; - g_input_.at(i).values_ = values; - } - - /** - * @brief a helper function to allow one to create one single operator - * def, which is usually the case for many simple operators. - */ - template - inline static vector SingleGradientDef(const Args&... args) { - return vector{CreateOperatorDef(args...)}; - } - - public: - /** - * Returns map that returns the parameters that the gradients are for. - */ - static CaffeMap MatchGradsToParams(const OperatorDef& op) { - // NOTE: how to go beyond string-matching? - CaffeMap m; - for (auto& out : op.output()) { - if (IsGradientBlob(out)) { - m[out] = out.substr(0, out.length() - 5); - } - } - return m; - } - - private: - // Utility functions for gradient name computation. We don't expose them - // in order to discourage the use of such names explicitly. - static string GradientName(const string& name) { - return name + "_grad"; - } - - static bool IsGradientBlob(const string& name) { - return name.length() > 5 && name.find("_grad") == name.length() - 5; - } - - static string GradientNameToParam(const string& name) { - CHECK(IsGradientBlob(name)); - return name.substr(0, name.length() - 5); - } - - static string GradientSliceIndices(const string& name) { - return name + "_grad_indices"; - } - - static string GradientSliceValues(const string& name) { - return name + "_grad_values"; - } - - protected: - // We make the member variables protected in case someone wants to write - // a fully custom Get() function. - const OperatorDef& def_; - const vector& g_output_; - vector g_input_; -}; - -/** - * @brief A helper class to indicate that the operator does not need gradient - * computation. - * - * Use the macro NO_GRADIENT to register operators that do not have gradients. - * Note that this is different fron SHOULD_NOT_DO_GRADIENT: the latter means - * that the gradient computation should not flow through it at all, and throws - * an error if it is called. - */ -class TORCH_API NoGradient : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - vector GetGradientDefs() override { - return vector(); - } -}; - -/** - * @brief A helper class to indicate that the operator should have no gradient. - * - * This is used when the operator definition is designed to not have a gradient. - * Calling a gradient on this operator def will cause Caffe2 to quit. - */ -struct ThrowInTheTowelIfGradientIsCalled : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - GradientOpsMeta Get() override { - CAFFE_THROW("One should not call gradient for operator ", def_.type(), "."); - } -}; - -/** - * @brief A helper class to indicate that the gradient mechanism is not ready. - * - * This should only be used sparsely when the gradient does exist, but we have - * not implemented it yet and are using this as a lazy excuse. Eventually, a - * gradient operator should be implemented. - */ -struct GradientNotImplementedYet : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - GradientOpsMeta Get() override { - CAFFE_THROW( - "Operator ", - def_.type(), - " should have a gradient but is not implemented yet."); - } -}; - -C10_DECLARE_REGISTRY( - GradientRegistry, - GradientMakerBase, - const OperatorDef&, - const vector&); - -#ifdef CAFFE2_NO_GRADIENT_OPS - -#define REGISTER_GRADIENT(name, ...) /* No gradients. */ -#define REGISTER_GRADIENT_STR(str_name, ...) /* No gradients. */ - -#else - -#define REGISTER_GRADIENT(name, ...) \ - C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) -#define REGISTER_GRADIENT_STR(str_name, ...) \ - C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__) - -#endif - -// NO_GRADIENT means that the operator does not need any gradient computation. -#define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient) - -// SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have -// gradient operators. If you attempt to call the gradient, a log fatal will -// occur. -#define SHOULD_NOT_DO_GRADIENT(name) \ - REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled) - -#define GRADIENT_NOT_IMPLEMENTED_YET(name) \ - REGISTER_GRADIENT(name, GradientNotImplementedYet) - -/** - * @brief Gets the GradientOpsMeta for the given operator def. - */ -TORCH_API GradientOpsMeta GetGradientForOp( - const OperatorDef& def, - const vector& g_output); - -} // namespace caffe2 - -#endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_ diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h deleted file mode 100644 index f5b9d0dc09a2..000000000000 --- a/caffe2/core/operator_schema.h +++ /dev/null @@ -1,612 +0,0 @@ -#ifndef CAFFE2_CORE_OPERATOR_SCHEMA_H_ -#define CAFFE2_CORE_OPERATOR_SCHEMA_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace caffe2 { - -// A const value returned by OpSchema::CalculateOutput() if the number of -// output cannot be determined. -constexpr int kCannotComputeNumOutputs = -1; - -/** - * @brief A class to record the schema of an op. - * - * OpSchema records the common interface of an op specified by its name. This - * is optional for each operator implemented in Caffe2 but is strongly - * recommended. - * - * To register an OpSchema, one can use the macro OPERATOR_SCHEMA(name) and - * then append the various functions in the class. For example, for an op - * that takes in two inputs, one output, and the first input and output - * could be in-place, can be written as - * - * OPERATOR_SCHEMA(name) - * .NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}}); - */ -class TORCH_API OpSchema { - public: - OpSchema() : OpSchema("unknown", "unknown", 0) {} - OpSchema(const string& type, const string& file, const int line); - - /** - * @brief Returns the file that the op schema is registered from. - */ - inline const string& file() const { - return file_; - } - - /** - * @brief Returns the line in file that the op schema is registered from. - */ - inline int line() const { - return line_; - } - - /** - * @brief Returns the docstring of the op schema. - */ - inline const char* doc() const { - return doc_.empty() ? nullptr : doc_.c_str(); - } - - /** - * @brief Verifies if an operator definition protobuf matches the pattern - * specified in the schema. - */ - bool Verify(const OperatorDef& def) const; - - // Functions to set the property of the operator schemas. - // Sets the number of inputs, either a fixed number or a min and a max. - - /** - * @brief A single input. - */ - OpSchema& NumInputs(int n); - /** - * @brief Input could be in range [min, max], inclusive. - */ - OpSchema& NumInputs(int min, int max); - /** - * @brief Input could be one of the values specified in allowed_input_nums. - */ - OpSchema& NumInputs(set allowed_input_nums); - /** - * @brief Input is checked with a specified function. - */ - OpSchema& NumInputs(std::function func); - - // Sets the number of outputs, either a fixed number, a min and a max, - // or a function that takes in the input number and produces an output - // number. Use only one function in the set below. - /** - * @brief A single output. - */ - OpSchema& NumOutputs(int n); - /** - * @brief Output could be in range [min, max], inclusive. - */ - OpSchema& NumOutputs(int min, int max); - /** - * @brief Output could be one of the values specified in allowed_output_nums. - */ - OpSchema& NumOutputs(set allowed_output_nums); - /** - * @brief Output is checked with a specified function. - */ - OpSchema& NumOutputs(std::function func); - - /** - * @brief Relationship between inputs and outputs is checked with a specified - * function. - */ - OpSchema& NumInputsOutputs(std::function func); - - // Set the function that can calculate the number of output based on the - // number of input. Use only one function in the set below. - /** - * @brief Set the output calculator to a user-defined function. - */ - OpSchema& OutputCalculator(std::function calc); - /** - * @brief Set the number of outputs to be the same as the number of inputs. - */ - OpSchema& SameNumberOfOutput(); - - // Sets the rule to allow optional in-place operation. - OpSchema& AllowInplace(std::function inplace); - OpSchema& AllowInplace(set> inplace); - OpSchema& AllowOneToOneInplace(); - // Sets the rule to enforce in-place operation. - OpSchema& EnforceInplace(std::function inplace); - OpSchema& EnforceInplace(set> inplace); - OpSchema& EnforceOneToOneInplace(); - - // Functions to deal with type and shape inference. Basically, this registers - // a function that takes in an OperatorDef and a series of input type and - // shape specified by TensorProto objects (whose data fields are empty), and - // produces a series of output type and shape. - typedef std::function< - vector(const OperatorDef&, const vector&)> - TensorInferenceFunctionType; - - /** - * @brief Sets the tensor inference function, which is a std::function object - * defined in operator_schema.h. - */ - OpSchema& TensorInferenceFunction(TensorInferenceFunctionType function); - - /** - * A wrapper that makes an infer tensor function to return unknown - * shape for all outputs if any one of the inputs has unknown shape - */ - - static TensorInferenceFunctionType NeedsAllInputShapes( - TensorInferenceFunctionType f); - - /** - * @brief Sets the corresponding onnx schema name - */ - OpSchema& InheritOnnxSchema(const std::string& onnx_schema_name); - - /** - * @brief Shortcut to InheritOnnxSchema(type_) - */ - OpSchema& InheritOnnxSchema() { - return InheritOnnxSchema(type_); - } - - /** - * @brief Sets the tensor inference function to produce the same output as - * the input. - */ - OpSchema& IdenticalTypeAndShape(); - OpSchema& IdenticalTypeAndShapeOfInput(int idx); - OpSchema& IdenticalTypeAndShapeOfInputDim(int idx, int dim); - OpSchema& IdenticalTypeAndShapeOfMultipleInputs(const vector& indices); - OpSchema& ScalarType(::caffe2::TensorProto_DataType dt); - - /** - * @brief A function to allow one to infer the type and shape from the op - * schema. - */ - inline vector InferTensor( - const OperatorDef& def, - const vector& input_type_shape) const { - CAFFE_ENFORCE( - Verify(def), - "(InferTensor) Operator def did not pass schema checking: ", - ProtoDebugString(def)); - return tensor_inference_function_(def, input_type_shape); - } - - /* - * @brief A struct to store various cost information about - * an operator such as FLOPs, total memory use and parameters. - */ - struct Cost { - uint64_t flops{0}; // Floating point operations. - uint64_t bytes_read{0}; // Total memory read. - uint64_t bytes_written{0}; // Total memory written. - uint64_t params_bytes{0}; // Memory read for parameters. - }; - /** - * @brief Registers a function that takes in an OperatorDef - * and a series of input shapes and returns the total "cost" - * required to run the operator via struct by value. - */ - typedef std::function< - struct Cost(const OperatorDef&, const vector&)> - CostInferenceFunctionType; - - /** - * @brief Register the Cost inference function. - */ - OpSchema& CostInferenceFunction(CostInferenceFunctionType function); - -#if 0 // def _MSC_VER - /** - * @brief Register the Cost inference function via a pointer. - */ - template :value - >:type> - inline OpSchema& CostInferenceFunction(T func) { - // Note: This is here in order to resolve an MSVC compiler issue: it - // does not automatically convert a function pointer to a std::function, - // and needs an explicit conversion. - return CostInferenceFunction(CostInferenceFunctionType(func)); - } -#endif // _MSC_VER - - bool HasCostInferenceFunction() const { - return !!cost_inference_function_; - } - - inline struct Cost InferCost( - const OperatorDef& def, - const vector& input_tensor_shape) const { - CAFFE_ENFORCE( - cost_inference_function_, "Cost inference function not defined."); - return (*cost_inference_function_)(def, input_tensor_shape); - } - - // Functions to do documentation for the operator schema. - OpSchema& SetDoc(const string& doc); - - struct Argument { - Argument(const char* name, const char* description, bool required) - : name_{name}, description_{description}, required_{required} {} - - const char* name() const { - return name_; - } - - const char* description() const { - return description_; - } - - bool is_required() const { - return required_; - } - - private: - const char* name_; - const char* description_; - const bool required_; - }; - - OpSchema& - Arg(const char* name, const char* description, bool required = false); - -#define DECLARE_STANDARD_ARG(name, str) \ - static const char* Arg_##name; \ - OpSchema& Arg##name(const char* description); - - DECLARE_STANDARD_ARG(IsTest, is_test) - -#undef DECLARE_STANDARD_ARG - - OpSchema& Input(const int n, const char* name, const char* description); - OpSchema& Output(const int n, const char* name, const char* description); - // Calls the passed function with `this` as an argument. Useful for - // adding docs for templated/macro ops. - OpSchema& FillUsing(std::function populator); - - // Remove from documentation - OpSchema& Private(); - - // This op can pass data across devices - OpSchema& InputsCanCrossDevices(); - - /** - * @brief A function to allow one to get the number of outputs based on the - * number of inputs, if this schema supports it. - */ - int CalculateOutput(int num_input) const; - - const std::string& onnx_schema() const { - return onnx_schema_; - } - - int min_input() const { - return min_input_; - } - - int max_input() const { - return max_input_; - } - - int min_output() const { - return min_output_; - } - - int max_output() const { - return max_output_; - } - - bool num_inputs_allowed(int x) const { - return num_inputs_allowed_(x); - } - - bool num_outputs_allowed(int x) const { - return num_outputs_allowed_(x); - } - - bool num_inputs_outputs_allowed(int x, int y) const { - return num_inputs_outputs_allowed_(x, y); - } - - int inf() const { - return std::numeric_limits::max(); - } - - bool inplace_enforced(int x, int y) const { - return inplace_enforced_(x, y); - } - - TORCH_API friend std::ostream& operator<<( - std::ostream& out, - const OpSchema& schema); - - const std::vector& args() const { - return args_; - } - - const std::vector>& input_desc() const { - return input_desc_; - } - const std::vector>& output_desc() const { - return output_desc_; - } - bool private_op() { - return private_; - } - bool inputs_can_cross_devices() const { - return inputs_can_cross_devices_; - } - - /** - * @brief Returns the required device location of inputs and outputs. - */ - using DeviceInferenceFunctionType = std::function< - std::pair, std::vector>( - const OperatorDef& def)>; - - OpSchema& DeviceInferenceFunction(DeviceInferenceFunctionType function); - - /** - * @brief Infer required device location of an op's inputs and outputs - */ - inline std::pair, std::vector> - InferDevice(const OperatorDef& def) const { - return device_inference_function_(def); - } - - // The helper is build sparse input with values, keys, weights and lengths; - // e.g.: - // values = [1, 2, 3, 2, 4, 6, 7, 3, 6] - // keys = [0, 1, 4, 0, 1, 2, 5, 1, 2] - // weights = [1, 2, 3, 4, 5, 6, 7, 8, 9] - // \_____/ \________/ \__/ - // lengths = [3, 4, 2] - OpSchema& WeightedValueKeyLengthInputFillers( - size_t value_index, - size_t key_index, - size_t length_index, - size_t weight_index); - - // The helper is build sparse input with values, keys, weights and lengths; - // e.g.: - // values = [1, 2, 3, 2, 4, 6, 7, 3, 6] - // keys = [0, 1, 4, 0, 1, 2, 5, 1, 2] - // \_____/ \________/ \__/ - // lengths = [3, 4, 2] - OpSchema& ValueKeyLengthInputFillers( - size_t value_index, - size_t key_index, - size_t length_index); - - // The helper is build sparse input with values and lengths; e.g.: - // values = [1, 2, 3, 2, 4, 6, 7, 3, 6] - // \_____/ \________/ \__/ - // lengths = [3, 4, 2] - OpSchema& ValueLengthInputFillers(size_t value_index, size_t length_index); - - OpSchema& DisallowInputFillers(); - - std::vector InputFillers( - const std::vector>& shapes) const; - - private: - std::vector SupplyDenseFillers( - const std::vector>& shapes); - - private: - string type_; - string file_; - string doc_; - string onnx_schema_; - std::vector args_{}; - std::vector> input_desc_{}; - std::vector> output_desc_{}; - int line_ = 0; - int min_input_ = 0; - int max_input_ = std::numeric_limits::max(); - int min_output_ = 0; - int max_output_ = std::numeric_limits::max(); - bool private_ = false; - bool inputs_can_cross_devices_ = false; - std::function num_inputs_allowed_ = [](int) { return true; }; - std::function num_outputs_allowed_ = [](int) { return true; }; - std::function num_inputs_outputs_allowed_ = [](int, int) { - return true; - }; - std::function calculate_output_; - // In default, any in-place operation is neither allowed nor enforced. - std::function inplace_allowed_ = [](int, int) { - return false; - }; - std::function inplace_enforced_ = [](int, int) { - return false; - }; - TensorInferenceFunctionType tensor_inference_function_; - std::unique_ptr cost_inference_function_ = nullptr; - DeviceInferenceFunctionType device_inference_function_; - - std::function( - const std::vector>&)> - filler_supplier_ = - [this](const std::vector>& shapes) { - return SupplyDenseFillers(shapes); - }; -}; - -/** - * @brief A registry to hold all the operator schemas. - */ -class TORCH_API OpSchemaRegistry { - public: - static OpSchema& - NewSchema(const string& key, const string& file, const int line); - - static const OpSchema* Schema(const string& key) { - auto& m = map(); - auto it = m.find(key); - if (it != m.end()) { - return &it->second; - } else { - return nullptr; - } - } - - private: - // OpSchemaRegistry should not need to be instantiated. - OpSchemaRegistry() = delete; - - /** - * @brief Returns the underlying string to OpSchema map. - * - * You should not manually manipulate the map object returned. Instead, use - * the macros defined such as OPERATOR_SCHEMA to register your operator - * schema. - * - * We wrap it inside a function to avoid the static initialization order - * fiasco. - */ - static CaffeMap& map(); -}; - -// Helper function for creating simple tensorproto with dimension and type -template -inline TensorShape CreateTensorShape( - vector dims, - ::caffe2::TensorProto_DataType dt) { - TensorShape ts; - for (T_I d : dims) { - ts.add_dims(d); - } - ts.set_data_type(dt); - return ts; -} - -// Helper function -inline vector GetDimsVector(const TensorShape& shape) { - vector dims; - for (auto d : shape.dims()) { - dims.push_back(d); - } - return dims; -} - -// Helper function -inline uint64_t nElemFromDim(const TensorShape& X, int dim = 0) { - CAFFE_ENFORCE_GE(dim, 0, "Invalid maximum index specified"); - - uint64_t nElem = 1; - for (const auto i : c10::irange(dim, X.dims_size())) { - nElem *= X.dims(i); - } - return nElem; -} - -// Helper function -inline uint64_t nElemBetweenDim(const TensorShape& X, int start, int stop) { - CAFFE_ENFORCE_GE(start, 0, "Invalid maximum index specified"); - CAFFE_ENFORCE_LE(stop, X.dims_size(), "Invalid maximum index specified"); - - uint64_t nElem = 1; - for (const auto i : c10::irange(start, stop)) { - nElem *= X.dims(i); - } - return nElem; -} - -// Helper function for infer op inputs and outputs device information. -inline std::pair, std::vector> -InferOpInputOutputDevice(const OperatorDef& op) { - auto op_schema = OpSchemaRegistry::Schema(op.type()); - if (op_schema) { - // op_schema found - return op_schema->InferDevice(op); - - } else { - // No schema for op.type registered - auto temp_schema = OpSchema(); - return temp_schema.InferDevice(op); - } -} - -template -OpSchema::Cost PointwiseCostInference( - const OperatorDef& /* unused */, - const vector& inputs) { - struct OpSchema::Cost c; - const TensorShape X = inputs[0]; - uint64_t nElemX = nElemFromDim(X); - uint64_t nElemRead = 0; - for (const auto i : c10::irange(inputs.size())) { - nElemRead += nElemFromDim(inputs[i]); - } - - c.flops = nElemX * OpsPerPoint; - auto const& X_element_size_byte = - DataTypeToTypeMeta(X.data_type()).itemsize(); - c.bytes_read = nElemRead * X_element_size_byte; - c.bytes_written = nElemX * X_element_size_byte; - return c; -} - -} // namespace caffe2 - -#if defined(_MSC_VER) -#define EXPORT_IF_NOT_MSVC -#else -#define EXPORT_IF_NOT_MSVC C10_EXPORT -#endif - -#ifndef CAFFE2_NO_OPERATOR_SCHEMA - -#define OPERATOR_SCHEMA(name) \ - EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ - static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ - &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) - -#else // CAFFE2_NO_OPERATOR_SCHEMA - -#define OPERATOR_SCHEMA(name) \ - EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ - static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ - 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) - -#endif // CAFFE2_NO_OPERATOR_SCHEMA - -#ifdef CAFFE2_NO_GRADIENT_OPS - -#define GRADIENT_OPERATOR_SCHEMA(name) \ - EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ - static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ - 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) - -#else - -#define GRADIENT_OPERATOR_SCHEMA(name) OPERATOR_SCHEMA(name) - -#endif -#endif // CAFFE2_CORE_OPERATOR_SCHEMA_H_ diff --git a/caffe2/core/operator_schema_test.cc b/caffe2/core/operator_schema_test.cc deleted file mode 100644 index 5e54cf7d37dd..000000000000 --- a/caffe2/core/operator_schema_test.cc +++ /dev/null @@ -1,279 +0,0 @@ -#include "caffe2/core/logging.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/operator_schema.h" -#include "caffe2/utils/proto_utils.h" - -#include - -namespace caffe2 { - -OPERATOR_SCHEMA(OpSchemaTestOp) - .NumInputs(1).NumOutputs(1) - .SetDoc(R"DOC(Test Documentation)DOC") - .Input(0, "in0", "dummy input.") - .Output(0, "out0", "dummy output."); - -TEST(OperatorSchemaTest, BasicSchema) { - const OpSchema* schema = OpSchemaRegistry::Schema("OpSchemaTestOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - EXPECT_TRUE(schema != nullptr); - EXPECT_TRUE(schema->doc() != nullptr); - OperatorDef def1 = CreateOperatorDef( - "OpSchemaTestOp", "", - vector{"in"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaTestOp", "", - vector{"in1", "in2"}, vector{"out"}); - EXPECT_FALSE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaTestOp", "", - vector{"in"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaSpecifiedInputOutputOp) - .NumInputs({2, 4}).NumOutputs({1, 3}); - -TEST(OperatorSchemaTest, SpecifiedInputOutput) { - const OpSchema* schema - = OpSchemaRegistry::Schema("OpSchemaSpecifiedInputOutputOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - EXPECT_TRUE(schema != nullptr); - OperatorDef def1 = CreateOperatorDef( - "OpSchemaSpecifiedInputOutputOp", "", - vector{"in"}, vector{"out"}); - EXPECT_FALSE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaSpecifiedInputOutputOp", "", - vector{"in1", "in2"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaSpecifiedInputOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaInputOutputRelationOp) - .NumInputsOutputs([](int in, int out) { - return out == in || out == in * 2; - }); - -TEST(OperatorSchemaTest, InputOutputRelation) { - const OpSchema* schema - = OpSchemaRegistry::Schema("OpSchemaInputOutputRelationOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - EXPECT_TRUE(schema != nullptr); - OperatorDef def1 = CreateOperatorDef( - "OpSchemaInputOutputRelationOp", "", - vector{"in"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaInputOutputRelationOp", "", - vector{"in"}, vector{"out1", "out2"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaInputOutputRelationOp", "", - vector{"in1", "in2", "in3"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaSameInputOutputOp) - .SameNumberOfOutput(); - -TEST(OperatorSchemaTest, SameInputOutput) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaSameInputOutputOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def1 = CreateOperatorDef( - "OpSchemaSameInputOutputOp", "", - vector{"in"}, vector{"out"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaSameInputOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaSameInputOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2", "out3"}); - EXPECT_FALSE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaCalculateOutputOp) - .NumInputs(1, 5).NumOutputs(2, 6) - .OutputCalculator([](int n) { return n + 1; }); - -TEST(OperatorSchemaTest, CalculateOutput) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaCalculateOutputOp"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def1 = CreateOperatorDef( - "OpSchemaCalculateOutputOp", "", - vector{"in"}, vector{"out"}); - EXPECT_FALSE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaCalculateOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaCalculateOutputOp", "", - vector{"in1", "in2"}, vector{"out1", "out2", "out3"}); - EXPECT_TRUE(schema->Verify(def3)); -} - -OPERATOR_SCHEMA(OpSchemaInplace) - .NumInputs(2).NumOutputs(2) - .AllowInplace({{0, 0}}) - .EnforceInplace({{1, 1}}); - -TEST(OperatorSchemaTest, Inplace) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaInplace"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def1 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"out1", "in2"}); - EXPECT_TRUE(schema->Verify(def1)); - OperatorDef def2 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"in1", "in2"}); - EXPECT_TRUE(schema->Verify(def2)); - OperatorDef def3 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"in1", "out2"}); - EXPECT_FALSE(schema->Verify(def3)); - OperatorDef def4 = CreateOperatorDef( - "OpSchemaInplace", "", - vector{"in1", "in2"}, vector{"out1", "out2"}); - EXPECT_FALSE(schema->Verify(def4)); -} - -OPERATOR_SCHEMA(OpSchemaSameInputOutputTensorInference).IdenticalTypeAndShape(); - -TEST(OperatorSchemaTest, TensorInferenceIdentical) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaSameInputOutputTensorInference"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def = CreateOperatorDef( - "OpSchemaSameInputOutputTensorInference", - "", - vector{"in"}, - vector{"out"}); - vector shapes(1); - shapes[0].set_data_type(TensorProto::FLOAT); - shapes[0].add_dims(1); - shapes[0].add_dims(2); - shapes[0].add_dims(3); - vector out = schema->InferTensor(def, shapes); - EXPECT_EQ(out.size(), 1); - EXPECT_EQ(out[0].SerializeAsString(), shapes[0].SerializeAsString()); -} - -OPERATOR_SCHEMA(OpSchemaArbitraryTensorInference) - .TensorInferenceFunction( - [](const OperatorDef&, const vector&) { - vector shapes(1); - shapes[0].set_data_type(TensorProto::FLOAT); - shapes[0].add_dims(1701); - return shapes; - }); - -TEST(OperatorSchemaTest, TensorInferenceArbitrary) { - const OpSchema* schema = - OpSchemaRegistry::Schema("OpSchemaArbitraryTensorInference"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - OperatorDef def = CreateOperatorDef( - "OpSchemaArbitraryTensorInference", - "", - vector{"in"}, - vector{"out"}); - vector out = schema->InferTensor(def, vector()); - EXPECT_EQ(out.size(), 1); - EXPECT_EQ(out[0].data_type(), TensorProto::FLOAT); - EXPECT_EQ(out[0].dims_size(), 1); - EXPECT_EQ(out[0].dims(0), 1701); -} - -TEST(OperatorSchemaTest, TestCastSchema) { - // This tests a use case of the schema: the Cast op takes in the def and - // deduces the - // schema from the "to" argument. - const OpSchema* schema = OpSchemaRegistry::Schema("Cast"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - if (!schema) { - // Compiled without the Cast op. - return; - } - OperatorDef def = CreateOperatorDef( - "Cast", - "", - vector{"in"}, - vector{"out"}, - vector{MakeArgument("to", TensorProto::UINT8)}); - vector out = schema->InferTensor(def, vector(1)); - EXPECT_EQ(out.size(), 1); - // Data type should be inferred. - EXPECT_EQ(out[0].data_type(), TensorProto::UINT8); - // Dim should not be set (same as input); - EXPECT_EQ(out[0].dims_size(), 0); -} - -OPERATOR_SCHEMA(OpSchemaCostInference) - .NumInputs(2) - .NumOutputs(2) - .CostInferenceFunction([](const OperatorDef& /*def*/, - const vector& inputs) { - struct OpSchema::Cost c; - c.flops = 2 * inputs[0].dims(0) * inputs[0].dims(1) * inputs[1].dims(1); - return c; - }); - -TEST(OperatorSchemaTest, TestCostInference) { - const OpSchema* schema = OpSchemaRegistry::Schema("OpSchemaCostInference"); -#ifdef CAFFE2_NO_OPERATOR_SCHEMA - EXPECT_TRUE(schema == nullptr); - return; -#endif - if (!schema) { - return; - } - OperatorDef def = CreateOperatorDef( - "OpSchemaCostInference", "", vector{"in"}, vector{"out"}); - vector shapes(2); - shapes[0].set_data_type(TensorProto::FLOAT); - shapes[0].add_dims(10); - shapes[0].add_dims(10); - shapes[1].set_data_type(TensorProto::FLOAT); - shapes[1].add_dims(10); - shapes[1].add_dims(10); - EXPECT_EQ(2000, schema->InferCost(def, shapes).flops); -} - -} // namespace caffe2 diff --git a/caffe2/core/operator_test.cc b/caffe2/core/operator_test.cc deleted file mode 100644 index afebacc71dc3..000000000000 --- a/caffe2/core/operator_test.cc +++ /dev/null @@ -1,634 +0,0 @@ -#include - -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" -#include - -namespace caffe2 { - -// Since we instantiate this on CPU and GPU (but don't want a -// CUDAContext dependency, we use OperatorBase. In general, you only -// want to inherit from Operator in your code. -class JustTest : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - virtual string type() { - return "base"; - } -}; - -class JustTestAndNeverConstructs : public JustTest { - public: - JustTestAndNeverConstructs(const OperatorDef& def, Workspace* ws) - : JustTest(def, ws) { - throw UnsupportedOperatorFeature("I just don't construct."); - } - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "FOO"; - } -}; - -class JustTestAndDoesConstruct : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "BAR"; - } -}; - -class JustTestWithSomeOutput : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - *OperatorBase::Output(0) = 5; - return true; - } - string type() override { - return "SETTING_SOME_OUTPUT"; - } -}; - -OPERATOR_SCHEMA(JustTest).NumInputs(0, 1).NumOutputs(0, 1); -OPERATOR_SCHEMA(JustTestCPUOnly).NumInputs(0, 1).NumOutputs(0, 1); -OPERATOR_SCHEMA(JustTestWithSomeOutput); - -REGISTER_CPU_OPERATOR(JustTest, JustTest); -REGISTER_CPU_OPERATOR(JustTestCPUOnly, JustTest); -REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, FOO, JustTestAndNeverConstructs); -REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, BAR, JustTestAndDoesConstruct); -REGISTER_CPU_OPERATOR_WITH_ENGINE(JustTest, BAZ, JustTestAndDoesConstruct); -REGISTER_CUDA_OPERATOR(JustTest, JustTest); -REGISTER_CPU_OPERATOR(JustTestWithSomeOutput, JustTestWithSomeOutput); - -TEST(OperatorTest, DeviceTypeRegistryWorks) { - EXPECT_EQ(gDeviceTypeRegistry()->count(CPU), 1); -} - -TEST(OperatorTest, RegistryWorks) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // After introducing events, CUDA operator creation has to have CUDA compiled - // as it needs to instantiate an Event object with CUDAContext. Thus we will - // guard this test below. - if (HasCudaRuntime()) { - op_def.mutable_device_option()->set_device_type(PROTO_CUDA); - op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - } -} - -TEST(OperatorTest, RegistryWrongDevice) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTypeCPUOnly"); - op_def.mutable_device_option()->set_device_type(PROTO_CUDA); - try { - CreateOperator(op_def, &ws); - LOG(FATAL) << "No exception was thrown"; - } catch (const std::exception& e) { - LOG(INFO) << "Exception " << e.what(); - } -} - -TEST(OperatorTest, ExceptionWorks) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("ThrowException"); - unique_ptr op = CreateOperator(op_def, &ws); - // Note: we do not do ASSERT_THROW in order to print out - // the error message for inspection. - try { - op->Run(); - // This should not happen - exception should throw above. - LOG(FATAL) << "This should not happen."; - } catch (const EnforceNotMet& err) { - LOG(INFO) << err.what(); - } - try { - op->RunAsync(); - // This should not happen - exception should throw above. - LOG(FATAL) << "This should not happen."; - } catch (const EnforceNotMet& err) { - LOG(INFO) << err.what(); - } -} - -TEST(OperatorTest, FallbackIfEngineDoesNotBuild) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - op_def.set_engine("FOO"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "base"); -} - -TEST(OperatorTest, MultipleEngineChoices) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - op_def.set_engine("FOO,BAR"); - unique_ptr op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); -} - -TEST(OperatorTest, CannotUseUninitializedBlob) { - Workspace ws; - OperatorDef op_def; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(CreateOperator(op_def, &ws), EnforceNotMet); -} - -TEST(OperatorTest, TestParameterAccess) { - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - AddArgument("arg0", 0.1, &op_def); - AddArgument>("arg1", vector{1, 2}, &op_def); - AddArgument("arg2", "argstring", &op_def); - EXPECT_NE(ws.CreateBlob("input"), nullptr); - OperatorBase op(op_def, &ws); - EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); - vector i = op.GetRepeatedArgument("arg1"); - EXPECT_EQ(i.size(), 2); - EXPECT_EQ(i[0], 1); - EXPECT_EQ(i[1], 2); - EXPECT_EQ(op.GetSingleArgument("arg2", "default"), "argstring"); - auto default1 = op.GetRepeatedArgument("arg3", {2, 3}); - EXPECT_EQ(default1.size(), 2); - EXPECT_EQ(default1[0], 2); - EXPECT_EQ(default1[1], 3); - auto default2 = op.GetRepeatedArgument("arg4"); - EXPECT_EQ(default2.size(), 0); -} - -TEST(OperatorTest, CannotAccessParameterWithWrongType) { - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - AddArgument("arg0", 0.1f, &op_def); - EXPECT_NE(ws.CreateBlob("input"), nullptr); - OperatorBase op(op_def, &ws); - EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(op.GetSingleArgument("arg0", 0), EnforceNotMet); -} - -#if GTEST_HAS_DEATH_TEST -TEST(OperatorDeathTest, DISABLED_CannotAccessRepeatedParameterWithWrongType) { - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - AddArgument>("arg0", vector{0.1f}, &op_def); - EXPECT_NE(ws.CreateBlob("input"), nullptr); - OperatorBase op(op_def, &ws); - auto args = op.GetRepeatedArgument("arg0"); - EXPECT_EQ(args.size(), 1); - EXPECT_FLOAT_EQ(args[0], 0.1f); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_DEATH(op.GetRepeatedArgument("arg0"), - "Argument does not have the right field: expected ints"); -} -#endif - -TEST(OperatorTest, TestDefaultValue) { - OperatorDef op_def; - Workspace ws; - OperatorBase op(op_def, &ws); - EXPECT_FLOAT_EQ(op.GetSingleArgument("arg-nonexisting", 0.5f), 0.5f); -} - -TEST(OperatorTest, TestSetUp) { - Workspace ws; - OperatorDef op_def; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("output"); - EXPECT_NE(nullptr, ws.CreateBlob("input")); - unique_ptr op(CreateOperator(op_def, &ws)); - EXPECT_NE(nullptr, op.get()); - EXPECT_TRUE(ws.HasBlob("output")); -} - -TEST(OperatorTest, TestSetUpInputOutputCount) { - Workspace ws; - OperatorDef op_def; - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_input("input2"); - op_def.add_output("output"); - EXPECT_NE(nullptr, ws.CreateBlob("input")); - EXPECT_NE(nullptr, ws.CreateBlob("input2")); -#ifndef CAFFE2_NO_OPERATOR_SCHEMA - // JustTest will only accept one single input. - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(CreateOperator(op_def, &ws)); -#endif - - op_def.clear_input(); - op_def.add_input("input"); - op_def.add_output("output2"); -#ifndef CAFFE2_NO_OPERATOR_SCHEMA - // JustTest will only produce one single output. - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(CreateOperator(op_def, &ws)); -#endif -} - -TEST(OperatorTest, TestOutputValues) { - NetDef net_def; - net_def.set_name("NetForTest"); - OperatorDef op_def; - Workspace ws; - op_def.set_name("JustTest1"); - op_def.set_type("JustTestWithSomeOutput"); - op_def.add_output("output"); - // JustTest will only produce one single output. - net_def.add_op()->CopyFrom(op_def); - unique_ptr net(CreateNet(net_def, &ws)); - EXPECT_TRUE(net->Run()); - EXPECT_TRUE(ws.HasBlob("output")); - EXPECT_EQ(ws.GetBlob("output")->Get(), 5); -} - -NetDef GetNetDefForTest() { - NetDef net_def; - OperatorDef op_def; - net_def.set_name("NetForTest"); - op_def.set_name("JustTest0"); - op_def.set_type("JustTest"); - op_def.add_input("input"); - op_def.add_output("hidden"); - net_def.add_op()->CopyFrom(op_def); - op_def.set_name("JustTest1"); - op_def.set_input(0, "hidden"); - op_def.set_output(0, "output"); - net_def.add_op()->CopyFrom(op_def); - return net_def; -} - -TEST(NetTest, TestScaffoldingSimpleNet) { - NetDef net_def = GetNetDefForTest(); - net_def.set_type("simple"); - Workspace ws; - EXPECT_NE(nullptr, ws.CreateBlob("input")); - unique_ptr net(CreateNet(net_def, &ws)); - EXPECT_NE(nullptr, net.get()); - EXPECT_TRUE(ws.HasBlob("input")); - EXPECT_TRUE(ws.HasBlob("hidden")); - EXPECT_TRUE(ws.HasBlob("output")); - EXPECT_TRUE(net->Run()); -} - -TEST(NetTest, TestScaffoldingDAGNet) { - NetDef net_def = GetNetDefForTest(); - net_def.set_type("dag"); - net_def.set_num_workers(1); - Workspace ws; - EXPECT_NE(nullptr, ws.CreateBlob("input")); - unique_ptr net(CreateNet(net_def, &ws)); - EXPECT_NE(nullptr, net.get()); - EXPECT_TRUE(ws.HasBlob("input")); - EXPECT_TRUE(ws.HasBlob("hidden")); - EXPECT_TRUE(ws.HasBlob("output")); - EXPECT_TRUE(net->Run()); -} - -class FooGradientOp : public JustTest { - public: - using JustTest::JustTest; - string type() override { - return "FooGradient"; - } -}; - -class FooGradientDummyEngineOp : public JustTest { - public: - using JustTest::JustTest; - string type() override { - return "FooGradientDummyEngine"; - } -}; - -class GetFooGradient : public GradientMakerBase { - using GradientMakerBase::GradientMakerBase; - vector GetGradientDefs() override { - return vector{ - CreateOperatorDef( - "FooGradient", "", - std::vector{GO(0)}, - std::vector{GI(0)})}; - } -}; - -GRADIENT_OPERATOR_SCHEMA(FooGradient).NumInputs(1).NumOutputs(1); -REGISTER_CPU_GRADIENT_OPERATOR(FooGradient, FooGradientOp) -REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE( - FooGradient, - DUMMY_ENGINE, - FooGradientDummyEngineOp) -REGISTER_GRADIENT(Foo, GetFooGradient); - -TEST(OperatorGradientRegistryTest, GradientSimple) { - Argument arg = MakeArgument("arg", 1); - DeviceOption option; - option.set_device_type(PROTO_CPU); - OperatorDef def = CreateOperatorDef( - "Foo", "", std::vector{"in"}, std::vector{"out"}, - std::vector{arg}, option, "DUMMY_ENGINE"); - vector g_output(1); - g_output[0].dense_ = "out_grad"; - GradientOpsMeta meta = GetGradientForOp(def, g_output); - // Check the names, input and output. - EXPECT_EQ(meta.ops_.size(), 1); - const OperatorDef& grad_op_def = meta.ops_[0]; - EXPECT_EQ(grad_op_def.type(), "FooGradient"); - EXPECT_EQ(grad_op_def.name(), ""); - EXPECT_EQ(grad_op_def.input_size(), 1); - EXPECT_EQ(grad_op_def.output_size(), 1); - EXPECT_EQ(grad_op_def.input(0), "out_grad"); - EXPECT_EQ(grad_op_def.output(0), "in_grad"); - // Checks the engine, device option and arguments. - EXPECT_EQ(grad_op_def.engine(), "DUMMY_ENGINE"); - EXPECT_EQ(grad_op_def.device_option().device_type(), PROTO_CPU); - EXPECT_EQ(grad_op_def.arg_size(), 1); - EXPECT_EQ( - grad_op_def.arg(0).SerializeAsString(), - MakeArgument("arg", 1).SerializeAsString()); - // Checks the gradient name for input. - EXPECT_EQ(meta.g_input_.size(), 1); - EXPECT_TRUE(meta.g_input_[0].IsDense()); - EXPECT_EQ(meta.g_input_[0].dense_, "in_grad"); - - Workspace ws; - EXPECT_NE(ws.CreateBlob("out_grad"), nullptr); - unique_ptr grad_op = CreateOperator(grad_op_def, &ws); - EXPECT_NE(nullptr, grad_op.get()); - EXPECT_EQ( - static_cast(grad_op.get())->type(), "FooGradientDummyEngine"); -} - -TEST(EnginePrefTest, PerOpEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAR"}}}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - - // Invalid operator type - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW( - SetPerOpEnginePref({{CPU, {{"NO_EXIST", {"BAR"}}}}}), EnforceNotMet); -} - -TEST(EnginePrefTest, GlobalEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetGlobalEnginePref({{CPU, {"FOO", "BAR"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetGlobalEnginePref({}); - - SetGlobalEnginePref({{CPU, {"FOO"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ(static_cast(op.get())->type(), "base"); - } - // clear - SetGlobalEnginePref({}); - - // Invalid device type - // This check is no longer necessary with the enum class - // ASSERT_THROW(SetGlobalEnginePref({{8888, {"FOO"}}}), EnforceNotMet); -} - -TEST(EnginePrefTest, GlobalEnginePrefAndPerOpEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAR"}}}}}); - SetGlobalEnginePref({{CPU, {"BAZ"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // per op pref takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -TEST(EnginePrefTest, GlobalEnginePrefAndPerOpEnginePrefAndOpDef) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - op_def.set_engine("BAR"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAZ"}}}}}); - SetGlobalEnginePref({{CPU, {"BAZ"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // operator_def takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -TEST(EnginePrefTest, SetOpEnginePref) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"BAZ"}}}}}); - SetOpEnginePref("JustTest", {{CPU, {"BAR"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // operator_def takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "BAR"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -TEST(EnginePrefTest, SetDefaultEngine) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTest"); - - SetPerOpEnginePref({{CPU, {{"JustTest", {"DEFAULT"}}}}}); - SetGlobalEnginePref({{CPU, {"BAR"}}}); - { - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - // operator_def takes precedence - EXPECT_EQ(static_cast(op.get())->type(), "base"); - } - // clear - SetPerOpEnginePref({}); - SetGlobalEnginePref({}); -} - -class JustTestWithRequiredArg : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "JustTestWithRequiredArg"; - } -}; - -REGISTER_CPU_OPERATOR(JustTestWithRequiredArg, JustTestWithRequiredArg); -OPERATOR_SCHEMA(JustTestWithRequiredArg) - .NumInputs(0, 1) - .NumOutputs(0, 1) - .Arg("test_arg", "this arg is required", true); - -TEST(RequiredArg, Basic) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTestWithRequiredArg"); - - { - try { - CreateOperator(op_def, &ws); - LOG(FATAL) << "No exception was thrown"; - } catch (const std::exception& e) { - LOG(INFO) << "Exception thrown (expected): " << e.what(); - } - } - - { - op_def.add_arg()->CopyFrom(MakeArgument("test_arg", 1)); - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ( - static_cast(op.get())->type(), "JustTestWithRequiredArg"); - } -} - -class JustTestWithStandardIsTestArg : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "JustTestWithStandardIsTestArg"; - } -}; - -REGISTER_CPU_OPERATOR( - JustTestWithStandardIsTestArg, - JustTestWithStandardIsTestArg); -OPERATOR_SCHEMA(JustTestWithStandardIsTestArg) - .NumInputs(0, 1) - .NumOutputs(0, 1) - .ArgIsTest("this is_test arg is required"); - -TEST(IsTestArg, standard) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTestWithStandardIsTestArg"); - - { - try { - CreateOperator(op_def, &ws); - LOG(FATAL) << "No exception was thrown"; - } catch (const std::exception& e) { - LOG(INFO) << "Exception thrown (expected): " << e.what(); - } - } - - { - op_def.add_arg()->CopyFrom(MakeArgument(OpSchema::Arg_IsTest, 1)); - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ( - static_cast(op.get())->type(), - "JustTestWithStandardIsTestArg"); - } -} - -class JustTestWithNonStandardIsTestArg : public JustTest { - public: - using JustTest::JustTest; - bool Run(int /* unused */ /*stream_id*/) override { - return true; - } - string type() override { - return "JustTestWithNonStandardIsTestArg"; - } -}; - -REGISTER_CPU_OPERATOR( - JustTestWithNonStandardIsTestArg, - JustTestWithNonStandardIsTestArg); -OPERATOR_SCHEMA(JustTestWithNonStandardIsTestArg) - .NumInputs(0, 1) - .NumOutputs(0, 1) - .Arg(OpSchema::Arg_IsTest, "this is_test arg is not required"); - -TEST(IsTestArg, non_standard) { - OperatorDef op_def; - Workspace ws; - op_def.set_type("JustTestWithNonStandardIsTestArg"); - - const auto op = CreateOperator(op_def, &ws); - EXPECT_NE(nullptr, op.get()); - EXPECT_EQ( - static_cast(op.get())->type(), - "JustTestWithNonStandardIsTestArg"); -} - -} // namespace caffe2 diff --git a/caffe2/core/parallel_net_test.cc b/caffe2/core/parallel_net_test.cc deleted file mode 100644 index 7b17faba3150..000000000000 --- a/caffe2/core/parallel_net_test.cc +++ /dev/null @@ -1,322 +0,0 @@ -#include // NOLINT -#include // NOLINT - -#include -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -// When measuring time, we relax the measured time by +- 40ms. -#ifndef _WIN32 -const int kTimeThreshold = 40; -#else -// Even more so on Windows -const int kTimeThreshold = 50; -#endif - -// SleepOp basically sleeps for a given number of seconds. -// We allow arbitrary inputs and at most one output so that we can -// test scaffolding of networks. If the output is 1, it will be filled with -// vector with two elements: start time and end time. -class SleepOp final : public Operator { - public: - SleepOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws), - ms_(OperatorBase::GetSingleArgument("ms", 1000)) { - TORCH_DCHECK_GT(ms_, 0); - TORCH_DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?"; - } - - bool RunOnDevice() override { - auto start = std::chrono::high_resolution_clock::now(); - std::this_thread::sleep_for(std::chrono::milliseconds(ms_)); - auto end = std::chrono::high_resolution_clock::now(); - if (OperatorBase::OutputSize()) { - vector* output = OperatorBase::Output>(0); - output->resize(2); - (*output)[0] = start.time_since_epoch().count(); - (*output)[1] = end.time_since_epoch().count(); - } - return true; - } - - private: - int ms_; -}; - -OPERATOR_SCHEMA(Sleep).NumInputs(0, INT_MAX).NumOutputs(0, 1); - -REGISTER_CPU_OPERATOR(Sleep, SleepOp); -REGISTER_CUDA_OPERATOR(Sleep, SleepOp); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefString[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " output: \"sleep3\"" - " name: \"sleep3\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -namespace { -// Run a network and get its duration in milliseconds. -int RunNetAndGetDuration(const string& net_def_str, const string& type) { - NetDef net_def; - CAFFE_ENFORCE(TextFormat::ParseFromString(net_def_str, &net_def)); - net_def.set_type(type); - Workspace ws; - unique_ptr net(CreateNet(net_def, &ws)); - CAFFE_ENFORCE(net.get() != nullptr); - // Run once to kick in potential initialization (can be slower) - CAFFE_ENFORCE(net->Run()); - // Now run and time it - auto start_time = std::chrono::system_clock::now(); - CAFFE_ENFORCE(net->Run()); - // Inspect the time - it should be around 200 milliseconds, since sleep3 can - // run in parallel with sleep1 and sleep2. - auto duration = std::chrono::duration_cast( - std::chrono::system_clock::now() - start_time); - int milliseconds = duration.count(); - return milliseconds; -} -} // namespace - -TEST(DAGNetTest, TestDAGNetTiming) { - int ms = RunNetAndGetDuration(string(kSleepNetDefString), "dag"); - EXPECT_NEAR(ms, 200, kTimeThreshold); -} - -// For sanity check, we also test the sequential time - it should take 0.35 -// seconds instead since everything has to be sequential. -TEST(SimpleNetTest, TestSimpleNetTiming) { - int ms = RunNetAndGetDuration(string(kSleepNetDefString), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has two operators reading the same blob at the same time. This -// should not change anything and the DAG should still make sleep2 and sleep3 -// run in parallel. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringReadAfterRead[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep3\"" - " name: \"sleep3\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -TEST(DAGNetTest, TestDAGNetTimingReadAfterRead) { - int ms = RunNetAndGetDuration(string(kSleepNetDefStringReadAfterRead), "dag"); - EXPECT_NEAR(ms, 250, kTimeThreshold); -} - -// For sanity check, we also test the sequential time - it should take 0.35 -// seconds instead since everything has to be sequential. -TEST(SimpleNetTest, TestSimpleNetTimingReadAfterRead) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringReadAfterRead), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has two operators writing out the sleep2 blob. As a result, the -// operator sleep2-again creates a write after write dependency and the whole -// process should be sequential. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringWriteAfterWrite[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " output: \"sleep2\"" - " name: \"sleep2-again\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -TEST(DAGNetTest, TestDAGNetTimingWriteAfterWrite) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterWrite), "dag"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -TEST(SimpleNetTest, TestSimpleNetTimingWriteAfterWrite) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterWrite), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has an operator writing to sleep1 while another operator is -// accessing it. As a result, the operator sleep1-again creates a write after -// read dependency and the whole process should be sequential. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringWriteAfterRead[] = - " name: \"sleepnet\"" - " type: \"dag\"" - " num_workers: 2" - " op {" - " output: \"sleep1\"" - " name: \"sleep1\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " input: \"sleep1\"" - " output: \"sleep2\"" - " name: \"sleep2\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 100" - " }" - " }" - " op {" - " output: \"sleep1\"" - " name: \"sleep1-again\"" - " type: \"Sleep\"" - " arg {" - " name: \"ms\"" - " i: 150" - " }" - " }"; - -TEST(DAGNetTest, TestDAGNetTimingWriteAfterRead) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterRead), "dag"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -TEST(SimpleNetTest, TestSimpleNetTimingWriteAfterRead) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringWriteAfterRead), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -// This network has an operator writing to sleep1 while another -// operator has a control dependency on it. As a result, the operator -// sleep1-again creates a write after read dependency and the whole -// process should be sequential. -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -const char kSleepNetDefStringControlDependency[] = R"DOC( - name: "sleepnet" - type: "dag" - num_workers: 2 - op { - output: "sleep1" - name: "sleep1" - type: "Sleep" - arg { - name: "ms" - i: 100 - } - } - op { - control_input: "sleep1" - output: "sleep2" - name: "sleep2" - type: "Sleep" - arg { - name: "ms" - i: 100 - } - } - op { - output: "sleep1" - name: "sleep1-again" - type: "Sleep" - arg { - name: "ms" - i: 150 - } - } -)DOC"; - -TEST(DAGNetTest, TestDAGNetTimingControlDependency) { - int ms = - RunNetAndGetDuration(string(kSleepNetDefStringControlDependency), "dag"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -TEST(SimpleNetTest, TestSimpleNetTimingControlDependency) { - int ms = RunNetAndGetDuration( - string(kSleepNetDefStringControlDependency), "simple"); - EXPECT_NEAR(ms, 350, kTimeThreshold); -} - -} // namespace caffe2 diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc deleted file mode 100644 index 7a54403805ec..000000000000 --- a/caffe2/core/plan_executor_test.cc +++ /dev/null @@ -1,414 +0,0 @@ -#ifndef ANDROID - -#include -#include "caffe2/core/init.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/plan_executor.h" - -namespace caffe2 { - -TEST(PlanExecutorTest, EmptyPlan) { - PlanDef plan_def; - Workspace ws; - EXPECT_TRUE(ws.RunPlan(plan_def)); -} - -namespace { -static std::atomic cancelCount{0}; -static std::atomic stuckRun{false}; -} // namespace - -class StuckBlockingOp final : public Operator { - public: - StuckBlockingOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // StuckBlockingOp runs and notifies ErrorOp. - stuckRun = true; - - while (!cancelled_) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - - return true; - } - - void Cancel() override { - LOG(INFO) << "cancelled StuckBlockingOp."; - cancelCount += 1; - cancelled_ = true; - } - - private: - std::atomic cancelled_{false}; -}; - -REGISTER_CPU_OPERATOR(StuckBlocking, StuckBlockingOp); -OPERATOR_SCHEMA(StuckBlocking).NumInputs(0).NumOutputs(0); - -class NoopOp final : public Operator { - public: - NoopOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // notify Error op we've ran. - stuckRun = true; - return true; - } -}; - -REGISTER_CPU_OPERATOR(Noop, NoopOp); -OPERATOR_SCHEMA(Noop).NumInputs(0).NumOutputs(0); - - -class StuckAsyncOp final : public Operator { - public: - StuckAsyncOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // notify Error op we've ran. - stuckRun = true; - // explicitly don't call SetFinished so this gets stuck - return true; - } - - void CancelAsyncCallback() override { - LOG(INFO) << "cancelled"; - cancelCount += 1; - } - - bool HasAsyncPart() const override { - return true; - } -}; - -REGISTER_CPU_OPERATOR(StuckAsync, StuckAsyncOp); -OPERATOR_SCHEMA(StuckAsync).NumInputs(0).NumOutputs(0); - -class TestError : public std::exception { - const char* what() const noexcept override { - return "test error"; - } -}; - -class ErrorOp final : public Operator { - public: - ErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // Wait for StuckAsyncOp or StuckBlockingOp to run first. - while (!stuckRun) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - throw TestError(); - return true; - } -}; - -REGISTER_CPU_OPERATOR(Error, ErrorOp); -OPERATOR_SCHEMA(Error).NumInputs(0).NumOutputs(0); - -static std::atomic blockingErrorRuns{0}; -class BlockingErrorOp final : public Operator { - public: - BlockingErrorOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) {} - - bool RunOnDevice() override { - // First n op executions should block and then start throwing errors. - if (blockingErrorRuns.fetch_sub(1) >= 1) { - LOG(INFO) << "blocking"; - while (true) { - std::this_thread::sleep_for(std::chrono::hours(10)); - } - } else { - LOG(INFO) << "throwing"; - throw TestError(); - } - } -}; - -REGISTER_CPU_OPERATOR(BlockingError, BlockingErrorOp); -OPERATOR_SCHEMA(BlockingError).NumInputs(0).NumOutputs(0); - -PlanDef parallelErrorPlan() { - PlanDef plan_def; - - auto* stuck_net = plan_def.add_network(); - stuck_net->set_name("stuck_net"); - stuck_net->set_type("async_scheduling"); - { - auto* op = stuck_net->add_op(); - op->set_type("StuckAsync"); - } - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - error_net->set_type("async_scheduling"); - { - auto op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - substep->add_network(stuck_net->name()); - } - { - auto* substep = execution_step->add_substep(); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -PlanDef parallelErrorPlanWithCancellableStuckNet() { - // Set a plan with two nets: one stuck net with blocking operator that never - // returns; one error net with error op that throws. - PlanDef plan_def; - - auto* stuck_blocking_net = plan_def.add_network(); - stuck_blocking_net->set_name("stuck_blocking_net"); - { - auto* op = stuck_blocking_net->add_op(); - op->set_type("StuckBlocking"); - } - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - { - auto* op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - substep->add_network(stuck_blocking_net->name()); - } - { - auto* substep = execution_step->add_substep(); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -PlanDef reporterErrorPlanWithCancellableStuckNet() { - // Set a plan with a concurrent net and a reporter net: one stuck net with - // blocking operator that never returns; one reporter net with error op - // that throws. - PlanDef plan_def; - - auto* stuck_blocking_net = plan_def.add_network(); - stuck_blocking_net->set_name("stuck_blocking_net"); - { - auto* op = stuck_blocking_net->add_op(); - op->set_type("StuckBlocking"); - } - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - { - auto* op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - substep->add_network(stuck_blocking_net->name()); - } - { - auto* substep = execution_step->add_substep(); - substep->set_run_every_ms(1); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -struct HandleExecutorThreadExceptionsGuard { - HandleExecutorThreadExceptionsGuard(int timeout = 60) { - globalInit({ - "caffe2", - "--caffe2_handle_executor_threads_exceptions=1", - "--caffe2_plan_executor_exception_timeout=" + - caffe2::to_string(timeout), - }); - } - - ~HandleExecutorThreadExceptionsGuard() { - globalInit({ - "caffe2", - }); - } - - HandleExecutorThreadExceptionsGuard( - const HandleExecutorThreadExceptionsGuard&) = delete; - void operator=(const HandleExecutorThreadExceptionsGuard&) = delete; - - private: - void globalInit(std::vector args) { - std::vector args_ptrs; - for (auto& arg : args) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,performance-inefficient-vector-operation) - args_ptrs.push_back(const_cast(arg.data())); - } - char** new_argv = args_ptrs.data(); - int new_argc = args.size(); - CAFFE_ENFORCE(GlobalInit(&new_argc, &new_argv)); - } -}; - -TEST(PlanExecutorTest, ErrorAsyncPlan) { - HandleExecutorThreadExceptionsGuard guard; - - cancelCount = 0; - PlanDef plan_def = parallelErrorPlan(); - Workspace ws; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_EQ(cancelCount, 1); -} - -// death tests not supported on mobile -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -TEST(PlanExecutorTest, BlockingErrorPlan) { - // TSAN doesn't play nicely with death tests -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) - return; -#endif -#endif - - testing::GTEST_FLAG(death_test_style) = "threadsafe"; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_DEATH( - [] { - HandleExecutorThreadExceptionsGuard guard(/*timeout=*/1); - - PlanDef plan_def; - - std::string plan_def_template = R"DOC( - network { - name: "net" - op { - type: "BlockingError" - } - } - execution_step { - num_concurrent_instances: 2 - substep { - network: "net" - } - } - )DOC"; - - CAFFE_ENFORCE( - TextFormat::ParseFromString(plan_def_template, &plan_def)); - Workspace ws; - blockingErrorRuns = 1; - ws.RunPlan(plan_def); - FAIL() << "shouldn't have reached this point"; - }(), - "failed to stop concurrent workers after exception: test error"); -} -#endif - -TEST(PlanExecutorTest, ErrorPlanWithCancellableStuckNet) { - HandleExecutorThreadExceptionsGuard guard; - - cancelCount = 0; - PlanDef plan_def = parallelErrorPlanWithCancellableStuckNet(); - Workspace ws; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_EQ(cancelCount, 1); -} - -TEST(PlanExecutorTest, ReporterErrorPlanWithCancellableStuckNet) { - HandleExecutorThreadExceptionsGuard guard; - - cancelCount = 0; - PlanDef plan_def = reporterErrorPlanWithCancellableStuckNet(); - Workspace ws; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_EQ(cancelCount, 1); -} - -PlanDef shouldStopWithCancelPlan() { - // Set a plan with a looping net with should_stop_blob set and a concurrent - // net that throws an error. The error should cause should_stop to return - // false and end the concurrent net. - PlanDef plan_def; - - auto* should_stop_net = plan_def.add_network(); - { - auto* op = should_stop_net->add_op(); - op->set_type("Noop"); - } - should_stop_net->set_name("should_stop_net"); - should_stop_net->set_type("async_scheduling"); - - auto* error_net = plan_def.add_network(); - error_net->set_name("error_net"); - { - auto* op = error_net->add_op(); - op->set_type("Error"); - } - - auto* execution_step = plan_def.add_execution_step(); - execution_step->set_concurrent_substeps(true); - { - auto* substep = execution_step->add_substep(); - execution_step->set_concurrent_substeps(true); - substep->set_name("concurrent_should_stop"); - substep->set_should_stop_blob("should_stop_blob"); - auto* substep2 = substep->add_substep(); - substep2->set_name("should_stop_net"); - substep2->add_network(should_stop_net->name()); - substep2->set_num_iter(10); - } - { - auto* substep = execution_step->add_substep(); - substep->set_name("error_step"); - substep->add_network(error_net->name()); - } - - return plan_def; -} - -TEST(PlanExecutorTest, ShouldStopWithCancel) { - HandleExecutorThreadExceptionsGuard guard; - - stuckRun = false; - PlanDef plan_def = shouldStopWithCancelPlan(); - Workspace ws; - - Blob* blob = ws.CreateBlob("should_stop_blob"); - Tensor* tensor = BlobGetMutableTensor(blob, CPU); - const vector& shape{1}; - tensor->Resize(shape); - tensor->mutable_data()[0] = false; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_THROW(ws.RunPlan(plan_def), TestError); - ASSERT_TRUE(stuckRun); -} - -} // namespace caffe2 - -#endif diff --git a/caffe2/core/scope_guard.h b/caffe2/core/scope_guard.h deleted file mode 100644 index ee412a424de4..000000000000 --- a/caffe2/core/scope_guard.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2016 Facebook - * @author Tudor Bosman (tudorb@fb.com) - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace caffe2 { - -// Copied from folly/ScopeGuard.h - -namespace detail { - -class ScopeGuardImplBase { - public: - void dismiss() noexcept { - dismissed_ = true; - } - - protected: - ScopeGuardImplBase() noexcept : dismissed_(false) {} - - static ScopeGuardImplBase makeEmptyScopeGuard() noexcept { - return ScopeGuardImplBase{}; - } - - template - static const T& asConst(const T& t) noexcept { - return t; - } - - bool dismissed_; -}; - -template -class ScopeGuardImpl : public ScopeGuardImplBase { - public: - explicit ScopeGuardImpl(FunctionType& fn) noexcept( - std::is_nothrow_copy_constructible::value) - : ScopeGuardImpl( - asConst(fn), - makeFailsafe(std::is_nothrow_copy_constructible{}, - &fn)) {} - - explicit ScopeGuardImpl(const FunctionType& fn) noexcept( - std::is_nothrow_copy_constructible::value) - : ScopeGuardImpl( - fn, - makeFailsafe(std::is_nothrow_copy_constructible{}, - &fn)) {} - - explicit ScopeGuardImpl(FunctionType&& fn) noexcept( - std::is_nothrow_move_constructible::value) - : ScopeGuardImpl( - std::move_if_noexcept(fn), - makeFailsafe(std::is_nothrow_move_constructible{}, - &fn)) {} - - ScopeGuardImpl(ScopeGuardImpl&& other) noexcept( - std::is_nothrow_move_constructible::value) - : function_(std::move_if_noexcept(other.function_)) { - // If the above line attempts a copy and the copy throws, other is - // left owning the cleanup action and will execute it (or not) depending - // on the value of other.dismissed_. The following lines only execute - // if the move/copy succeeded, in which case *this assumes ownership of - // the cleanup action and dismisses other. - dismissed_ = other.dismissed_; - other.dismissed_ = true; - } - - ~ScopeGuardImpl() noexcept { - if (!dismissed_) { - execute(); - } - } - - private: - static ScopeGuardImplBase makeFailsafe(std::true_type, const void*) noexcept { - return makeEmptyScopeGuard(); - } - - template - static auto makeFailsafe(std::false_type, Fn* fn) noexcept - -> ScopeGuardImpl { - return ScopeGuardImpl{std::ref(*fn)}; - } - - template - explicit ScopeGuardImpl(Fn&& fn, ScopeGuardImplBase&& failsafe) - : ScopeGuardImplBase{}, function_(std::forward(fn)) { - failsafe.dismiss(); - } - - void* operator new(std::size_t) = delete; - - void execute() noexcept { function_(); } - - FunctionType function_; -}; - -template -using ScopeGuardImplDecay = ScopeGuardImpl::type>; - -} // namespace detail - -/** - * ScopeGuard is a general implementation of the "Initialization is - * Resource Acquisition" idiom. Basically, it guarantees that a function - * is executed upon leaving the current scope unless otherwise told. - * - * The MakeGuard() function is used to create a new ScopeGuard object. - * It can be instantiated with a lambda function, a std::function, - * a functor, or a void(*)() function pointer. - * - * - * Usage example: Add a friend to memory iff it is also added to the db. - * - * void User::addFriend(User& newFriend) { - * // add the friend to memory - * friends_.push_back(&newFriend); - * - * // If the db insertion that follows fails, we should - * // remove it from memory. - * auto guard = MakeGuard([&] { friends_.pop_back(); }); - * - * // this will throw an exception upon error, which - * // makes the ScopeGuard execute UserCont::pop_back() - * // once the Guard's destructor is called. - * db_->addFriend(GetName(), newFriend.GetName()); - * - * // an exception was not thrown, so don't execute - * // the Guard. - * guard.dismiss(); - * } - * - * Examine ScopeGuardTest.cpp for some more sample usage. - * - * Stolen from: - * Andrei's and Petru Marginean's CUJ article: - * http://drdobbs.com/184403758 - * and the loki library: - * http://loki-lib.sourceforge.net/index.php?n=Idioms.ScopeGuardPointer - * and triendl.kj article: - * http://www.codeproject.com/KB/cpp/scope_guard.aspx - */ -template -detail::ScopeGuardImplDecay MakeGuard(F&& f) noexcept( - noexcept(detail::ScopeGuardImplDecay(static_cast(f)))) { - return detail::ScopeGuardImplDecay(static_cast(f)); -} - -} // namespaces diff --git a/caffe2/core/serialization_test.cc b/caffe2/core/serialization_test.cc deleted file mode 100644 index 902a3e01e677..000000000000 --- a/caffe2/core/serialization_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -#include - -#include -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/blob_serialization.h" - -// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays -C10_DEFINE_bool( - caffe2_test_generate_unknown_dtype_blob, - false, - "Recompute and log the serialized blob data for the " - "TensorSerialization.TestUnknownDType test"); - -using namespace caffe2; - -namespace { - -// This data was computed by serializing a 10-element int32_t tensor, -// but with the data_type field set to 4567. This allows us to test the -// behavior of the code when deserializing data from a future version of the -// code that has new data types that our code does not understand. -constexpr c10::string_view kFutureDtypeBlob( - "\x0a\x09\x74\x65\x73\x74\x5f\x62\x6c\x6f\x62\x12\x06\x54\x65\x6e" - "\x73\x6f\x72\x1a\x28\x08\x0a\x08\x01\x10\xd7\x23\x22\x0a\x00\x01" - "\x02\x03\x04\x05\x06\x07\x08\x09\x3a\x09\x74\x65\x73\x74\x5f\x62" - "\x6c\x6f\x62\x42\x02\x08\x00\x5a\x04\x08\x00\x10\x0a", - 61); -// The same tensor with the data_type actually set to TensorProto_DataType_INT32 -constexpr c10::string_view kInt32DtypeBlob( - "\x0a\x09\x74\x65\x73\x74\x5f\x62\x6c\x6f\x62\x12\x06\x54\x65\x6e" - "\x73\x6f\x72\x1a\x27\x08\x0a\x08\x01\x10\x02\x22\x0a\x00\x01\x02" - "\x03\x04\x05\x06\x07\x08\x09\x3a\x09\x74\x65\x73\x74\x5f\x62\x6c" - "\x6f\x62\x42\x02\x08\x00\x5a\x04\x08\x00\x10\x0a", - 60); - -void logBlob(c10::string_view data) { - constexpr size_t kBytesPerLine = 16; - constexpr size_t kCharsPerEncodedByte = 4; - std::vector hexStr; - hexStr.resize((kBytesPerLine * kCharsPerEncodedByte) + 1); - hexStr[kBytesPerLine * kCharsPerEncodedByte] = '\0'; - size_t lineIdx = 0; - for (char c : data) { - snprintf( - hexStr.data() + (kCharsPerEncodedByte * lineIdx), - kCharsPerEncodedByte + 1, - "\\x%02x", - static_cast(c)); - ++lineIdx; - if (lineIdx >= kBytesPerLine) { - LOG(INFO) << " \"" << hexStr.data() << "\""; - lineIdx = 0; - } - } - if (lineIdx > 0) { - hexStr[lineIdx * kCharsPerEncodedByte] = '\0'; - LOG(INFO) << " \"" << hexStr.data() << "\""; - } -} - -} // namespace - -TEST(TensorSerialization, TestUnknownDType) { - // This code was used to generate the blob data listed above. - constexpr size_t kTestTensorSize = 10; - if (FLAGS_caffe2_test_generate_unknown_dtype_blob) { - Blob blob; - auto* blobTensor = BlobGetMutableTensor(&blob, CPU); - blobTensor->Resize(kTestTensorSize, 1); - auto *tensorData = blobTensor->mutable_data(); - for (unsigned n = 0; n < kTestTensorSize; ++n) { - tensorData[n] = n; - } - auto data = SerializeBlob(blob, "test_blob"); - LOG(INFO) << "test blob: size=" << data.size(); - logBlob(data); - } - - // Test deserializing the normal INT32 data, - // just to santity check that deserialization works - Blob i32Blob; - DeserializeBlob(std::string(kInt32DtypeBlob), &i32Blob); - const auto& tensor = BlobGetTensor(i32Blob, c10::DeviceType::CPU); - EXPECT_EQ(kTestTensorSize, tensor.numel()); - EXPECT_EQ(TypeMeta::Make(), tensor.dtype()); - const auto* tensor_data = tensor.template data(); - for (unsigned i = 0; i < kTestTensorSize; ++i) { - EXPECT_EQ(static_cast(i), tensor_data[i]); - } - - // Now test deserializing our blob with an unknown data type - Blob futureDtypeBlob; - try { - DeserializeBlob(std::string(kFutureDtypeBlob), &futureDtypeBlob); - FAIL() << "DeserializeBlob() should have failed"; - } catch (const std::exception& ex) { - EXPECT_STREQ( - "Cannot deserialize tensor: unrecognized data type", ex.what()); - } -} diff --git a/caffe2/core/stats_test.cc b/caffe2/core/stats_test.cc deleted file mode 100644 index ab61e7a2f84b..000000000000 --- a/caffe2/core/stats_test.cc +++ /dev/null @@ -1,151 +0,0 @@ -#include -#include -#include - -#include "caffe2/core/stats.h" -#include - -namespace caffe2 { -namespace { - -struct MyCaffeClass { - explicit MyCaffeClass(const std::string& name) : stats_(name) {} - - void tryRun(int) {} - - void run(int numRuns) { - try { - CAFFE_EVENT(stats_, num_runs, numRuns); - tryRun(numRuns); - CAFFE_EVENT(stats_, num_successes); - } catch (std::exception& e) { - CAFFE_EVENT(stats_, num_failures, 1, "arg_to_usdt", e.what()); - } - CAFFE_EVENT(stats_, usdt_only, 1, "arg_to_usdt"); - } - - private: - struct MyStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(MyStats); - CAFFE_EXPORTED_STAT(num_runs); - CAFFE_EXPORTED_STAT(num_successes); - CAFFE_EXPORTED_STAT(num_failures); - CAFFE_STAT(usdt_only); - } stats_; -}; - -ExportedStatMap filterMap( - const ExportedStatMap& map, - const ExportedStatMap& keys) { - ExportedStatMap filtered; - for (const auto& kv : map) { - if (keys.count(kv.first) > 0) { - filtered.insert(kv); - } - } - return filtered; -} - -#define EXPECT_SUBSET(map, sub) EXPECT_EQ(filterMap((map), (sub)), (sub)) - -TEST(StatsTest, StatsTestClass) { - MyCaffeClass a("first"); - MyCaffeClass b("second"); - for (int i = 0; i < 10; ++i) { - a.run(10); - b.run(5); - } - EXPECT_SUBSET( - ExportedStatMap({ - {"first/num_runs", 100}, - {"first/num_successes", 10}, - {"first/num_failures", 0}, - {"second/num_runs", 50}, - {"second/num_successes", 10}, - {"second/num_failures", 0}, - }), - toMap(StatRegistry::get().publish())); -} - -TEST(StatsTest, StatsTestDuration) { - struct TestStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(TestStats); - CAFFE_STAT(count); - CAFFE_AVG_EXPORTED_STAT(time_ns); - }; - TestStats stats("stats"); - CAFFE_DURATION(stats, time_ns) { - std::this_thread::sleep_for(std::chrono::microseconds(1)); - } - - ExportedStatList data; - StatRegistry::get().publish(data); - auto map = toMap(data); - auto countIt = map.find("stats/time_ns/count"); - auto sumIt = map.find("stats/time_ns/sum"); - EXPECT_TRUE(countIt != map.end() && sumIt != map.end()); - EXPECT_EQ(countIt->second, 1); - EXPECT_GT(sumIt->second, 0); -} - -TEST(StatsTest, StatsTestSimple) { - struct TestStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(TestStats); - CAFFE_STAT(s1); - CAFFE_STAT(s2); - CAFFE_EXPORTED_STAT(s3); - }; - TestStats i1("i1"); - TestStats i2("i2"); - CAFFE_EVENT(i1, s1); - CAFFE_EVENT(i1, s2); - CAFFE_EVENT(i1, s3, 1); - CAFFE_EVENT(i1, s3, -1); - CAFFE_EVENT(i2, s3, 2); - - ExportedStatList data; - StatRegistry::get().publish(data); - EXPECT_SUBSET(toMap(data), ExportedStatMap({{"i1/s3", 0}, {"i2/s3", 2}})); - - StatRegistry reg2; - reg2.update(data); - reg2.update(data); - - EXPECT_SUBSET( - toMap(reg2.publish(true)), ExportedStatMap({{"i1/s3", 0}, {"i2/s3", 4}})); - EXPECT_SUBSET( - toMap(reg2.publish()), ExportedStatMap({{"i1/s3", 0}, {"i2/s3", 0}})); -} - -TEST(StatsTest, StatsTestStatic) { - struct TestStats { - // NOLINTNEXTLINE(modernize-pass-by-value) - CAFFE_STAT_CTOR(TestStats); - CAFFE_STATIC_STAT(cpuUsage); - CAFFE_STATIC_STAT(memUsage); - }; - TestStats i1("i1"); - TestStats i2("i2"); - CAFFE_EVENT(i1, cpuUsage, 95); - CAFFE_EVENT(i2, memUsage, 80); - - ExportedStatList data; - StatRegistry::get().publish(data); - EXPECT_SUBSET( - toMap(data), ExportedStatMap({{"i1/cpuUsage", 95}, {"i2/memUsage", 80}})); - - CAFFE_EVENT(i1, cpuUsage, 80); - CAFFE_EVENT(i1, memUsage, 50); - CAFFE_EVENT(i2, memUsage, 90); - - StatRegistry::get().publish(data); - EXPECT_SUBSET( - toMap(data), - ExportedStatMap( - {{"i1/cpuUsage", 80}, {"i1/memUsage", 50}, {"i2/memUsage", 90}})); -} -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/storage.h b/caffe2/core/storage.h deleted file mode 100644 index e9bd6ed60c0b..000000000000 --- a/caffe2/core/storage.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef CAFFE2_CORE_STORAGE_H_ -#define CAFFE2_CORE_STORAGE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "caffe2/core/allocator.h" -#include "caffe2/core/common.h" -#include "caffe2/core/context.h" -#include "caffe2/core/flags.h" -#include "caffe2/core/logging.h" -#include - -#include -#include -#include -#include -#include -#include - -namespace caffe2 { - -using StorageImpl = at::StorageImpl; -using Storage = at::Storage; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_STORAGE_H_ diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h deleted file mode 100644 index 1171605b9f77..000000000000 --- a/caffe2/core/tensor.h +++ /dev/null @@ -1,674 +0,0 @@ -#ifndef CAFFE2_CORE_TENSOR_H_ -#define CAFFE2_CORE_TENSOR_H_ - -#include -#include "caffe2/core/storage.h" - -#include -#include -#include -#include -#include -#include - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") -#endif - -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -namespace at { -class Tensor; -}; -#endif -namespace caffe2 { - -using at::UndefinedTensorImpl; - -/** - * @brief Tensor class holds a shared pointer to the implementation TensorImpl, - * redirects API calls to TensorImpl; - * Copying of Tensor results in sharing the same underlying implementation - * object - * - * NB: See TensorImpl for documentation on these methods. - */ -class TORCH_API Tensor final { - private: - enum Unsafe { IDoWantAliasing }; - Tensor(const Tensor& other, Unsafe _) : impl_(other.getIntrusivePtr()) {} - - protected: - using TensorImplPtr = c10::intrusive_ptr; - TensorImplPtr impl_; - - void enforce_invariants(); - - public: - Tensor() : impl_() {} - - Tensor(const Tensor& t) : impl_(t.impl_) {} - Tensor& operator=(const Tensor& t) { - impl_ = t.impl_; - return *this; - } - - Tensor(Tensor&&) = default; - Tensor& operator=(Tensor&&) = default; - - operator bool() const { - return impl_.defined(); - } - - TensorImpl* unsafeGetTensorImpl() const { - return impl_.get(); - } - - TensorImpl* unsafeReleaseTensorImpl() { - return impl_.release(); - } - - Tensor UnsafeSharedInstance() const { - return Tensor(*this, IDoWantAliasing); - } - - /** - * @brief Creates a tensor of the given device type. - * - * Note that the actual data allocation is not going to be carried out until - * you resize the tensor and then call mutable_data(). - */ - explicit Tensor(at::Device device) - : impl_(c10::make_intrusive( - Storage::create_legacy(device), - c10::computeDispatchKey(c10::nullopt, at::kStrided, device), - TypeMeta())) {} - - /** - * @brief Creates a tensor of the given dimension. - * - * Note that the actual data allocation is not going to be carried out until - * the first time mutable_data() is called. - */ - explicit Tensor(at::IntArrayRef dims, DeviceType type) : Tensor(type) { - // TODO: here, we create a Storage - // and immediately discard it in Resize() since - // reset_tensor will be true and FreeMemory will be called, - // we might want to avoid creating Storage twice? - Resize(dims); - } - - // we want to preserve index information - explicit Tensor(at::IntArrayRef dims, at::Device device) : Tensor(device) { - Resize(dims); - } - - // TODO: remove? - explicit Tensor(const vector& dims, DeviceType type) : Tensor(type) { - Resize(dims); - } - - /** - * @brief: Create a Tensor of at::DeviceType `type` and initialize it with - * src Tensor - */ - Tensor(const Tensor& src, DeviceType type) : Tensor(type) { - CopyFrom(src); - } - - /** - * @brief Mutual conversion with at::Tensor - * - * The tensor will share the same instance (data, strides, sizes, etc) but - * a different subset of APIs would be available - */ -#if defined(EXPOSE_C2_OPS) || \ - !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - explicit Tensor(at::Tensor tensor); - - explicit operator at::Tensor() const&; - - explicit operator at::Tensor() &&; -#endif - - bool is_same(const Tensor& other) const noexcept { - return impl_ == other.impl_; - } - - Tensor Clone() const { - Tensor x(GetDevice()); - x.CopyFrom(*this); - return x; - } - - /** - * Clone self as a Tensor that share the same Storage, - * that is, both Tensors are views on the same Storage. - * If we change the sizes or strides of one Tensor, it - * does not affect the other Tensor that it shares Storage - * with. - * A similar yet different usage is `Tensor x = y;`, this - * will make x and y pointing to the same Tensor and resizing - * one of them will resize the other as well. - * - * TODO: Deduplicate this with THTensor_(newWithTensor) - * (exposed in ATen as at::alias but not otherwise available) - */ - Tensor Alias() const { - Tensor x(sizes(), GetDevice()); - if (!dtype_initialized()) { - C10_LOG_EVERY_MS(WARNING, 1000) - << "Cloning a tensor that don't have a data type (did you call mutable_data on the tensor?)"; - } - AT_ASSERTM( - storage_initialized(), - "Cloning a tensor that has no content and has size > 0"); - // set_storage already sets data_type_ of TensorImpl - x.impl_->set_storage_and_dtype(storage(), impl_->dtype()); - x.impl_->set_storage_offset(impl_->storage_offset()); - x.impl_->set_sizes_and_strides(sizes(), strides()); - return x; - } - - DeviceType GetDeviceType() const { - return impl_->device_type(); - } - - at::Device GetDevice() const { - return impl_.get()->device(); - } - - /** - * @brief Copies the data from a source tensor, with a context provided to - * carry out the underlying memcpy operation. This method respects - * caffe2_keep_on_shrink. - * - * After CopyFrom, this function guarantees that the destination tensor will - * have the same initialization state and dtype as src. This function - * preserves the DeviceType of the source tensor (so, e.g., if you allocate - * a tensor on CPU and then CopyFrom a CUDA tensor, that will to a - * CUDA-to-CPU transfer). - * - * 'async' parameter triggers async copy for CUDA tensors - */ - void CopyFrom(const Tensor& src, bool async = false); - - /** - * @brief Extend the outer-most dimension of this tensor - * to dimension of `num`. - */ - void ExtendTo(int64_t num, float growthPct) const { - CAFFE_ENFORCE_GE_WITH_CALLER(impl_->dim(), 1); - CAFFE_ENFORCE_GE_WITH_CALLER(growthPct, 0); - Extend(num - impl_->size(0), growthPct); - } - - void Extend(int64_t num, float growthPct) const { - impl_.get()->Extend(num, growthPct); - } - - /** - * @brief Shrinks the outer-most dimension to given size, keeping the data. - * - * This method guarantees that no re-allocations are carried out, which means - * that the extra capacity after the end of the shrunk tensor is maintained. - * Notably, this function does NOT respect caffe2_keep_on_shrink. - */ - void ShrinkTo(int64_t outer_dim) const { - CAFFE_ENFORCE_WITH_CALLER( - impl_->is_contiguous(), - "Right now ShrinkTo is only supported on contiguous Tensor."); - CAFFE_ENFORCE_WITH_CALLER(impl_->dim() >= 1, "Tensor must be at least 1D"); - CAFFE_ENFORCE_WITH_CALLER( - outer_dim <= impl_->size(0), - "New outer dimension must be smaller than current."); - CAFFE_ENFORCE( - impl_->storage().unique(), - "Can't call ShrinkTo on shared storage, please call Resize instead."); - impl_.get()->set_size(0, outer_dim); - } - - template - void ReserveSpace(const T& outer_dim) const { - impl_.get()->ReserveSpace(outer_dim); - } - - template - void Resize(Ts... dim_source) const { - impl_.get()->Resize(dim_source...); - } - - template - void Resize(const std::vector& dim_source) const { - impl_.get()->Resize(ArrayRef(dim_source)); - } - - /** - * Resize the tensor like the source tensor. Note that this is just a - * sugar wrapper that essentially calls Resize(src_tensor.dims()). - * This method respects caffe2_keep_on_shrink. - */ - inline void ResizeLike(const Tensor& src_tensor) const { - CAFFE_ENFORCE_WITH_CALLER( - src_tensor.is_contiguous(), - "Right now ResizeLike is only supported for contiguous Tensor."); - if (impl_ != src_tensor.impl_) { - impl_.get()->Resize(src_tensor.sizes()); - } - } - - inline void Reshape(const vector& dims) const { - impl_.get()->Reshape(dims); - } - - inline void Reshape(const vector& dims) const { - impl_.get()->Reshape(ToVectorint64_t(dims)); - } - - inline void FreeMemory() const { - impl_.get()->FreeMemory(); - } - - /** - * A utility function to print the debug string for the tensor. Note that this - * is very slow since it involves quite some string operations, so do not use - * it in your performance-critical code. - */ - string DebugString() const { - std::stringstream ss; - ss << "A Tensor of item size " << impl_->dtype().itemsize() << " and type " - << impl_->dtype().name() << " and dimension ("; - for (int d : impl_->sizes()) { - ss << d << ","; - } - ss << ")."; - return ss.str(); - } - - // To be deprecated - void ShareData(const Tensor& src) const { - impl_.get()->ShareData(*src.impl_.get()); - } - - /** - * @brief Shares the data with an externally managed pointer. - * - * This is similar to ShareData() but the source is a pointer with an advanced - * deleter option. In default, no deletion takes place, and one needs to make - * sure that the external memory is deallocated only after the tensor finishes - * using it. If a Deleter object is passed in, when this tensor is reallocated - * or freed, the deleter function is going to be called. - */ - template - void ShareExternalPointer( - T* src, - size_t nbytes = 0, - MemoryDeleter d = nullptr) const { - ShareExternalPointer((void*)src, caffe2::TypeMeta::Make(), nbytes, d); - } - - template - void ShareExternalPointer(at::DataPtr&& data_ptr, size_t nbytes = 0) const { - ShareExternalPointer( - std::move(data_ptr), caffe2::TypeMeta::Make(), nbytes); - } - - void ShareExternalPointer( - void* src, - const TypeMeta data_type, - size_t nbytes = 0, - MemoryDeleter d = nullptr) const { - CAFFE_ENFORCE_WITH_CALLER( - impl_->is_contiguous(), - "Right now ShareExternalPointer is only supported for contiguous Tensor."); - CAFFE_ENFORCE_WITH_CALLER( - data_type != ScalarType::Undefined, - "To share with a raw external pointer you need to pass in an " - "initialized data_type(TypeMeta)."); - impl_.get()->ShareExternalPointer( - at::DataPtr(src, src, d, impl_->device_type()), data_type, nbytes); - } - - void ShareExternalPointer( - at::DataPtr&& data_ptr, - const TypeMeta data_type, - size_t nbytes) { - impl_.get()->ShareExternalPointer(std::move(data_ptr), data_type, nbytes); - } - - const c10::intrusive_ptr& getIntrusivePtr() - const { - return impl_; - } - - bool defined() const { - return impl_; - } - - /** - * Returns a raw void* pointer of the underlying storage. mutable_data() - * or raw_mutable_data() must have been called prior to this function call. - */ - inline void* raw_data() const { - return impl_->mutable_data(); - } - - template - inline T* data() const { - return impl_.get()->mutable_data_dtype_initialized(); - } - - inline void* raw_mutable_data(const TypeMeta meta) const { - return impl_.get()->raw_mutable_data(meta); - } - - /** - * Returns a mutable raw pointer of the underlying storage. This can only be - * used when you know for sure that the underlying storage of the tensor is - * already created via an earlier raw_mutable_data(meta) call or a - * mutable_data() call. - * - * If the existing data does not match the desired type, it will be deleted - * and a new storage will be created. - */ - inline void* raw_mutable_data() const { - const auto& data_type = impl_->dtype(); - CAFFE_ENFORCE_WITH_CALLER( - data_type != ScalarType::Undefined, - "Calling raw_mutable_data() without meta, but the current meta is " - "of unknown type."); - return raw_mutable_data(data_type); - } - - template - inline T* mutable_data() const { - return impl_.get()->mutable_data(); - } - - /** - * Returns the number of dimensions of the data. - */ - inline int dim() const { - return impl_->dim(); - } - - /** - * (To be deprecated) Returns the number of dimensions of the data. - */ - inline int ndim() const { - return impl_->dim(); - } - - /** - * (To be deprecated) Returns the size (i.e. the number of items) of the - * tensor. - */ - inline int64_t size() const { - return impl_->numel(); - } - - /** - * Returns the number of items of the tensor. - */ - inline int64_t numel() const { - return impl_->numel(); - } - - /** - * Return the number of bytes each item takes in the tensor. - */ - inline size_t itemsize() const { - return impl_->dtype().itemsize(); - } - - /** - * Returns the total number of bytes of the storage. - * - * This is equivalent to calling size() * itemsize(). - */ - inline size_t nbytes() const { - return impl_->numel() * itemsize(); - } - - inline at::IntArrayRef sizes() const { - return impl_.get()->sizes(); - } - - inline c10::SymIntArrayRef sym_sizes() const { - return impl_->sym_sizes(); - } - - inline c10::SymInt sym_numel() const { - return impl_->sym_numel(); - } - - inline c10::SymIntArrayRef sym_strides() const { - return impl_->sym_strides(); - } - - inline int64_t size_from_dim(int k) const { - return size_from_dim_(k, impl_->sizes()); - } - - inline int64_t size_to_dim(int k) const { - return size_to_dim_(k, impl_->sizes()); - } - - inline int64_t size_between_dim(int k, int l) const { - return size_between_dim_(k, l, impl_->sizes()); - } - - /** - * Returns the 'canonical' version of a (usually) user-specified axis, - * allowing for negative indexing (e.g., -1 for the last axis). - * - * @param axis_index the axis index. - * If 0 <= index < dim(), return index. - * If -ndim <= index <= -1, return (dim() - (-index)), - * e.g., the last axis index (dim() - 1) if index == -1, - * the second to last if index == -2, etc. - * Dies on out of range index. - */ - inline int canonical_axis_index(int axis_index) const { - return canonical_axis_index_(axis_index, impl_->dim()); - } - - inline int64_t stride(int64_t dim) const { - return impl_.get()->stride(dim); - } - - inline at::IntArrayRef strides() const { - return impl_.get()->strides(); - } - - inline bool is_contiguous( - at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { - return impl_.get()->is_contiguous(memory_format); - } - - /** - * Checks if the tensor content is of the given data type. - */ - template - inline bool IsType() const { - return impl_->dtype().Match(); - } - - /** - * Returns the TypeMeta object associated with the current data type. - */ - inline const TypeMeta dtype() const { - return impl_->dtype(); - } - - /** - * (To be deprecated) Returns the TypeMeta object associated with the current - * data type. - */ - inline const TypeMeta meta() const { - return impl_->dtype(); - } - - /** - * Returns the i-th dimension of the tensor in int. - * - * This function returns an int value instead of int64_t, which depending on - * the typedef could be int64. If you want int64 dim values, make sure you - * call dim() instead. - */ - inline int dim32(const int i) const { -#ifndef NDEBUG - CAFFE_ENFORCE_LT_WITH_CALLER( - i, static_cast(impl_->dim()), "Exceeding ndim limit"); - CAFFE_ENFORCE_GE_WITH_CALLER(i, 0, "Cannot have negative dimension index"); -#endif - // Avoid TensorImpl::size() because it is a virtual call that - // supports out-of-range indexing like Python. - auto s = impl_->sizes()[i]; - CAFFE_ENFORCE_LT_WITH_CALLER(s, std::numeric_limits::max()); - return static_cast(s); - } - - inline int64_t size(const int i) const { - return impl_->size(i); - } - - // To be deprecated - inline int64_t dim(const int i) const { - return impl_->size(i); - } - - const Storage& storage() { - return impl_->storage(); - } - - const Storage& storage() const { - return impl_->storage(); - } - - bool storage_initialized() const { - return impl_->storage_initialized(); - } - - bool dtype_initialized() const { - return impl_->dtype_initialized(); - } -}; - -/** - * Reinitialize a Tensor to given dims and options if necessary, note that - * this will not do anything if the - * Tensor already has correct size and data type - */ -TORCH_API void -ReinitializeTensor(Tensor* t, at::IntArrayRef dims, at::TensorOptions options); - -TORCH_API void ReinitializeAndCopyFrom( - Tensor* t, - at::TensorOptions options, - const Tensor& src, - bool async = false); - -using TensorCPU = Tensor; - -constexpr int k_limit_default_ = 1000; - -// TODO: the following logic can be merged into regular Tensor class methods -// after MKLMemory starts to implement Tensor interface - -// Type call registry -typedef TypeMeta (*TypeCall)(const void*); -TypeCall GetTypeCallFunction(TypeIdentifier id); -void RegisterTypeCallFunction(TypeIdentifier id, TypeCall c); - -// Shape call registry -typedef vector ( - *TensorInfoCall)(const void*, size_t* capacity, DeviceOption* device); -TensorInfoCall GetTensorInfoFunction(TypeIdentifier id); -void RegisterTensorInfoFunction(TypeIdentifier id, TensorInfoCall c); - -// resize helper function -void TensorVectorResize( - std::vector& tensors, - int size, - DeviceType type); - -// Tensor factory function -TORCH_API Tensor empty(at::IntArrayRef dims, at::TensorOptions options); - -/** - * @brief Creates a CPU tensor, and fills its contents with the given values. - * Values are copied in - */ -// TODO: can be unified with at::from_blob when Tensor is merged and string -// types are supported -template -Tensor TensorCPUFromValues(at::IntArrayRef dims, at::ArrayRef values) { - Tensor r = empty(dims, at::device(CPU).dtype()); - CAFFE_ENFORCE_EQ(values.size(), r.numel()); - CPUContext context; - context.CopyItemsFromCPU( - r.dtype(), values.size(), values.data(), r.mutable_data()); - return r; -} - -vector -GetTensorInfo(const void* c, size_t* capacity, DeviceOption* device); - -class TORCH_API TensorPrinter { - public: - explicit TensorPrinter( - const std::string& tensor_name = "", - const std::string& file_name = "", - int limit = k_limit_default_); - ~TensorPrinter(); - - template - void Print(const Tensor& tensor); - - void PrintMeta(const Tensor& tensor); - - string MetaStr(const Tensor& tensor); - - private: - bool to_file_; - int limit_; - std::unique_ptr log_file_; - std::string tensor_name_; -}; - -template -void TensorPrinter::Print(const Tensor& tensor) { - std::stringstream values_stream; - // One most likely doesn't want to print int64-number of items for visual - // inspection, so we cast down to int here. - int total_count = static_cast(std::min(tensor.numel(), int64_t(limit_))); - - const T* tensor_data = tensor.template data(); - for (int i = 0; i < total_count - 1; ++i) { - values_stream << tensor_data[i] << ","; - } - if (total_count) { - // We do not add a comma after the last item. - values_stream << tensor_data[total_count - 1]; - } - - if (to_file_) { - (*log_file_) << MetaStr(tensor) << values_stream.str() << std::endl; - } else { - // Log to console. - LOG(INFO) << MetaStr(tensor) << values_stream.str(); - } -} - -CAFFE_DECLARE_KNOWN_TYPE(Tensor, Caffe2Tensor) -} // namespace caffe2 - -C10_CLANG_DIAGNOSTIC_POP() - -namespace c10 { -template <> -struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {}; -} // namespace c10 -#endif // CAFFE2_CORE_TENSOR_H_ diff --git a/caffe2/core/tensor_int8.h b/caffe2/core/tensor_int8.h deleted file mode 100644 index b95b7b8d10e5..000000000000 --- a/caffe2/core/tensor_int8.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef CAFFE2_TENSOR_INT8_H_ -#define CAFFE2_TENSOR_INT8_H_ - -#include "caffe2/core/context.h" -#include "caffe2/core/tensor.h" -#include "caffe2/proto/caffe2_pb.h" - -namespace caffe2 { -namespace int8 { - -struct Int8TensorCPU { - float scale{1.0}; - int32_t zero_point{0}; - // Generally stores uint8_t data, but sometimes int32_t (e.g. bias - // parameters). - Tensor t{CPU}; -}; -} // namespace int8 -} // namespace caffe2 - -#endif // CAFFE2_TENSOR_INT8_H_ diff --git a/caffe2/core/timer_test.cc b/caffe2/core/timer_test.cc deleted file mode 100644 index 8ffb2f21af03..000000000000 --- a/caffe2/core/timer_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include -#include -#include - -#include "caffe2/core/timer.h" -#include - -namespace caffe2 { -namespace { - -TEST(TimerTest, Test) { - Timer timer; - - // A timer auto-starts when it is constructed. - std::this_thread::sleep_for(std::chrono::microseconds(1)); - EXPECT_GT(timer.NanoSeconds(), 0); - - // Sleep for a while, and get the time. - timer.Start(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - float ns = timer.NanoSeconds(); - float us = timer.MicroSeconds(); - float ms = timer.MilliSeconds(); - - // Time should be at least accurate +- 10%. (30% on Windows) -#ifndef _WIN32 - EXPECT_NEAR(ns, 100000000, 10000000); - EXPECT_NEAR(us, 100000, 10000); - EXPECT_NEAR(ms, 100, 10); -#else - EXPECT_NEAR(ns, 100000000, 30000000); - EXPECT_NEAR(us, 100000, 30000); - EXPECT_NEAR(ms, 100, 30); -#endif - - // Test restarting the clock. - timer.Start(); - EXPECT_LT(timer.MicroSeconds(), 1000); -} - -TEST(TimerTest, TestLatency) { - constexpr int iter = 1000; - float latency = 0; - Timer timer; - for (int i = 0; i < iter; ++i) { - timer.Start(); - latency += timer.NanoSeconds(); - } - std::cout << "Average nanosecond latency is: " << latency / iter << std::endl; - latency = 0; - for (int i = 0; i < iter; ++i) { - timer.Start(); - latency += timer.MicroSeconds(); - } - std::cout << "Average microsecond latency is: " << latency / iter << std::endl; - latency = 0; - for (int i = 0; i < iter; ++i) { - timer.Start(); - latency += timer.MilliSeconds(); - } - std::cout << "Average millisecond latency is: " << latency / iter << std::endl; -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/core/transform_test.cc b/caffe2/core/transform_test.cc deleted file mode 100644 index 0dc6ba92c7f9..000000000000 --- a/caffe2/core/transform_test.cc +++ /dev/null @@ -1,460 +0,0 @@ -#include -#include "caffe2/core/net.h" -#include "caffe2/core/operator.h" -#include "caffe2/core/transform.h" - -namespace caffe2 { - -namespace { - -using transform::Graph; - -static std::atomic counter; - -class TransformDummyOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - counter.fetch_add(1); - return true; - } -}; - -REGISTER_CPU_OPERATOR(TransformDummyOp1, TransformDummyOp); - -OPERATOR_SCHEMA(TransformDummyOp1) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(TransformDummyOp2, TransformDummyOp); - -OPERATOR_SCHEMA(TransformDummyOp2) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -REGISTER_CPU_OPERATOR(TransformDummyOp3, TransformDummyOp); - -OPERATOR_SCHEMA(TransformDummyOp3) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -/** - * This TransformDummy transform will find all subgraphs of shape - * (TransformDummyOp1 -> TransformDummyOp2) and replaces them with - * (TransformDummyOp3). Simple unit test. - */ -class DummyTransform : public Transform { - public: - // Finds all patterns of the form (TransformDummyOp1 -> TransformDummyOp2) - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (subgraph.size() >= pattern_chain.size()) { - return false; - } - // which index are we trying to append the new node to? - auto pattern_idx = subgraph.size(); - // type doesn't match - if (g.node(idx).op.type() != pattern_chain[pattern_idx]) { - return false; - } - // not that head, and doesn't have exactly 1 parent - if (pattern_idx > 0 && g.node(idx).parents.size() != 1) { - return false; - } - // not that tail, and doesn't have exactly 1 child - if (pattern_idx < pattern_chain.size() - 1 && - g.node(idx).children.size() != 1) { - return false; - } - - return true; - } - - // Checks if the subgraph matched is (TransformDummyOp1 -> TransformDummyOp2) - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 2) { - if (g.node(subgraph[0]).op.type() == "TransformDummyOp1" && - g.node(subgraph[1]).op.type() == "TransformDummyOp2") { - return true; - } - } - return false; - } - - // Replaces a match of (TransformDummyOp1 -> TransformDummyOp2) with - // (TransformDummyOp3) - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - auto& g = *g_ptr; - OperatorDef new_op; - new_op.set_type("TransformDummyOp3"); - int new_idx = g.size(); - - std::map> new_op_children; - std::map> new_op_parents; - - // for each node parent in the head of the match, connect it to our new node - for (const auto& edge : g.node(match[0]).parents) { - int parent = edge.first; - for (const auto& blob : edge.second) { - g.node(parent).children[new_idx].push_back(blob); - new_op_parents[parent].push_back(blob); - } - } - for (const string& blob : g.node(match[0]).op.input()) { - new_op.add_input(blob); - } - - // for each child in the tail of the match, connect it to our new node - for (const auto& edge : g.node(match[1]).children) { - int child = edge.first; - for (const auto& blob : edge.second) { - g.node(child).parents[new_idx].push_back(blob); - new_op_children[child].push_back(blob); - } - } - for (const string& blob : g.node(match[1]).op.output()) { - new_op.add_output(blob); - } - - g.DeactivateSubgraph(match); - - g.push_node(transform::Node(new_op, true, new_op_parents, new_op_children)); - return true; - } - - private: - const std::vector pattern_chain = {"TransformDummyOp1", - "TransformDummyOp2"}; -}; - -REGISTER_TRANSFORM(TransformDummySwap, DummyTransform) - -TEST(TransformTest, TestPatternMatch) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp2", {"mid3"}, {"out"}); - - auto t = CreateTransform("TransformDummySwap"); - Graph g(netdef); - auto matches = t->PatternMatch(g); - - EXPECT_EQ(matches.size(), 2); - EXPECT_EQ(matches[0][0], 0); - EXPECT_EQ(matches[0][1], 1); - EXPECT_EQ(matches[1][0], 2); - EXPECT_EQ(matches[1][1], 3); -} - -TEST(TransformTest, TestReplacePattern) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp2", {"mid3"}, {"out"}); - - auto t = CreateTransform("TransformDummySwap"); - Graph g(netdef); - std::vector> matches = {{0, 1}, {2, 3}}; - t->ReplacePattern(matches, &g); - - EXPECT_EQ(g.size(), 6); - EXPECT_FALSE(g.is_node_active(0)); - EXPECT_FALSE(g.is_node_active(1)); - EXPECT_FALSE(g.is_node_active(2)); - EXPECT_FALSE(g.is_node_active(3)); - EXPECT_TRUE(g.is_node_active(4)); - EXPECT_TRUE(g.is_node_active(5)); - - EXPECT_EQ(g.node(4).children.size(), 1); - EXPECT_EQ(g.node(4).parents.size(), 0); - EXPECT_TRUE(g.node(4).children.count(5)); - - NetDef replaced_netdef = g.GetNetDef(); - - EXPECT_EQ(replaced_netdef.op().size(), 2); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(0).input(0), "in"); - EXPECT_EQ(replaced_netdef.op(1).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(1).output(0), "out"); -} - -TEST(TransformTest, TestTransformApply) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp2", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp2", {"mid3"}, {"out"}); - - NetDef replaced_netdef = ApplyTransform("TransformDummySwap", netdef); - - EXPECT_EQ(replaced_netdef.op().size(), 2); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(0).input(0), "in"); - EXPECT_EQ(replaced_netdef.op(1).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(1).output(0), "out"); -} - -/** - * Transform with Sorted Order matching. - * Matches two operators of type TransformDummyOp1, even if disconnected. - * These operators will be given in execution order, - * but doesn't need connectivity. - * Changes them to TransformDummyOp2. - */ -class SortedDummyTransform : public Transform { - public: - SortedDummyTransform() { - SetPatternMatchType(SORTED_WRT_EXECUTION_ORDER); - } - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (g.node(idx).op.type() != "TransformDummyOp1") { - return false; - } - return true; - } - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 2) { - if (g.node(subgraph[0]).op.type() == "TransformDummyOp1" && - g.node(subgraph[1]).op.type() == "TransformDummyOp1") { - return true; - } - } - return false; - } - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - for (const auto& x : match) { - g_ptr->node(x).op.set_type("TransformDummyOp2"); - } - return true; - } -}; - -REGISTER_TRANSFORM(SortedTransformDummySwap, SortedDummyTransform) - -TEST(TransformTest, TestPatternMatchTypeSortedOrder) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp3", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp3", {"mid3"}, {"out"}); - - auto t = CreateTransform("SortedTransformDummySwap"); - NetDef replaced_netdef = t->ApplyTo(netdef); - - EXPECT_EQ(replaced_netdef.op().size(), 4); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp2"); - EXPECT_EQ(replaced_netdef.op(2).type(), "TransformDummyOp2"); -} - -/** - * General subgraph transform. - * Matches a TransformDummyOp1, and a TransformDummyOp2. - * Order doesn't matter. Connectedness doesn't matter. - * Turns them into TransformDummyOp3. - */ -class GeneralDummyTransform : public Transform { - public: - GeneralDummyTransform() { - SetPatternMatchType(GENERAL); - } - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (subgraph.size() == 0 && g.node(idx).op.type() == "TransformDummyOp1") { - return true; - } - if (subgraph.size() == 1 && g.node(idx).op.type() == "TransformDummyOp2") { - return true; - } - return false; - } - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 2) { - if (g.node(subgraph[0]).op.type() == "TransformDummyOp1" && - g.node(subgraph[1]).op.type() == "TransformDummyOp2") { - return true; - } - } - return false; - } - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - for (const auto& x : match) { - g_ptr->node(x).op.set_type("TransformDummyOp3"); - } - return true; - } -}; - -REGISTER_TRANSFORM(GeneralTransformDummySwap, GeneralDummyTransform) - -TEST(TransformTest, TestPatternMatchTypeGeneral) { - Workspace ws; - ws.CreateBlob("in"); - NetDef netdef; - - AddOp(&netdef, "TransformDummyOp2", {"in"}, {"mid1"}); - AddOp(&netdef, "TransformDummyOp3", {"mid1"}, {"mid2"}); - AddOp(&netdef, "TransformDummyOp1", {"mid2"}, {"mid3"}); - AddOp(&netdef, "TransformDummyOp3", {"mid3"}, {"out"}); - - auto t = CreateTransform("GeneralTransformDummySwap"); - NetDef replaced_netdef = t->ApplyTo(netdef); - - EXPECT_EQ(replaced_netdef.op().size(), 4); - EXPECT_EQ(replaced_netdef.op(0).type(), "TransformDummyOp3"); - EXPECT_EQ(replaced_netdef.op(2).type(), "TransformDummyOp3"); -} - -class TransformSleepFastOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - std::this_thread::sleep_for(std::chrono::milliseconds(30)); - return true; - } -}; - -REGISTER_CPU_OPERATOR(TransformSleepFastOp, TransformSleepFastOp); - -OPERATOR_SCHEMA(TransformSleepFastOp) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -class TransformSleepSlowOp final : public OperatorBase { - public: - using OperatorBase::OperatorBase; - bool Run(int /* unused */) override { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - return true; - } -}; - -REGISTER_CPU_OPERATOR(TransformSleepSlowOp, TransformSleepSlowOp); - -OPERATOR_SCHEMA(TransformSleepSlowOp) - .NumInputs(0, INT_MAX) - .NumOutputs(0, INT_MAX) - .AllowInplace({{0, 0}, {1, 1}}); - -/** - * This TransformDummy transform will find all operators of type old_type, - * and replace them with type new_type. - */ -class TypeSwapTransform : public Transform { - public: - // Determine the actual strings through inheriting from derived type. - // NOLINTNEXTLINE(modernize-pass-by-value) - explicit TypeSwapTransform(string old_type, string new_type) - : old_type(old_type), new_type(new_type) {} - - // Really simple, only accept if it's a FastSleepOp, and no match so far. - bool PatternRule(const Graph& g, const std::vector& subgraph, int idx) - override { - if (subgraph.size() == 0 && g.node(idx).op.type() == old_type) { - return true; - } - return false; - } - // Checks if the subgraph matched is a FastSleepOp - bool ValidatorRule(const Graph& g, const std::vector& subgraph) - override { - if (subgraph.size() == 1) { - if (g.node(subgraph[0]).op.type() == old_type) { - return true; - } - } - return false; - } - // Replaces op of original type to new type. - bool ReplaceRule(const std::vector& match, Graph* g_ptr) override { - CHECK(g_ptr); - auto& g = *g_ptr; - g.node(match[0]).op.set_type(new_type); - return true; - } - - private: - string old_type; - string new_type; -}; - -class FastToSlowTransform : public TypeSwapTransform { - public: - explicit FastToSlowTransform() - : TypeSwapTransform("TransformSleepFastOp", "TransformSleepSlowOp") {} -}; - -REGISTER_TRANSFORM(FastToSlow, FastToSlowTransform); - -class SlowToFastTransform : public TypeSwapTransform { - public: - explicit SlowToFastTransform() - : TypeSwapTransform("TransformSleepSlowOp", "TransformSleepFastOp") {} -}; - -REGISTER_TRANSFORM(SlowToFast, SlowToFastTransform); - -TEST(TransformTest, TestApplyTransformIfFasterIsFaster) { - NetDef init_netdef; - AddOp(&init_netdef, "ConstantFill", {}, {"in"}); - - NetDef netdef; - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid"}); - AddOp(&netdef, "TransformSleepSlowOp", {"mid"}, {"out"}); - netdef.add_external_input("in"); // This is important for this function. - - // Make sure the transform would work normally. - auto transformed_net = ApplyTransform("SlowToFast", netdef); - EXPECT_EQ(transformed_net.op(1).type(), "TransformSleepFastOp"); - - // Should be still transform normally. - auto mystery_net = - ApplyTransformIfFaster("SlowToFast", netdef, init_netdef, 5, 10, 1.01); - EXPECT_EQ(mystery_net.op(1).type(), "TransformSleepFastOp"); -} - -TEST(TransformTest, TestApplyTransformIfFasterButSlower) { - NetDef init_netdef; - AddOp(&init_netdef, "ConstantFill", {}, {"in"}); - - NetDef netdef; - AddOp(&netdef, "TransformDummyOp1", {"in"}, {"mid"}); - AddOp(&netdef, "TransformSleepFastOp", {"mid"}, {"out"}); - netdef.add_external_input("in"); // This is important for this function. - - // Make sure the transform would work normally. - auto transformed_net = ApplyTransform("FastToSlow", netdef); - EXPECT_EQ(transformed_net.op(1).type(), "TransformSleepSlowOp"); - - // Should not actually change! - auto mystery_net = - ApplyTransformIfFaster("FastToSlow", netdef, init_netdef, 5, 10, 1.01); - EXPECT_EQ(mystery_net.op(1).type(), "TransformSleepFastOp"); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/core/types.cc b/caffe2/core/types.cc deleted file mode 100644 index dfba94ad06ae..000000000000 --- a/caffe2/core/types.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "caffe2/core/types.h" -#include - -#include -#include -#include - -namespace caffe2 { - -TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { - static_assert( - sizeof(int) == 4, "int in this compiler does not equal to 4 bytes."); - - // Can't use a switch because `meta_id` is not an integer type - const auto meta_id = meta.id(); - if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_FLOAT; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT32; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_STRING; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_BOOL; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_UINT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_UINT16; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT16; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT64; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_FLOAT16; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_DOUBLE; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_UINT8; - } else if (meta_id == TypeMeta::Id()) { - return TensorProto_DataType_INT32; - } else { - return TensorProto_DataType_UNDEFINED; - } -} - -const TypeMeta DataTypeToTypeMeta(const TensorProto_DataType& dt) { - switch (dt) { - case TensorProto_DataType_FLOAT: - return TypeMeta::Make(); - case TensorProto_DataType_INT32: - return TypeMeta::Make(); - case TensorProto_DataType_BYTE: - return TypeMeta::Make(); - case TensorProto_DataType_STRING: - return TypeMeta::Make(); - case TensorProto_DataType_BOOL: - return TypeMeta::Make(); - case TensorProto_DataType_UINT8: - return TypeMeta::Make(); - case TensorProto_DataType_INT8: - return TypeMeta::Make(); - case TensorProto_DataType_UINT16: - return TypeMeta::Make(); - case TensorProto_DataType_INT16: - return TypeMeta::Make(); - case TensorProto_DataType_INT64: - return TypeMeta::Make(); - case TensorProto_DataType_FLOAT16: - return TypeMeta::Make(); - case TensorProto_DataType_DOUBLE: - return TypeMeta::Make(); - default: - throw std::runtime_error("Unknown data type."); - }; -} - -} // namespace caffe2 diff --git a/caffe2/core/types.h b/caffe2/core/types.h deleted file mode 100644 index f83a58910e66..000000000000 --- a/caffe2/core/types.h +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef CAFFE2_CORE_TYPES_H_ -#define CAFFE2_CORE_TYPES_H_ - -#include -#include -#include - -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" -#include -#include "caffe2/proto/caffe2_pb.h" -#include - -namespace caffe2 { - -// Storage orders that are often used in the image applications. -enum StorageOrder { - UNKNOWN = 0, - NHWC = 1, - NCHW = 2, -}; - -inline StorageOrder StringToStorageOrder(const string& str) { - if (str == "NHWC" || str == "nhwc") { - return StorageOrder::NHWC; - } else if (str == "NCHW" || str == "nchw") { - return StorageOrder::NCHW; - } else { - LOG(ERROR) << "Unknown storage order string: " << str; - return StorageOrder::UNKNOWN; - } -} - -inline int32_t GetDimFromOrderString(const std::string& str) { - auto order = StringToStorageOrder(str); - switch (order) { - case StorageOrder::NHWC: - return 3; - case StorageOrder::NCHW: - return 1; - default: - CAFFE_THROW("Unsupported storage order: ", str); - return -1; - } -} - -inline constexpr char NameScopeSeparator() { return '/'; } - -// From TypeMeta to caffe2::DataType protobuffer enum. -TORCH_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta); - -// From caffe2::DataType protobuffer enum to TypeMeta -TORCH_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt); - -} // namespace caffe2 - -/////////////////////////////////////////////////////////////////////////////// -// at::Half is defined in c10/util/Half.h. Currently half float operators are -// mainly on CUDA gpus. -// The reason we do not directly use the cuda __half data type is because that -// requires compilation with nvcc. The float16 data type should be compatible -// with the cuda __half data type, but will allow us to refer to the data type -// without the need of cuda. -static_assert(sizeof(unsigned short) == 2, - "Short on this platform is not 16 bit."); -namespace caffe2 { -// Helpers to avoid using typeinfo with -rtti -template -inline bool fp16_type(); - -template <> -inline bool fp16_type() { - return true; -} - -template -inline bool fp16_type() { - return false; -} - -} // namespace caffe2 - -#endif // CAFFE2_CORE_TYPES_H_ diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h deleted file mode 100644 index 04fa86fe2527..000000000000 --- a/caffe2/core/workspace.h +++ /dev/null @@ -1,342 +0,0 @@ -#ifndef CAFFE2_CORE_WORKSPACE_H_ -#define CAFFE2_CORE_WORKSPACE_H_ - -#include "caffe2/core/common.h" -#include "caffe2/core/observer.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "c10/util/Registry.h" -#include "caffe2/core/blob.h" -#include "caffe2/core/net.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/signal_handler.h" -#include "caffe2/utils/threadpool/ThreadPool.h" - -C10_DECLARE_bool(caffe2_print_blob_sizes_at_exit); - -namespace caffe2 { - -class NetBase; - -struct TORCH_API StopOnSignal { - StopOnSignal() - : handler_(std::make_shared( - SignalHandler::Action::STOP, - SignalHandler::Action::STOP)) {} - - StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {} - - bool operator()(int /*iter*/) { - return handler_->CheckForSignals() != SignalHandler::Action::STOP; - } - - std::shared_ptr handler_; -}; - -/** - * Workspace is a class that holds all the related objects created during - * runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of - * all these objects and deals with the scaffolding logistics. - */ -class TORCH_API Workspace { - public: - typedef std::function ShouldContinue; - /** - * Initializes an empty workspace. - */ - Workspace() : Workspace(".", nullptr) {} - - /** - * Initializes an empty workspace with the given root folder. - * - * For any operators that are going to interface with the file system, such - * as load operators, they will write things under this root folder given - * by the workspace. - */ - explicit Workspace(const string& root_folder) - : Workspace(root_folder, nullptr) {} - - /** - * Initializes a workspace with a shared workspace. - * - * When we access a Blob, we will first try to access the blob that exists - * in the local workspace, and if not, access the blob that exists in the - * shared workspace. The caller keeps the ownership of the shared workspace - * and is responsible for making sure that its lifetime is longer than the - * created workspace. - */ - explicit Workspace(const Workspace* shared) : Workspace(".", shared) {} - - /** - * Initializes workspace with parent workspace, blob name remapping - * (new name -> parent blob name), no other blobs are inherited from - * parent workspace - */ - Workspace( - const Workspace* shared, - const std::unordered_map& forwarded_blobs) - : Workspace(".", nullptr) { - CAFFE_ENFORCE(shared, "Parent workspace must be specified"); - for (const auto& forwarded : forwarded_blobs) { - CAFFE_ENFORCE( - shared->HasBlob(forwarded.second), - "Invalid parent workspace blob: ", - forwarded.second); - forwarded_blobs_[forwarded.first] = - std::make_pair(shared, forwarded.second); - } - } - - /** - * Initializes a workspace with a root folder and a shared workspace. - */ - Workspace(const string& root_folder, const Workspace* shared) - : root_folder_(root_folder), shared_(shared), bookkeeper_(bookkeeper()) { - std::lock_guard guard(bookkeeper_->wsmutex); - bookkeeper_->workspaces.insert(this); - } - - ~Workspace() { - if (FLAGS_caffe2_print_blob_sizes_at_exit) { - PrintBlobSizes(); - } - // This is why we have a bookkeeper_ shared_ptr instead of a naked static! A - // naked static makes us vulnerable to out-of-order static destructor bugs. - std::lock_guard guard(bookkeeper_->wsmutex); - bookkeeper_->workspaces.erase(this); - } - - /** - * Adds blob mappings from workspace to the blobs from parent workspace. - * Creates blobs under possibly new names that redirect read/write operations - * to the blobs in the parent workspace. - * Arguments: - * parent - pointer to parent workspace - * forwarded_blobs - map from new blob name to blob name in parent's - * workspace skip_defined_blob - if set skips blobs with names that already - * exist in the workspace, otherwise throws exception - */ - void AddBlobMapping( - const Workspace* parent, - const std::unordered_map& forwarded_blobs, - bool skip_defined_blobs = false); - - /** - * Converts previously mapped tensor blobs to local blobs, copies values from - * parent workspace blobs into new local blobs. Ignores undefined blobs. - */ - template - void CopyForwardedTensors(const std::unordered_set& blobs) { - for (const auto& blob : blobs) { - auto it = forwarded_blobs_.find(blob); - if (it == forwarded_blobs_.end()) { - continue; - } - const auto& ws_blob = it->second; - const auto* parent_ws = ws_blob.first; - auto* from_blob = parent_ws->GetBlob(ws_blob.second); - CAFFE_ENFORCE(from_blob); - CAFFE_ENFORCE( - from_blob->template IsType(), - "Expected blob with tensor value", - ws_blob.second); - forwarded_blobs_.erase(blob); - auto* to_blob = CreateBlob(blob); - CAFFE_ENFORCE(to_blob); - const auto& from_tensor = from_blob->template Get(); - auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType()); - to_tensor->CopyFrom(from_tensor); - } - } - - /** - * Return list of blobs owned by this Workspace, not including blobs - * shared from parent workspace. - */ - vector LocalBlobs() const; - - /** - * Return a list of blob names. This may be a bit slow since it will involve - * creation of multiple temp variables. For best performance, simply use - * HasBlob() and GetBlob(). - */ - vector Blobs() const; - - /** - * Return the root folder of the workspace. - */ - const string& RootFolder() { return root_folder_; } - /** - * Checks if a blob with the given name is present in the current workspace. - */ - inline bool HasBlob(const string& name) const { - // First, check the local workspace, - // Then, check the forwarding map, then the parent workspace - if (blob_map_.count(name)) { - return true; - } - - auto it = forwarded_blobs_.find(name); - if (it != forwarded_blobs_.end()) { - const auto parent_ws = it->second.first; - const auto& parent_name = it->second.second; - return parent_ws->HasBlob(parent_name); - } - - if (shared_) { - return shared_->HasBlob(name); - } - - return false; - } - - void PrintBlobSizes(); - - /** - * Creates a blob of the given name. The pointer to the blob is returned, but - * the workspace keeps ownership of the pointer. If a blob of the given name - * already exists, the creation is skipped and the existing blob is returned. - */ - Blob* CreateBlob(const string& name); - /** - * Similar to CreateBlob(), but it creates a blob in the local workspace even - * if another blob with the same name already exists in the parent workspace - * -- in such case the new blob hides the blob in parent workspace. If a blob - * of the given name already exists in the local workspace, the creation is - * skipped and the existing blob is returned. - */ - Blob* CreateLocalBlob(const string& name); - /** - * Remove the blob of the given name. Return true if removed and false if - * not exist. - * Will NOT remove from the shared workspace. - */ - bool RemoveBlob(const string& name); - /** - * Gets the blob with the given name as a const pointer. If the blob does not - * exist, a nullptr is returned. - */ - const Blob* GetBlob(const string& name) const; - /** - * Gets the blob with the given name as a mutable pointer. If the blob does - * not exist, a nullptr is returned. - */ - Blob* GetBlob(const string& name); - - /** - * Renames a local workspace blob. If blob is not found in the local blob list - * or if the target name is already present in local or any parent blob list - * the function will throw. - */ - Blob* RenameBlob(const string& old_name, const string& new_name); - - /** - * Creates a network with the given NetDef, and returns the pointer to the - * network. If there is anything wrong during the creation of the network, a - * nullptr is returned. The Workspace keeps ownership of the pointer. - * - * If there is already a net created in the workspace with the given name, - * CreateNet will overwrite it if overwrite=true is specified. Otherwise, an - * exception is thrown. - */ - NetBase* CreateNet(const NetDef& net_def, bool overwrite = false); - NetBase* CreateNet( - const std::shared_ptr& net_def, - bool overwrite = false); - /** - * Gets the pointer to a created net. The workspace keeps ownership of the - * network. - */ - NetBase* GetNet(const string& net_name); - /** - * Deletes the instantiated network with the given name. - */ - void DeleteNet(const string& net_name); - /** - * Finds and runs the instantiated network with the given name. If the network - * does not exist or there are errors running the network, the function - * returns false. - */ - bool RunNet(const string& net_name); - - /** - * Returns a list of names of the currently instantiated networks. - */ - vector Nets() const { - vector names; - for (auto& entry : net_map_) { - names.push_back(entry.first); - } - return names; - } - - /** - * Runs a plan that has multiple nets and execution steps. - */ - bool RunPlan(const PlanDef& plan_def, - ShouldContinue should_continue = StopOnSignal{}); - - /* - * Returns a CPU threadpool instance for parallel execution of - * work. The threadpool is created lazily; if no operators use it, - * then no threadpool will be created. - */ - ThreadPool* GetThreadPool(); - - // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference - // between RunNet and RunNetOnce lies in the fact that RunNet allows you to - // have a persistent net object, while RunNetOnce creates a net and discards - // it on the fly - this may make things like database read and random number - // generators repeat the same thing over multiple calls. - bool RunOperatorOnce(const OperatorDef& op_def); - bool RunNetOnce(const NetDef& net_def); - - /** - * Applies a function f on each workspace that currently exists. - * - * This function is thread safe and there is no race condition between - * workspaces being passed to f in this thread and destroyed in another. - */ - template - static void ForEach(F f) { - auto bk = bookkeeper(); - std::lock_guard guard(bk->wsmutex); - for (Workspace* ws : bk->workspaces) { - f(ws); - } - } - - public: - std::atomic last_failed_op_net_position{}; - - private: - struct Bookkeeper { - std::mutex wsmutex; - std::unordered_set workspaces; - }; - - static std::shared_ptr bookkeeper(); - - std::unordered_map> blob_map_; - const string root_folder_; - const Workspace* shared_; - std::unordered_map> - forwarded_blobs_; - std::unique_ptr thread_pool_; - std::mutex thread_pool_creation_mutex_; - std::shared_ptr bookkeeper_; - std::unordered_map> net_map_; - - C10_DISABLE_COPY_AND_ASSIGN(Workspace); -}; - -} // namespace caffe2 - -#endif // CAFFE2_CORE_WORKSPACE_H_ diff --git a/caffe2/core/workspace_test.cc b/caffe2/core/workspace_test.cc deleted file mode 100644 index c3f6ff0fb48f..000000000000 --- a/caffe2/core/workspace_test.cc +++ /dev/null @@ -1,149 +0,0 @@ -#include - -#include "caffe2/core/operator.h" -#include - -namespace caffe2 { - -class WorkspaceTestFoo {}; - -CAFFE_KNOWN_TYPE(WorkspaceTestFoo); - -TEST(WorkspaceTest, BlobAccess) { - Workspace ws; - - EXPECT_FALSE(ws.HasBlob("nonexisting")); - EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr); - - EXPECT_EQ(ws.GetBlob("newblob"), nullptr); - EXPECT_NE(nullptr, ws.CreateBlob("newblob")); - EXPECT_NE(nullptr, ws.GetBlob("newblob")); - EXPECT_TRUE(ws.HasBlob("newblob")); - - // Different names should still be not created. - EXPECT_FALSE(ws.HasBlob("nonexisting")); - EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr); - - // Check if the returned Blob is OK for all operations - Blob* blob = ws.GetBlob("newblob"); - int* int_unused CAFFE2_UNUSED = blob->GetMutable(); - EXPECT_TRUE(blob->IsType()); - EXPECT_FALSE(blob->IsType()); - EXPECT_NE(&blob->Get(), nullptr); - - // Re-creating the blob does not change the content as long as it already - // exists. - EXPECT_NE(nullptr, ws.CreateBlob("newblob")); - EXPECT_TRUE(blob->IsType()); - EXPECT_FALSE(blob->IsType()); - // When not null, we should only call with the right type. - EXPECT_NE(&blob->Get(), nullptr); - - // Re-creating the blob through CreateLocalBlob does not change the content - // either. - EXPECT_NE(nullptr, ws.CreateLocalBlob("newblob")); - EXPECT_TRUE(blob->IsType()); - EXPECT_NE(&blob->Get(), nullptr); - - // test removing blob - EXPECT_FALSE(ws.HasBlob("nonexisting")); - EXPECT_FALSE(ws.RemoveBlob("nonexisting")); - EXPECT_TRUE(ws.HasBlob("newblob")); - EXPECT_TRUE(ws.RemoveBlob("newblob")); - EXPECT_FALSE(ws.HasBlob("newblob")); -} - -TEST(WorkspaceTest, RunEmptyPlan) { - PlanDef plan_def; - Workspace ws; - EXPECT_TRUE(ws.RunPlan(plan_def)); -} - -TEST(WorkspaceTest, Sharing) { - Workspace parent; - EXPECT_FALSE(parent.HasBlob("a")); - EXPECT_TRUE(parent.CreateBlob("a")); - EXPECT_TRUE(parent.GetBlob("a")); - { - Workspace child(&parent); - // Child can access parent blobs - EXPECT_TRUE(child.HasBlob("a")); - EXPECT_TRUE(child.GetBlob("a")); - // Child can create local blobs - EXPECT_FALSE(child.HasBlob("b")); - EXPECT_FALSE(child.GetBlob("b")); - EXPECT_TRUE(child.CreateBlob("b")); - EXPECT_TRUE(child.GetBlob("b")); - // Parent cannot access child blobs - EXPECT_FALSE(parent.GetBlob("b")); - EXPECT_FALSE(parent.HasBlob("b")); - // Parent can create duplicate names - EXPECT_TRUE(parent.CreateBlob("b")); - // But child has local overrides - EXPECT_NE(child.GetBlob("b"), parent.GetBlob("b")); - // Child can create a blob that already exists in the parent - EXPECT_TRUE(child.CreateBlob("a")); - EXPECT_EQ(child.GetBlob("a"), parent.GetBlob("a")); - // Child can create a local blob for the blob already exists in the parent - EXPECT_TRUE(child.CreateLocalBlob("a")); - // But the local blob will be different from the one in parent workspace - EXPECT_NE(child.GetBlob("a"), parent.GetBlob("a")); - } -} - -TEST(WorkspaceTest, BlobMapping) { - Workspace parent; - EXPECT_FALSE(parent.HasBlob("a")); - EXPECT_TRUE(parent.CreateBlob("a")); - EXPECT_TRUE(parent.GetBlob("a")); - { - std::unordered_map forwarded_blobs; - forwarded_blobs["inner_a"] = "a"; - Workspace child(&parent, forwarded_blobs); - EXPECT_FALSE(child.HasBlob("a")); - EXPECT_TRUE(child.HasBlob("inner_a")); - EXPECT_TRUE(child.GetBlob("inner_a")); - Workspace ws; - EXPECT_TRUE(ws.CreateBlob("b")); - forwarded_blobs.clear(); - forwarded_blobs["inner_b"] = "b"; - child.AddBlobMapping(&ws, forwarded_blobs); - EXPECT_FALSE(child.HasBlob("b")); - EXPECT_TRUE(child.HasBlob("inner_b")); - EXPECT_TRUE(child.GetBlob("inner_b")); - } -} - -/** - * Checks that Workspace::ForEach(f) applies f on the specified set of - * workspaces in any order. - */ -static void forEachCheck(std::initializer_list workspaces) { - std::unordered_set expected(workspaces); - std::unordered_set actual; - Workspace::ForEach([&](Workspace* ws) { - auto inserted = actual.insert(ws).second; - EXPECT_TRUE(inserted); - }); - EXPECT_EQ(actual, expected); -} - -TEST(WorkspaceTest, ForEach) { - forEachCheck({}); - - { - Workspace ws1; - forEachCheck({&ws1}); - - { - Workspace ws2; - forEachCheck({&ws1, &ws2}); - } - - forEachCheck({&ws1}); - } - - forEachCheck({}); -} - -} // namespace caffe2 diff --git a/caffe2/cuda_rtc/CMakeLists.txt b/caffe2/cuda_rtc/CMakeLists.txt deleted file mode 100644 index 6bb289b79d72..000000000000 --- a/caffe2/cuda_rtc/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -if(USE_CUDA) - set(Caffe2_CUDA_RTC_GPU_SRC - "${CMAKE_CURRENT_SOURCE_DIR}/elemenntwise_rtc_gpu.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/pool_op_rtc_gpu.cc" - ) - - set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} ${Caffe2_CUDA_RTC_GPU_SRC}) - set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE) -else() - message(STATUS "CUDA RTC operators skipped due to no CUDA support") -endif() diff --git a/caffe2/cuda_rtc/common_rtc.h b/caffe2/cuda_rtc/common_rtc.h deleted file mode 100644 index 0fa6bad7a0c4..000000000000 --- a/caffe2/cuda_rtc/common_rtc.h +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef CAFFE2_CUDA_RTC_COMMON_RTC_H_ -#define CAFFE2_CUDA_RTC_COMMON_RTC_H_ - -#include -#include - -#include -#include - -#define NVRTC_CHECK(condition) \ - do { \ - nvrtcResult result = condition; \ - if (result != NVRTC_SUCCESS) { \ - LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ - << nvrtcGetErrorString(result); \ - } \ - } while (0) - -namespace caffe2 { - -template -class CudaRTCFunction { - public: - CudaRTCFunction() : module_loaded_(false) {} - ~CudaRTCFunction() { - if (module_loaded_) { - CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_)); - } - } - - // TODO: this function is nontrivial and since CudaRTCFunction uses CRTP, it - // may potentially increase the binary size. In that case, move common parts - // into a separate function. - template - void Compile(Args... args) { - string src = static_cast(this)->GetSource(args...); - string name = static_cast(this)->KernelName(args...); - VLOG(1) << "function name: " << name; - VLOG(1) << "function src:\n" << src; - // Actually do the compiling. - nvrtcProgram prog; - NVRTC_CHECK( - nvrtcCreateProgram(&prog, src.c_str(), nullptr, 0, nullptr, nullptr)); - // Compile the program. - // TODO(Yangqing): how to find the current gpu architecture instead of hard - // coding it? - const char* nvrtc_opts[] = { - "--gpu-architecture=compute_35", "--use_fast_math"}; - nvrtcResult compile_result = nvrtcCompileProgram(prog, 2, nvrtc_opts); - if (compile_result != NVRTC_SUCCESS) { - size_t log_size; - NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size)); - std::string nvrtc_log(log_size, '\0'); - NVRTC_CHECK(nvrtcGetProgramLog(prog, &nvrtc_log[0])); - LOG(FATAL) << "Compilation failure for nvrtc(" - << nvrtcGetErrorString(compile_result) << "): \n" - << nvrtc_log; - } - size_t ptx_size; - NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size)); - vector nvrtc_ptx(ptx_size); - NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx.data())); - NVRTC_CHECK(nvrtcDestroyProgram(&prog)); - // After compilation, load the module. - if (module_loaded_) { - CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_)); - } - CUDA_DRIVERAPI_ENFORCE( - cuModuleLoadDataEx(&module_, nvrtc_ptx.data(), 0, 0, 0)); - module_loaded_ = true; - CUDA_DRIVERAPI_ENFORCE( - cuModuleGetFunction(&kernel_, module_, name.c_str())); - } - - template - void Launch( - unsigned int gx, - unsigned int gy, - unsigned int gz, - unsigned int bx, - unsigned int by, - unsigned int bz, - unsigned int shared_mem, - cudaStream_t stream, - Args... args) { - CAFFE_ENFORCE( - module_loaded_, "Cannot call Launch before a module is loaded."); - void* args_voidp[] = {&args...}; - CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( - kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, args_voidp, 0)); - } - - void LaunchEx( - unsigned int gx, - unsigned int gy, - unsigned int gz, - unsigned int bx, - unsigned int by, - unsigned int bz, - unsigned int shared_mem, - cudaStream_t stream, - void** extra) { - CAFFE_ENFORCE( - module_loaded_, "Cannot call Launch before a module is loaded."); - CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( - kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, nullptr, extra)); - } - - private: - bool module_loaded_; - CUmodule module_; - CUfunction kernel_; -}; - -// TODO: this is in no way unique and is just a hack right now. -inline std::string GetUniqueName() { - static constexpr int len = 20; - static const char alpha[] = - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - - std::stringstream ss; - ss << "_cuda_kernel_"; - for (const auto i : c10::irange(len)) { - ss << alpha[rand() % (sizeof(alpha) - 1)]; - } - return ss.str(); -} - -} // namespace caffe2 - -#endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_ diff --git a/caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc b/caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc deleted file mode 100644 index dfa3981731e7..000000000000 --- a/caffe2/cuda_rtc/elemenntwise_rtc_gpu.cc +++ /dev/null @@ -1,129 +0,0 @@ -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/operator.h" -#include "caffe2/cuda_rtc/common_rtc.h" - -namespace caffe2 { -namespace { -class ElementwiseRTCFunction : public CudaRTCFunction { - public: - ElementwiseRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {} - - template - string KernelName(Args... /*args*/) { - return name_; - } - - template - string GetSource(Args... args); - - private: - string name_; -}; - -template <> -string ElementwiseRTCFunction::GetSource( - int input_size, - int output_size, - const string command_string) { - std::stringstream ss; - ss << "extern \"C\" __global__ void " << name_ - << "(const size_t nthreads, \n"; - // Insert the parameter list. - int remain_params = input_size + output_size; - for (int i = 0; i < input_size; ++i) { - ss << "const float* in" << i << ((remain_params--) ? ", \n" : ""); - } - for (int i = 0; i < output_size; ++i) { - ss << "float* out" << i << ((remain_params--) ? ", \n" : ""); - } - ss << ") {\n" - "for (int index = blockIdx.x * blockDim.x + threadIdx.x;\n" - "index < nthreads; index += blockDim.x * gridDim.x) {\n" - << command_string << "\n" - << "}\n}"; - return ss.str(); -} -} // namespace - -/** - * A GPU operator that can generate limited elementwise operations. - * - * ElementwiseRTCOp allows one to do a simple and limited thing: it takes in - * multiple inputs and multiple outputs, as well as a raw string argument - * rtc_src. The runtime then generates the following kernel code: - * - * __global__ void kernel_name(const size_t nthreads, ...) { - * for(int index = blockIdx.x * blockDim.x + threadIdx.x; - * index < nthreads; index += blockDim.x * gridDim.x) { - * rtc_src - * } - * } - * where the "..." part is auto generated, so one can refer to the input and - * output as in0, in1, ..., out0, out1... in the rtc_src string. - * - * For example, if one wants to do a vector multiplication, one can take two - * inputs and one outputs, and write rtc_src as - * out0[index] = in0[index] * in1[index]; - * - * This op is currently highly experimental. We do not have a gradient - * registered for it either. - */ -class ElementwiseRTCOp final : public Operator { - public: - ElementwiseRTCOp(const OperatorDef& operator_def, Workspace* ws) - : Operator(operator_def, ws) { - const string src = OperatorBase::GetSingleArgument("rtc_src", ""); - CAFFE_ENFORCE(src.size(), "Op should have a non-zero source code size."); - func_.Compile(InputSize(), OutputSize(), src); - } - ~ElementwiseRTCOp() override {} - - bool RunOnDevice() override { - static_assert( - sizeof(void*) == sizeof(size_t), - "The argbuffer relies on the assumption that void* and " - "size_t have the same size."); - vector argBuffer_vec(InputSize() + OutputSize() + 1); - size_t* argBuffer = argBuffer_vec.data(); - CAFFE_ENFORCE( - Input(0).numel() < std::numeric_limits::max(), - "The kernel function currently only supports int index."); - argBuffer[0] = Input(0).numel(); - void** ptr_buffer = reinterpret_cast(argBuffer + 1); - for (int i = 0; i < InputSize(); ++i) { - ptr_buffer[i] = const_cast(Input(i).data()); - } - for (int i = 0; i < OutputSize(); ++i) { - Output(i)->ResizeLike(Input(0)); - ptr_buffer[i + InputSize()] = Output(i)->mutable_data(); - } - size_t argBufferSize = sizeof(argBuffer); - void* config[] = { - CU_LAUNCH_PARAM_BUFFER_POINTER, - argBuffer, - CU_LAUNCH_PARAM_BUFFER_SIZE, - &argBufferSize, - CU_LAUNCH_PARAM_END}; - func_.LaunchEx( - CAFFE_GET_BLOCKS(Input(0).numel()), - 1, - 1, - CAFFE_CUDA_NUM_THREADS, - 1, - 1, - 0, - context_.cuda_stream(), - config); - return true; - } - - private: - ElementwiseRTCFunction func_; -}; - -namespace { -REGISTER_CUDA_OPERATOR_WITH_ENGINE(ElementwiseRTC, NVRTC, ElementwiseRTCOp); -} - -} // namespace caffe2 diff --git a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc b/caffe2/cuda_rtc/pool_op_rtc_gpu.cc deleted file mode 100644 index 8ec14e1223ae..000000000000 --- a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc +++ /dev/null @@ -1,340 +0,0 @@ -#include - -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/cuda_rtc/common_rtc.h" -#include "caffe2/operators/pool_op.h" - -namespace caffe2 { -namespace { -class AveragePool {}; -class MaxPool {}; -} // namespace - -namespace { - -// The max pool forward function, with parameters written in const int. -const char kMaxPoolForwardNCHWSource[] = R"( -extern "C" -__global__ void %s(const float* bottom_data, float* top_data) { - const int nthreads = %d; - const int channels = %d; - const int height = %d; - const int width = %d; - const int pooled_height = %d; - const int pooled_width = %d; - const int kernel_h = %d; - const int kernel_w = %d; - const int stride_h = %d; - const int stride_w = %d; - const int pad_t = %d; - const int pad_l = %d; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; - index < nthreads; index += blockDim.x * gridDim.x) { - int pw = index %% pooled_width; - int ph = (index / pooled_width) %% pooled_height; - int c = (index / (pooled_width * pooled_height)) %% channels; - int n = index / (pooled_width * pooled_height * channels); - int hstart = ph * stride_h - pad_t; - int wstart = pw * stride_w - pad_l; - int hend = min(hstart + kernel_h, height); - int wend = min(wstart + kernel_w, width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - float maxval = -1.0e37f; - const float* bdata_offset = bottom_data + n * channels * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - maxval = fmaxf( - bdata_offset[c * height * width + h * width + w], maxval); - } - } - top_data[index] = maxval; - } -} -)"; - -// The max pool forward function, with parameters written in const int. -const char kMaxPoolBackwardNCHWSource[] = R"( -extern "C" -__global__ void %s( - const float* const bottom_data, const float* const top_data, - const float* const top_diff, float* const bottom_diff) { - const int nthreads = %d; - const int num = %d; - const int channels = %d; - const int height = %d; - const int width = %d; - const int pooled_height = %d; - const int pooled_width = %d; - const int kernel_h = %d; - const int kernel_w = %d; - const int stride_h = %d; - const int stride_w = %d; - const int pad_t = %d; - const int pad_l = %d; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; - index < nthreads; index += blockDim.x * gridDim.x) { - const int w = index %% width + pad_l; - const int h = (index / width) %% height + pad_t; - const int c = (index / width / height) %% channels; - const int n = index / width / height / channels; - const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; - const int phend = min(h / stride_h + 1, pooled_height); - const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; - const int pwend = min(w / stride_w + 1, pooled_width); - const int top_offset = - (n * channels + c) * pooled_height * pooled_width; - bottom_diff[index] = 0; - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - int top_local_offset = top_offset + ph * pooled_width + pw; - if (bottom_data[index] == top_data[top_local_offset]) { - bottom_diff[index] += top_diff[top_local_offset]; - } - } - } - } -} -)"; - -class MaxPoolRTCFunction : public CudaRTCFunction { - public: - MaxPoolRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {} - - template - string KernelName(Args... /*args*/) { - return name_; - } - - template - string GetSource(Args... args); - - private: - string name_; -}; - -class MaxPoolGradientRTCFunction - : public CudaRTCFunction { - public: - MaxPoolGradientRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {} - - template - string KernelName(Args... /*args*/) { - return name_; - } - - template - string GetSource(Args... args); - - private: - string name_; -}; - -template <> -string MaxPoolRTCFunction::GetSource( - const int output_size, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_t, - const int pad_l) { - char buffer[65536]; - int nbytes = snprintf( - buffer, - 65536, - kMaxPoolForwardNCHWSource, - name_.c_str(), - output_size, - channels, - height, - width, - pooled_height, - pooled_width, - kernel_h, - kernel_w, - stride_h, - stride_w, - pad_t, - pad_l); - TORCH_DCHECK_GE(nbytes, 0); - TORCH_DCHECK_LT(nbytes, 65536); - return string(buffer); -} - -template <> -string MaxPoolGradientRTCFunction::GetSource( - const int output_size, - const int num, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_t, - const int pad_l) { - char buffer[65536]; - int nbytes = snprintf( - buffer, - 65536, - kMaxPoolBackwardNCHWSource, - name_.c_str(), - output_size, - num, - channels, - height, - width, - pooled_height, - pooled_width, - kernel_h, - kernel_w, - stride_h, - stride_w, - pad_t, - pad_l); - TORCH_DCHECK_GE(nbytes, 0); - TORCH_DCHECK_LT(nbytes, 65536); - return string(buffer); -} - -} // namespace - -class MaxPoolRTCOp final : public ConvPoolOpBase { - public: - MaxPoolRTCOp(const OperatorDef& operator_def, Workspace* ws) - : ConvPoolOpBase(operator_def, ws) { - CAFFE_ENFORCE_EQ( - order_, StorageOrder::NCHW, "Currently only NCHW is supported."); - } - ~MaxPoolRTCOp() override {} - - bool RunOnDeviceWithOrderNCHW() override { - auto& X = Input(0); - auto output_sizes = - ConvPoolOpBase::GetOutputSize(X, X.dim32(1)); - auto* Y = Output(0, output_sizes, at::dtype()); - - if (input_dims_ != X.sizes()) { - // recompile - VLOG(1) << "MaxPool RTC recompiling"; - CAFFE_ENFORCE_LT(Y->numel(), std::numeric_limits::max()); - func_.Compile( - static_cast(Y->numel()), - X.dim32(1), - X.dim32(2), - X.dim32(3), - Y->dim32(2), - Y->dim32(3), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l()); - input_dims_ = X.sizes().vec(); - } - // Carry out the pooling computation. - func_.Launch( - CAFFE_GET_BLOCKS(Y->numel()), - 1, - 1, - CAFFE_CUDA_NUM_THREADS, - 1, - 1, - 0, - context_.cuda_stream(), - X.data(), - Y->mutable_data()); - return true; - } - - bool RunOnDeviceWithOrderNHWC() override { - LOG(FATAL) << "Not implemented."; - return false; - } - - private: - MaxPoolRTCFunction func_; - vector input_dims_; -}; - -class MaxPoolGradientRTCOp final : public ConvPoolOpBase { - public: - MaxPoolGradientRTCOp(const OperatorDef& operator_def, Workspace* ws) - : ConvPoolOpBase(operator_def, ws) { - CAFFE_ENFORCE_EQ( - order_, StorageOrder::NCHW, "Currently only NCHW is supported."); - } - ~MaxPoolGradientRTCOp() override {} - - bool RunOnDeviceWithOrderNCHW() override { - auto& X = Input(0); - auto& Y = Input(1); - auto& dY = Input(2); - CAFFE_ENFORCE_EQ(dY.dim(), 4); - - auto* dX = Output(0, X.sizes(), at::dtype()); - ConvPoolOpBase::ComputePads({X.dim32(2), X.dim32(3)}); - if (input_dims_ != X.sizes()) { - VLOG(1) << "MaxPoolGradient RTC recompiling"; - CAFFE_ENFORCE_LT(X.numel(), std::numeric_limits::max()); - func_.Compile( - static_cast(X.numel()), - X.dim32(0), - X.dim32(1), - X.dim32(2), - X.dim32(3), - dY.dim32(2), - dY.dim32(3), - kernel_h(), - kernel_w(), - stride_h(), - stride_w(), - pad_t(), - pad_l()); - input_dims_ = X.sizes().vec(); - } - func_.Launch( - CAFFE_GET_BLOCKS(X.numel()), - 1, - 1, - CAFFE_CUDA_NUM_THREADS, - 1, - 1, - 0, - context_.cuda_stream(), - X.data(), - Y.data(), - dY.data(), - dX->mutable_data()); - return true; - } - - bool RunOnDeviceWithOrderNHWC() override { - LOG(FATAL) << "Not implemented."; - return false; - } - - private: - MaxPoolGradientRTCFunction func_; - vector input_dims_; -}; - -namespace { -REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPool, NVRTC, MaxPoolRTCOp); -REGISTER_CUDA_OPERATOR_WITH_ENGINE( - MaxPoolGradient, - NVRTC, - MaxPoolGradientRTCOp); -} // namespace -} // namespace caffe2 diff --git a/caffe2/onnx/torch_ops/CMakeLists.txt b/caffe2/onnx/torch_ops/CMakeLists.txt deleted file mode 100644 index 99443af4cc9b..000000000000 --- a/caffe2/onnx/torch_ops/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -# ---[ Extra onnx files. -file(GLOB ONNX_SRCS *.cc) - -# ---[ Send the lists to the parent scope. -set(ONNX_SRCS ${ONNX_SRCS} PARENT_SCOPE) diff --git a/caffe2/onnx/torch_ops/constants.h b/caffe2/onnx/torch_ops/constants.h deleted file mode 100644 index ebd2a2464d9b..000000000000 --- a/caffe2/onnx/torch_ops/constants.h +++ /dev/null @@ -1,7 +0,0 @@ -namespace ONNX_NAMESPACE { - -const int AI_ONNX_PYTORCH_DOMAIN_MIN_OPSET = 1; -const int AI_ONNX_PYTORCH_DOMAIN_MAX_OPSET = 1; -constexpr const char* AI_ONNX_PYTORCH_DOMAIN = "ai.onnx.pytorch"; - -} // namespace ONNX_NAMESPACE diff --git a/caffe2/onnx/torch_ops/defs.cc b/caffe2/onnx/torch_ops/defs.cc deleted file mode 100644 index a324cce6f284..000000000000 --- a/caffe2/onnx/torch_ops/defs.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) Facebook Inc. and Microsoft Corporation. -// Licensed under the MIT license. - -#include "./schema.h" - -namespace ONNX_NAMESPACE { - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - SparseLengthsSumFused8BitRowwise, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 SparseLengthsSumFused8BitRowwise operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "INDICES", "indices tensor", "T2") - .Input(2, "LENGTHS", "lengths tensor", "T2") - .Output(0, "output", "Output tensor", "T2") - .TypeConstraint( - "T1", - {"tensor(uint8)"}, - "Constrain input data to uint8 tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - SparseLengthsSum, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 SparseLengthsSum operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "INDICES", "indices tensor", "T2") - .Input(2, "LENGTHS", "lengths tensor", "T2") - .Output(0, "output", "Output tensor", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - SparseLengthsWeightedSum, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 SparseLengthsWeightedSum operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "WEIGHTS", "data tensor", "T1") - .Input(2, "INDICES", "indices tensor", "T2") - .Input(3, "LENGTHS", "lengths tensor", "T2") - .Output(0, "output", "Output tensor", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - BatchGather, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 BatchGather operator") - .Input(0, "DATA", "data tensor", "T1") - .Input(1, "INDICES", "indices tensor", "T2") - .Output(0, "output", "Output tensor", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int8)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(uint8)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)"}, - "Constrain index and length to integral tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - DotProduct, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 DotProduct operator") - .Input(0, "X", "Input 1 tensor", "T") - .Input(1, "Y", "Input 2 tensor", "T") - .Output(0, "Z", "Output tensor", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - FCTransposed, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 FCTransposed operator") - .Input(0, "X", "Input tensor", "T") - .Input(1, "W", "Weight tensor", "T") - .Input(2, "B", "Bias tensor", "T") - .Output(0, "Z", "Output tensor", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - BatchMatMul, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 BatchMatMul operator") - .Input(0, "X", "tensor of shape (dim0, dim1 ... M, K)", "T") - .Input(1, "Y", "tensor of shape (dim0, dim2 ... K, N)", "T") - .Output(0, "Z", "tensor of shape (dim0, dim1 ... M, N)", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,bugprone-branch-clone) -ONNX_PYTORCH_OPERATOR_SET_SCHEMA( - ExpandDims, - 1, - OpSchema() - .SetDoc("Mirror Caffe2 ExpandDims operator") - .Input(0, "X", "Input tensor", "T") - .Output(0, "Y", "Output tensor", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to float tensors.")); - -} // namespace ONNX_NAMESPACE diff --git a/caffe2/onnx/torch_ops/operator_sets.h b/caffe2/onnx/torch_ops/operator_sets.h deleted file mode 100644 index f7380af3910f..000000000000 --- a/caffe2/onnx/torch_ops/operator_sets.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include "onnx/defs/schema.h" - -namespace ONNX_NAMESPACE { - -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME( - PyTorch, - 1, - SparseLengthsSumFused8BitRowwise); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims); - -// Iterate over schema from ai.onnx.pytorch domain opset 1 -class OpSet_PyTorch_ver1 { - public: - static void ForEachSchema(std::function fn) { - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); - } -}; - -inline void RegisterPyTorchOperatorSetSchema() { - RegisterOpSetSchema(); -} - -} // namespace ONNX_NAMESPACE diff --git a/caffe2/onnx/torch_ops/schema.cc b/caffe2/onnx/torch_ops/schema.cc deleted file mode 100644 index de933c2c23ab..000000000000 --- a/caffe2/onnx/torch_ops/schema.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "./schema.h" -#include "./operator_sets.h" - -namespace { -using namespace ONNX_NAMESPACE; -class PyTorchSchemasRegisterer { - public: - PyTorchSchemasRegisterer() { - OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion( - AI_ONNX_PYTORCH_DOMAIN, - AI_ONNX_PYTORCH_DOMAIN_MIN_OPSET, - AI_ONNX_PYTORCH_DOMAIN_MAX_OPSET); - RegisterPyTorchOperatorSetSchema(); - } -}; -static PyTorchSchemasRegisterer registerer{}; -} // namespace diff --git a/caffe2/onnx/torch_ops/schema.h b/caffe2/onnx/torch_ops/schema.h deleted file mode 100644 index 3454e366a1ee..000000000000 --- a/caffe2/onnx/torch_ops/schema.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -#include "./constants.h" -#include "onnx/defs/schema.h" - -#define ONNX_PYTORCH_OPERATOR_SET_SCHEMA(name, ver, impl) \ - ONNX_OPERATOR_SET_SCHEMA_EX( \ - name, PyTorch, AI_ONNX_PYTORCH_DOMAIN, ver, false, impl) diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index 9510ec60dfef..3d08e5c0a7bb 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -24,8 +24,6 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) if(CXX_AVX2_FOUND) add_library(Caffe2_perfkernels_avx STATIC ${avx_srcs}) add_library(Caffe2_perfkernels_avx2 STATIC ${avx2_srcs}) - add_dependencies(Caffe2_perfkernels_avx Caffe2_PROTO) - add_dependencies(Caffe2_perfkernels_avx2 Caffe2_PROTO) target_link_libraries(Caffe2_perfkernels_avx PRIVATE c10) target_link_libraries(Caffe2_perfkernels_avx2 PRIVATE c10) install(TARGETS Caffe2_perfkernels_avx Caffe2_perfkernels_avx2 @@ -62,7 +60,6 @@ if(CXX_AVX2_FOUND) if(CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) add_library(Caffe2_perfkernels_avx512 STATIC ${avx512_srcs}) - add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO) target_link_libraries(Caffe2_perfkernels_avx512 PRIVATE c10) install(TARGETS Caffe2_perfkernels_avx512 ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/caffe2/perfkernels/adagrad.cc b/caffe2/perfkernels/adagrad.cc deleted file mode 100644 index c589187cb2eb..000000000000 --- a/caffe2/perfkernels/adagrad.cc +++ /dev/null @@ -1,186 +0,0 @@ -#include "caffe2/perfkernels/adagrad.h" - -#include - -#include "caffe2/perfkernels/common.h" - -namespace caffe2 { - -void adagrad_update__base( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - const float lr, - const float weight_decay = 0.f) { - internal::adagrad_update_base_inlined( - N, w, g, h, nw, nh, decay, epsilon, lr, weight_decay); -} - -void adagrad_update_prefetch__base( - int N, - const float* w, - const float* /* w_n */, // prefetch ptr - - const float* g, - - const float* h, - const float* /* h_n */, // prefetch ptr - - float* nw, - float* /* nw_n */, // prefetch ptr - - float* nh, - float* /* nh_n */, // prefetch ptr - - float epsilon, - float lr, - float weight_decay = 0.f) { - adagrad_update__base(N, w, g, h, nw, nh, epsilon, 1.0f, lr, weight_decay); -} - -void adagrad_fp16_update_prefetch__base( - int N, - const at::Half* w, - const at::Half* /* w_n */, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* /* h_n */, // prefetch ptr - at::Half* nw, - at::Half* /* nw_n */, // prefetch ptr - at::Half* nh, - at::Half* /* nh_n */, // prefetch ptr - float epsilon, - float lr, - float weight_decay = 0.f) { - internal::adagrad_update_base_inlined( - N, w, g, h, nw, nh, 1.0f, epsilon, lr, weight_decay); -} - -// version without prefetching -decltype(adagrad_update__base) adagrad_update__avx2_fma; -decltype(adagrad_update__base) adagrad_update__avx512; -void adagrad_update( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay) { - AVX512_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); - AVX2_FMA_DO( - adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); - BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay); -} - -decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx2_fma; -void adagrad_update_prefetch( - int N, - const float* w, - const float* w_n, // prefetch ptr - - const float* g, - - const float* h, - const float* h_n, // prefetch ptr - - float* nw, - float* nw_n, // prefetch ptr - - float* nh, - float* nh_n, // prefetch ptr - - float epsilon, - float lr, - float weight_decay) { - AVX2_FMA_DO( - adagrad_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); - BASE_DO( - adagrad_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); -} - -// Version with prefetching for embeddings and -// momentum using fp16 -decltype(adagrad_fp16_update_prefetch__base) - adagrad_fp16_update_prefetch__avx2_fma; -void adagrad_fp16_update_prefetch( - int N, - const at::Half* w, - const at::Half* w_n, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* h_n, // prefetch ptr - at::Half* nw, - at::Half* nw_n, // prefetch ptr - at::Half* nh, - at::Half* nh_n, // prefetch ptr - float epsilon, - float lr, - float weight_decay) { - AVX2_FMA_DO( - adagrad_fp16_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); - BASE_DO( - adagrad_fp16_update_prefetch, - N, - w, - w_n, - g, - h, - h_n, - nw, - nw_n, - nh, - nh_n, - epsilon, - lr, - weight_decay); -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/adagrad.h b/caffe2/perfkernels/adagrad.h deleted file mode 100644 index f030e3e09d60..000000000000 --- a/caffe2/perfkernels/adagrad.h +++ /dev/null @@ -1,205 +0,0 @@ -#pragma once - -#if defined(__AVX__) && !defined(__NVCC__) && \ - (defined(__x86_64__) || defined(_M_X64) || defined(__i386__)) -#define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC -#include -#endif -#include -#include - -namespace caffe2 { - -namespace internal { - -// The following functions inside internal namespace are inlined because they -// are performance critical. - -template -static inline void adagrad_update_base_inlined( - int N, - const T* w, - const float* g, - const T* h, - T* nw, - T* nh, - float decay, - float epsilon, - float lr, - float weight_decay = 0.f) { - for (const auto i : c10::irange(N)) { - float gi = std::fma(weight_decay, w[i], g[i]); - float hi = decay * h[i] + gi * gi; - nh[i] = hi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - -// version with prefetching -// TODO(msmelyan) -// Crux of the computation is computing a / (sqrt(b) + epsilon), -// where a and b are vectors and epsilon is very small (eg., 10^-5) and does not -// change. Today it's computed using two vector sqrt and vector divide simd -// instructions. It is slow. We can take advantage of existing fast vector -// VRSQRTPS instruction that computes approximate reciprocals of square roots -// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the -// addition of epsilon is just done to avoid division by zero, we approximate a -// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can -// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for -// the test on random numbers between 0.1 and 1 the absolute error was about -// 10^-3 compared to using slower but more accurate combination of vsqrt and -// vdiv. Extend Marat's function with more NR iterations to get more accuracy -// for training -// TODO(msmelyan) -// explore streaming stores, but need to have unique indices (deduplication) -inline void adagrad_update_prefetch_inlined( - int N, - const float* w, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - const float* w_n, // prefetch ptr -#else - const float* /* unused */, -#endif - - const float* g, - - const float* h, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - const float* h_n, // prefetch ptr -#else - const float* /* unused */, -#endif - - float* nw, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - float* nw_n, // prefetch ptr -#else - float* /* unused */, -#endif - - float* nh, -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - float* nh_n, // prefetch ptr -#else - float* /* unused */, -#endif - - float epsilon, - float lr, - float weight_decay = 0.f) { - auto i = 0; - -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC - constexpr int kSize = 8; - for (; i + kSize <= N; i += kSize) { - _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&h_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nw_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nh_n[i]), _MM_HINT_T0); - - __m256 gi = _mm256_loadu_ps(g + i); - __m256 hi = _mm256_loadu_ps(h + i); - __m256 wi = _mm256_loadu_ps(w + i); -#ifdef __FMA__ - gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); - -#else - gi = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(weight_decay), wi), gi); -#endif - - __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi)); - _mm256_storeu_ps(nh + i, nhi); - __m256 vtmp = _mm256_div_ps( - _mm256_mul_ps(_mm256_set1_ps(lr), gi), - _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - _mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp)); - } -#endif - - adagrad_update_base_inlined( - N - i, - w + i, - g + i, - h + i, - nw + i, - nh + i, - 1.0f, - epsilon, - lr, - weight_decay); -} - -} // namespace internal - -// version with prefetching -// TODO(msmelyan) -// Crux of the computation is computing a / (sqrt(b) + epsilon), -// where a and b are vectors and epsilon is very small (eg., 10^-5) and does not -// change. Today it's computed using two vector sqrt and vector divide simd -// instructions. It is slow. We can take advantage of existing fast vector -// VRSQRTPS instruction that computes approximate reciprocals of square roots -// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the -// addition of epsilon is just done to avoid division by zero, we approximate a -// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can -// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for -// the test on random numbers between 0.1 and 1 the absolute error was about -// 10^-3 compared to using slower but more accurate combination of vsqrt and -// vdiv. Extend Marat's function with more NR iterations to get more accuracy -// for training -// TODO(msmelyan) -// explore streaming stores, but need to have inuque indices (deduplication) -void adagrad_update_prefetch( - int N, - const float* w, - const float* w_n, // prefetch ptr - - const float* g, - - const float* h, - const float* h_n, // prefetch ptr - - float* nw, - float* nw_n, // prefetch ptr - - float* nh, - float* nh_n, // prefetch ptr - - float epsilon, - float lr, - float weight_decay = 0.f); - -// Version with prefetching for embeddings and -// momentum using fp16 -void adagrad_fp16_update_prefetch( - int N, - const at::Half* w, - const at::Half* w_n, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* h_n, // prefetch ptr - at::Half* nw, - at::Half* nw_n, // prefetch ptr - at::Half* nh, - at::Half* nh_n, // prefetch ptr - float epsilon, - float lr, - float weight_decay = 0.f); - -// version without prefetching -void adagrad_update( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay = 0.f); - -} // namespace caffe2 - -#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC -#undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC -#endif diff --git a/caffe2/perfkernels/adagrad_avx2.cc b/caffe2/perfkernels/adagrad_avx2.cc deleted file mode 100644 index 08c9fd00d9a0..000000000000 --- a/caffe2/perfkernels/adagrad_avx2.cc +++ /dev/null @@ -1,125 +0,0 @@ -#include "caffe2/perfkernels/adagrad.h" -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include - -namespace caffe2 { - -// version without prefetching -void adagrad_update__avx2_fma( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay = 0.f) { - constexpr int kSize = 8; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - __m256 gi = _mm256_loadu_ps(g + i); - __m256 hi = _mm256_loadu_ps(h + i); - __m256 wi = _mm256_loadu_ps(w + i); - gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); - - __m256 nhi = _mm256_add_ps( - _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi)); - _mm256_storeu_ps(nh + i, nhi); - __m256 vtmp = _mm256_div_ps( - _mm256_mul_ps(_mm256_set1_ps(lr), gi), - _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - _mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp)); - } - - for (; i < N; ++i) { - float gi = std::fma(weight_decay, w[i], g[i]); - float hi = nh[i] = decay * h[i] + gi * gi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - -void adagrad_update_prefetch__avx2_fma( - int N, - const float* w, - const float* w_n, // prefetch ptr - - const float* g, - - const float* h, - const float* h_n, // prefetch ptr - - float* nw, - float* nw_n, // prefetch ptr - - float* nh, - float* nh_n, // prefetch ptr - - float epsilon, - float lr, - float weight_decay = 0.f) { - internal::adagrad_update_prefetch_inlined( - N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr, weight_decay); -} - -// Compute adagrad sparse, assumes embedding and momentum are at::Half -void adagrad_fp16_update_prefetch__avx2_fma( - int N, - const at::Half* w, - const at::Half* w_n, // prefetch ptr - const float* g, - const at::Half* h, - const at::Half* h_n, // prefetch ptr - at::Half* nw, - at::Half* nw_n, // prefetch ptr - at::Half* nh, - at::Half* nh_n, // prefetch ptr - float epsilon, - float lr, - float weight_decay = 0.f) { - constexpr int kSize = 8; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - _mm_prefetch(reinterpret_cast(&w_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&h_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nw_n[i]), _MM_HINT_T0); - _mm_prefetch(reinterpret_cast(&nh_n[i]), _MM_HINT_T0); - - // only convert momentum and embedding, gradient is fp32 - __m256 gi = _mm256_loadu_ps(g + i); - __m128i hhi = _mm_loadu_si128(reinterpret_cast(h + i)); - __m256 hi = _mm256_cvtph_ps(hhi); - __m128i whi = _mm_loadu_si128(reinterpret_cast(w + i)); - __m256 wi = _mm256_cvtph_ps(whi); - gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi); - - __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi)); - __m128i nhhi = _mm256_cvtps_ph(nhi, 0); - _mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi); - - __m256 vtmp = _mm256_div_ps( - _mm256_mul_ps(_mm256_set1_ps(lr), gi), - _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon))); - __m256 nwi = _mm256_add_ps(wi, vtmp); - __m128i nhwi = _mm256_cvtps_ph(nwi, 0); - _mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi); - } - - for (; i < N; ++i) { - float gi = std::fma( - weight_decay, - _cvtsh_ss(reinterpret_cast(w)[i]), - g[i]); - float nhi = - _cvtsh_ss(reinterpret_cast(h)[i]) + gi * gi; - reinterpret_cast(nh)[i] = _cvtss_sh(nhi, 0); - float nwi = _cvtsh_ss(reinterpret_cast(w)[i]) + - lr * gi / (std::sqrt(nhi) + epsilon); - reinterpret_cast(nw)[i] = _cvtss_sh(nwi, 0); - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/adagrad_avx512.cc b/caffe2/perfkernels/adagrad_avx512.cc deleted file mode 100644 index 417dd1ca8bab..000000000000 --- a/caffe2/perfkernels/adagrad_avx512.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "caffe2/perfkernels/adagrad.h" -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include - -namespace caffe2 { - -// version without prefetching -void adagrad_update__avx512( - int N, - const float* w, - const float* g, - const float* h, - float* nw, - float* nh, - float epsilon, - float decay, - float lr, - float weight_decay = 0.f) { - constexpr int kSize = 16; - auto i = 0; - for (; i + kSize <= N; i += kSize) { - __m512 gi = _mm512_loadu_ps(g + i); - __m512 hi = _mm512_loadu_ps(h + i); - __m512 wi = _mm512_loadu_ps(w + i); - gi = _mm512_fmadd_ps(_mm512_set1_ps(weight_decay), wi, gi); - - __m512 nhi = _mm512_add_ps( - _mm512_mul_ps(_mm512_set1_ps(decay), hi), _mm512_mul_ps(gi, gi)); - _mm512_storeu_ps(nh + i, nhi); - __m512 vtmp = _mm512_div_ps( - _mm512_mul_ps(_mm512_set1_ps(lr), gi), - _mm512_add_ps(_mm512_sqrt_ps(nhi), _mm512_set1_ps(epsilon))); - _mm512_storeu_ps(nw + i, _mm512_add_ps(wi, vtmp)); - } - - for (; i < N; ++i) { - float gi = std::fma(weight_decay, w[i], g[i]); - float hi = nh[i] = decay * h[i] + gi * gi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox.cc b/caffe2/perfkernels/batch_box_cox.cc deleted file mode 100644 index 7172f4b9d8cd..000000000000 --- a/caffe2/perfkernels/batch_box_cox.cc +++ /dev/null @@ -1,113 +0,0 @@ -#include "caffe2/perfkernels/common.h" - -#include -#include -#include - -namespace caffe2 { - -namespace { -template -void BoxCoxNaive( - std::size_t N, - std::size_t D, - const T* data_ptr, - const T* __restrict lambda1_ptr, - const T* __restrict lambda2_ptr, - T* output_ptr) { - constexpr T k_eps = static_cast(1e-6); - - for (std::size_t i = 0; i < N; i++) { - for (std::size_t j = 0; j < D; j++, data_ptr++, output_ptr++) { - T lambda1_v = lambda1_ptr[j]; - T lambda2_v = lambda2_ptr[j]; - T tmp = std::max(*data_ptr + lambda2_v, k_eps); - if (lambda1_v == 0) { - *output_ptr = std::log(tmp); - } else { - T lambda_1 = 1 / lambda1_v; - T pow = std::pow(tmp, lambda1_v); - *output_ptr = lambda_1 * pow - lambda_1; - } - } - } - -} -} - -#if defined(CAFFE2_PERF_WITH_AVX2) && defined(CAFFE2_PERF_USE_MKL) -namespace details { -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* data_ptr, - const T* __restrict lambda1_ptr, - const T* __restrict lambda2_ptr, - T* output_ptr); - -extern template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* self_data, - const float* __restrict lambda1_data, - const float* __restrict lambda2_data, - float* output_data); - -extern template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* self_data, - const double* __restrict lambda1_data, - const double* __restrict lambda2_data, - double* output_data); -} // namespace detail -#endif - -template -void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* data, - const T* lambda1_data, - const T* lambda2_data, - T* output_data) { -#ifdef CAFFE2_PERF_WITH_AVX2 - AVX2_FMA_DO( - details::compute_batch_box_cox, - N, - D, - block_size, - data, - lambda1_data, - lambda2_data, - output_data); -#endif - BoxCoxNaive(N, D, data, lambda1_data, lambda2_data, output_data); -} - -template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* data, - const float* lambda1_data, - const float* lambda2_data, - float* output_data); - -template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* data, - const double* lambda1_data, - const double* lambda2_data, - double* output_data); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox.h b/caffe2/perfkernels/batch_box_cox.h deleted file mode 100644 index 60c973bbf8ea..000000000000 --- a/caffe2/perfkernels/batch_box_cox.h +++ /dev/null @@ -1,35 +0,0 @@ -// Impmenets BoxCox operator for CPU -#pragma once -#include - -namespace caffe2 { - -template -void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* self_data, - const T* lambda1_data, - const T* lambda2_data, - T* output_data); - -extern template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* data, - const float* lambda1_data, - const float* lambda2_data, - float* output_data); - -extern template void compute_batch_box_cox( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* data, - const double* lambda1_data, - const double* lambda2_data, - double* output_data); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc deleted file mode 100644 index 6171b5bfd032..000000000000 --- a/caffe2/perfkernels/batch_box_cox_avx2.cc +++ /dev/null @@ -1,399 +0,0 @@ -#include -#ifdef CAFFE2_PERF_USE_MKL -#include -#include -#include - -#include "vectorizer.h" - -// Enable compiler vectorized version only if numerical consistency is not -// required between dev and opt versions - disabled for now -#ifndef FAST_VECTORIZED_KERNEL -#define CPU_CAPABILITY_AVX2 -#include - -namespace at::vec { - -// Implements the vectorized version of std::max() operation, -// which DOESNOT propagates NaN for second argument -template -Vectorized max(const Vectorized& a, const Vectorized& b); - -template <> -Vectorized max(const Vectorized& a, const Vectorized& b) { - // std::max(NaN, nonNan) -> NaN - return _mm256_max_pd(b, a); -} - -template <> -Vectorized max(const Vectorized& a, const Vectorized& b) { - // std::max(NaN, nonNan) -> NaN - return _mm256_max_ps(b, a); -} - -// Implements recieprocal method based on newton-rapson method -// 1. user RCP approximiation -// 2. update with RCP = RCP * (2 - X * RCP) -template -Vectorized fast_recieprocal(const Vectorized& b); -template -scalar_t fast_recieprocal(scalar_t b); - -template<> -Vectorized fast_recieprocal(const Vectorized& b) { - auto minus2 = _mm256_set1_ps(-2.f); - auto rcp = _mm256_rcp_ps(b); - rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); - rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); - return rcp; -} - -template <> -float fast_recieprocal(float b) { - auto minus2 = _mm_set_ss(-2.f); - auto b_reg = _mm_set_ss(b); - auto rcp = _mm_rcp_ss(b_reg); - rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); - rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); - return _mm_cvtss_f32(rcp); -} - -template<> -Vectorized fast_recieprocal(const Vectorized& b) { - return b.reciprocal(); -} - -template <> -double fast_recieprocal(double b) { - return 1./b; -} - -} -#endif - -#include -#include -#include - -#include - -namespace caffe2::details { - -// MKL VML function templates. -template -void PackV(const int N, const T* a, const int* ia, T* y); -template -void UnpackV(const int N, const T* a, T* y, const int* iy); - -#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void PackV(const int N, const T* a, const int* ia, T* y) { \ - OriginalFunc(N, a, ia, y); \ - } -DELEGATE_PACKV_FUNCTION(float, vsPackV) -DELEGATE_PACKV_FUNCTION(double, vdPackV) -#undef DELEGATE_PACKV_FUNCTION - -#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void UnpackV(const int N, const T* a, T* y, const int* iy) { \ - OriginalFunc(N, a, y, iy); \ - } -DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) -DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) -#undef DELEGATE_UNPACKV_FUNCTION - -#ifndef FAST_VECTORIZED_KERNEL -template -void box_cox_zero_lambda( - size_t D, - const T* const self_data, - const T* const lambda2_data, - T k_eps, - T* const output_data) { - int j = 0; - using Vec = at::vec::Vectorized; - constexpr int64_t VLEN = Vec::size(); - auto k_eps_vec = Vec(k_eps); - for(; j + VLEN < D; j += VLEN) { - auto data = Vec::loadu(self_data + j); - auto lambda2 = Vec::loadu(lambda2_data + j); - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps_vec); - auto res = max.log(); - res.store(output_data + j); - } - for ( ;j < D; ++j) { - auto sum = self_data[j] + lambda2_data[j]; - auto max = std::max(sum, k_eps); - output_data[j] = std::log(max); - } -} - -template -void box_cox_nonzero_lambda( - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* out) { - - int j = 0; - using Vec = at::vec::Vectorized; - constexpr int64_t VLEN = Vec::size(); - auto k_eps_vec = Vec(k_eps); - for(; j + VLEN < D; j += VLEN) { - auto data = Vec::loadu(data_ptr + j); - auto lambda2 = Vec::loadu(lambda2_ptr + j); - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps_vec); - auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); - auto pow = max.pow(lambda1); - auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); - res.store(out + j); - } - for ( ;j < D; ++j) { - auto sum = data_ptr[j] + lambda2_ptr[j]; - auto max = std::max(sum, k_eps); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); - auto pow = std::pow(max, lambda1_ptr[j]); - out[j] = pow * lambda_over_1 - lambda_over_1; - } -} -#else -template -void box_cox_zero_lambda( - size_t D, - const T* const self_data, - const T* const lambda2_data, - T k_eps, - T* const output_data) { - VECTOR_LOOP for (auto j=0 ;j < D; ++j) { - auto sum = self_data[j] + lambda2_data[j]; - auto max = std::max(sum, k_eps); - output_data[j] = std::log(max); - } -} - -template -void box_cox_nonzero_lambda( - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* out) { - - VECTOR_LOOP for (auto j=0 ;j < D; ++j) { - FAST_MATH - auto sum = data_ptr[j] + lambda2_ptr[j]; - auto max = std::max(sum, k_eps); - auto lamda1 = lambda1_ptr[j]; - auto lambda_over_1 = 1 / lamda1; - if constexpr (std::is_same::value) { - lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); - lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); - } - auto pow = std::pow(max, lamda1); - out[j] = pow * lambda_over_1 - lambda_over_1; - } -} -#endif - -template -void box_cox_mixed_lambda( - const T* const self_data, - const std::vector& nonzeros, - const std::vector& zeros, - const T* const lambda1, - const T* const lambda2, - const T* const lambda2_z_, - T k_eps, - T* const buffer, - T* const output_data) { - PackV(nonzeros.size(), self_data, nonzeros.data(), buffer); - box_cox_nonzero_lambda( - nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); - UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data()); - - PackV(zeros.size(), self_data, zeros.data(), buffer); - box_cox_zero_lambda( - zeros.size(), buffer, lambda2_z_, k_eps, buffer); - UnpackV(zeros.size(), buffer, output_data, zeros.data()); -} - -template -void TileArrayIntoVector( - const T* const a, - const size_t D, - const int K, - std::vector& b) { - b.resize(K * D); - for (const auto k : c10::irange(K)) { - std::copy(a, a + D, b.begin() + k * D); - } -} - -void TileIndicesInPlace(std::vector& v, const std::size_t D, const std::size_t K) { - auto n = v.size(); - v.resize(K * n); - for (const auto k : c10::irange(1, K)) { - for (const auto j : c10::irange(n)) { - v[k * n + j] = v[j] + k * D; - } - } -} - -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* self_data, - const T* __restrict lambda1_data, - const T* __restrict lambda2_data, - T* output_data) { - constexpr T k_eps = static_cast(1e-6); - - FOLLY_DECLARE_REUSED(zeros, std::vector); - FOLLY_DECLARE_REUSED(nonzeros, std::vector); - // Don't bother calling reserve; calls after the first will get a - // correctly-sized allocation anyway. - for (const auto j : c10::irange(D)) { - if (lambda1_data[j] == 0) { - zeros.push_back(j); - } else { - nonzeros.push_back(j); - } - } - - // Process K rows at a time for effective vectorization with small rows. - const auto K = std::min(N, (block_size + D - 1) / D); - - FOLLY_DECLARE_REUSED(lambda1_, std::vector); - FOLLY_DECLARE_REUSED(lambda2_, std::vector); - FOLLY_DECLARE_REUSED(lambda2_z_, std::vector); - - if (nonzeros.size() == D) { - // ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0 - size_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda1_data, D, K, lambda1_); - TileArrayIntoVector(lambda2_data, D, K, lambda2_); - DCHECK_EQ(K * D, lambda1_.size()); - DCHECK_EQ(K * D, lambda2_.size()); - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_nonzero_lambda( - K * D, - self_data, - lambda1_.data(), - lambda2_.data(), - k_eps, - output_data); - } - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_nonzero_lambda( - D, self_data, lambda1_data, lambda2_data, k_eps, output_data); - } - } else if (zeros.size() == D) { - // ln(x + lambda2), if lambda1 == 0 - size_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda2_data, D, K, lambda2_z_); - DCHECK_EQ(K * D, lambda2_z_.size()); - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_zero_lambda( - K * D, self_data, lambda2_z_.data(), k_eps, output_data); - } - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_zero_lambda( - D, self_data, lambda2_data, k_eps, output_data); - } - } else { - // mix zeros and nonzeros - const size_t n = nonzeros.size(); - if (K > 1) { - TileIndicesInPlace(nonzeros, 0, K); - TileIndicesInPlace(zeros, 0, K); - } - - FOLLY_DECLARE_REUSED(buffer, std::vector); - - buffer.resize(std::max(nonzeros.size(), zeros.size())); - lambda1_.resize(nonzeros.size()); - lambda2_.resize(nonzeros.size()); - lambda2_z_.resize(zeros.size()); - PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data()); - PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data()); - PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data()); - - size_t i = 0; - if (K > 1) { - // Truncate to original size, and re-tile with offsets this time. - nonzeros.resize(n); - DCHECK_GT(D, n); - zeros.resize(D - n); - TileIndicesInPlace(nonzeros, D, K); - TileIndicesInPlace(zeros, D, K); - DCHECK_EQ(nonzeros.size(), lambda1_.size()); - DCHECK_EQ(nonzeros.size(), lambda2_.size()); - DCHECK_EQ(zeros.size(), lambda2_z_.size()); - - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_mixed_lambda( - self_data, - nonzeros, - zeros, - lambda1_.data(), - lambda2_.data(), - lambda2_z_.data(), - k_eps, - buffer.data(), - output_data); - } - // Truncate to original size. - nonzeros.resize(n); - zeros.resize(D - n); - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_mixed_lambda( - self_data, - nonzeros, - zeros, - lambda1_.data(), - lambda2_.data(), - lambda2_z_.data(), - k_eps, - buffer.data(), - output_data); - } - } -}; - - -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* self_data, - const float* __restrict lambda1_data, - const float* __restrict lambda2_data, - float* output_data); - -template -void compute_batch_box_cox__avx2_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* self_data, - const double* __restrict lambda1_data, - const double* __restrict lambda2_data, - double* output_data); - -} // namespace caffe2::detail -#endif diff --git a/caffe2/perfkernels/cvtsh_ss_bugfix.h b/caffe2/perfkernels/cvtsh_ss_bugfix.h deleted file mode 100644 index 6a748faa0e57..000000000000 --- a/caffe2/perfkernels/cvtsh_ss_bugfix.h +++ /dev/null @@ -1,75 +0,0 @@ -#pragma once - -// Apple clang was fixed in 8.1 -#if defined(__apple_build_version__) && \ - ((__clang_major__ < 8) || \ - ((__clang_major__ == 8) && (__clang_minor__ < 1))) -#define CAFFE2_INTERNAL_APPLE_NEED_FIX 1 -#endif - -// Regular clang was fixed in 3.9 -#if defined(__clang__) && (__clang_major__ < 4) && (__clang_minor__ < 9) -#define CAFFE2_INTERNAL_CLANG_NEED_FIX 1 -#endif - -#if defined(CAFFE2_INTERNAL_APPLE_NEED_FIX) || \ - defined(CAFFE2_INTERNAL_CLANG_NEED_FIX) - -#include -#include - -// This version of clang has a bug that _cvtsh_ss is not defined, see -// https://reviews.llvm.org/D16177 -static __inline float - __attribute__((__always_inline__, __nodebug__, __target__("f16c"))) - _cvtsh_ss(unsigned short a) { - __v8hi v = {(short)a, 0, 0, 0, 0, 0, 0, 0}; - __v4sf r = __builtin_ia32_vcvtph2ps(v); - return r[0]; -} - -static __inline unsigned short - __attribute__((__always_inline__, __nodebug__, __target__("f16c"))) - _cvtss_sh(float a, int imm8) { - unsigned short ret; - *reinterpret_cast(&ret) = a; - return ret; -} - -#endif // __APPLE_NEED_FIX || __CLANG_NEED_FIX - -#undef __APPLE_NEED_FIX -#undef __CLANG_NEED_FIX - -#if defined(_MSC_VER) && !defined(__clang__) - -#include -#include - -// It seems that microsoft msvc does not have a _cvtsh_ss implementation so -// we will add a dummy version to it. - -static inline float _cvtsh_ss(unsigned short x) { - union { - std::uint32_t intval; - float floatval; - } t1; - std::uint32_t t2, t3; - t1.intval = x & 0x7fff; // Non-sign bits - t2 = x & 0x8000; // Sign bit - t3 = x & 0x7c00; // Exponent - t1.intval <<= 13; // Align mantissa on MSB - t2 <<= 16; // Shift sign bit into position - t1.intval += 0x38000000; // Adjust bias - t1.intval = (t3 == 0 ? 0 : t1.intval); // Denormals-as-zero - t1.intval |= t2; // Re-insert sign bit - return t1.floatval; -} - -static inline unsigned short _cvtss_sh(float x, int imm8) { - unsigned short ret; - *reinterpret_cast(&ret) = x; - return ret; -} - -#endif // _MSC_VER diff --git a/caffe2/perfkernels/embedding_lookup.cc b/caffe2/perfkernels/embedding_lookup.cc index 687d081301e4..96ae253b32c6 100644 --- a/caffe2/perfkernels/embedding_lookup.cc +++ b/caffe2/perfkernels/embedding_lookup.cc @@ -1,8 +1,9 @@ #include "caffe2/perfkernels/embedding_lookup.h" -#include "caffe2/core/types.h" #include "caffe2/perfkernels/common.h" +#include +#include #include namespace caffe2 { diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 48c869ee7038..c9b91dc31b88 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -2,9 +2,8 @@ #include #include +#include #include -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" #include "caffe2/perfkernels/common.h" namespace caffe2 { diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc deleted file mode 100644 index b1522ecda7e2..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc +++ /dev/null @@ -1,212 +0,0 @@ -#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" - -#include "caffe2/core/types.h" -#include "caffe2/perfkernels/common.h" -#include "caffe2/utils/cpuid.h" - -#include - -namespace caffe2 { - -/** - * Base implementation does runtime dispatch for each segment of reduction - * @return false if there is an out-of-bound error - */ -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -static bool Fused8BitRowwiseEmbeddingLookupGenericSlow( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const InType* input, - const IndexType* indices, - const int* lengths, - const float* weights, // optional, can be null for sum reducer - bool normalize_by_lengths, - OutType* out) { - // block_size is the number of elements and fused_block_size is the size of - // an entire row, including scale and bias. - const auto scale_bias_offset = 8 / sizeof(InType); - const int64_t fused_block_size = block_size + scale_bias_offset; - int64_t current = 0; - for (const auto m : c10::irange(output_size)) { - memset(out, 0, sizeof(OutType) * block_size); - if (current + lengths[m] > index_size) { - return false; - } - for (int i = 0; i < lengths[m]; ++i) { - int64_t idx = indices[current]; - if (idx < 0 || idx >= data_size) { - return false; - } -#ifdef __GNUC__ - if (current + 1 < index_size) { - __builtin_prefetch( - input + fused_block_size * indices[current + 1], 0, 1); - } -#endif // __GNUC__ - - const float* scale_bias = reinterpret_cast( - input + fused_block_size * indices[current] + block_size); - - float weight = 1.0f; - if (weights) { - weight = weights[IS_WEIGHT_POSITIONAL ? i : current]; - } - const float scale = weight * scale_bias[0]; - const float bias = weight * scale_bias[1]; - - for (const auto j : c10::irange(block_size)) { - out[j] += scale * input[fused_block_size * indices[current] + j] + bias; - } - - ++current; - } - if (normalize_by_lengths && lengths[m]) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float scale = 1.f / lengths[m]; - for (const auto j : c10::irange(block_size)) { - out[j] *= scale; - } - } - out += block_size; - } - return current == index_size; -} - -// clang-format off -// Proxy back to generic implementation -#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \ - bool \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - return Fused8BitRowwiseEmbeddingLookupGenericSlow< \ - IndexType, \ - uint8_t, \ - OutType, \ - false>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - decltype( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base) \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \ - bool Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - const int32_t one = 1; \ - CAFFE_ENFORCE_EQ( \ - reinterpret_cast(&one)[0], \ - 1, \ - "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ - AVX2_FMA_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - template <> \ - void Fused8BitRowwiseEmbeddingLookup( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - bool success = \ - Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ - if (success) { \ - return; \ - } \ - int64_t current = 0; \ - for (int m = 0; m < output_size; ++m) { \ - for (int i = 0; i < lengths[m]; ++i) { \ - CAFFE_ENFORCE_LT(current, index_size); \ - IndexType idx = indices[current]; \ - CAFFE_ENFORCE( \ - 0 <= idx && idx < data_size, \ - "Index ", \ - current, \ - " is out of bounds: ", \ - idx, \ - ", range 0 to ", \ - data_size); \ - ++current; \ - } \ - } \ - CAFFE_ENFORCE_EQ( \ - current, \ - index_size, \ - "Your input seems to be incorrect: the sum of lengths values should be " \ - "the size of the indices tensor, but it appears not."); \ - } -// clang-format on - -FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, float); -FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, float); - -#undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION - -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h deleted file mode 100644 index cfaab0d361b1..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h +++ /dev/null @@ -1,55 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -/** - * Embedding lookup with reduction. - * - * `input` of size data_size * (block_size + 8B) - * `indices` of size index_size - * `lengths` of size output_size - * `weights` nullptr or array of size index_size - * `out` of size output_size * block_size - * sum(lengths[i]) == index_size - * - * Note that block_size should be the number of quantized values per row in the - * data, i.e. excluding the scale and bias. The total (fused) block size is - * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias. - * - * Behavior is roughly equivalent to pseudocode: - * - * pos = 0 - * fused_block_size = block_size + 8B // quantized values and scale and bias - * for (i = 0..output_size-1) - * for (k = 0..block_size-1) - * out[i*block_size + k] = 0 - * for (j = 0..lengths[i]-1) - * for (k = 0..block_size-1) - * out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] * - * (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0) - * pos += 1 - * if (normalize_weights && lengths[i] > 0) - * for (k = 0..block_size-1) - * out[i*block_size + k] /= lengths[i] - * - */ - -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -void Fused8BitRowwiseEmbeddingLookup( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t data_size, - const InType* input, - const IndexType* indices, - const int* lengths, - const float* weights, // optional, can be null for non-weighted sum - bool normalize_by_lengths, - OutType* out); -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc deleted file mode 100644 index 866298226af0..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc +++ /dev/null @@ -1,214 +0,0 @@ -#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h" - -#include "caffe2/core/types.h" -#include "caffe2/perfkernels/common.h" -#include "caffe2/utils/cpuid.h" - -#include - -namespace caffe2 { - -/** - * Base implementation does runtime dispatch for each segment of reduction - * @return false if there is an out-of-bound error - */ -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx( - const int64_t block_size, - const int64_t output_size, - const int64_t index_size, - const int64_t data_size, - const InType* input, - const IndexType* indices, - const IndexType* offsets, - const float* weights, // optional, can be null for sum reducer - bool normalize_by_lengths, - OutType* out) { - // block_size is the number of elements and fused_block_size is the size of - // an entire row, including scale and bias. - const auto scale_bias_offset = 8 / sizeof(InType); - const int64_t fused_block_size = block_size + scale_bias_offset; - int64_t current = 0; - for (const auto m : c10::irange(output_size)) { - memset(out, 0, sizeof(OutType) * block_size); - if (current != offsets[m] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[m]; - int64_t end_offset = offsets[m + 1]; - int64_t length = end_offset - start_offset; - for (const auto i : c10::irange(start_offset, end_offset)) { - int64_t idx = indices[current]; - if (idx < 0 || idx >= data_size) { - return false; - } -#ifdef __GNUC__ - if (current + 1 < index_size) { - __builtin_prefetch( - input + fused_block_size * indices[current + 1], 0, 1); - } -#endif // __GNUC__ - - const float* scale_bias = reinterpret_cast( - input + fused_block_size * indices[current] + block_size); - - float weight = 1.0f; - if (weights) { - weight = weights[IS_WEIGHT_POSITIONAL ? i : current]; - } - const float scale = weight * scale_bias[0]; - const float bias = weight * scale_bias[1]; - - for (const auto j : c10::irange(block_size)) { - out[j] += scale * input[fused_block_size * indices[current] + j] + bias; - } - - ++current; - } - if (normalize_by_lengths && length) { - float scale = 1.f / length; - for (const auto j : c10::irange(block_size)) { - out[j] *= scale; - } - } - out += block_size; - } - return current == index_size; -} - -// clang-format off -// Proxy back to generic implementation -#define FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(IndexType, OutType) \ - bool \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const IndexType* offsets, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - return Fused8BitRowwiseEmbeddingLookupGenericSlowIdx< \ - IndexType, \ - uint8_t, \ - OutType, \ - false>( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - decltype( \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base) \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \ - bool Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const IndexType* offsets, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - const int32_t one = 1; \ - CAFFE_ENFORCE_EQ( \ - reinterpret_cast(&one)[0], \ - 1, \ - "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ - AVX2_FMA_DO( \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - BASE_DO( \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false, \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - } \ - template <> \ - void Fused8BitRowwiseEmbeddingLookupIdx( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const IndexType* indices, \ - const IndexType* offsets, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - bool success = \ - Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - offsets, \ - weights, \ - normalize_by_lengths, \ - out); \ - if (success) { \ - return; \ - } \ - int64_t current = 0; \ - for (int m = 0; m < output_size; ++m) { \ - for (int64_t i = offsets[m]; i < offsets[m + 1]; ++i) { \ - CAFFE_ENFORCE_LT(current, index_size); \ - IndexType idx = indices[current]; \ - CAFFE_ENFORCE( \ - 0 <= idx && idx < data_size, \ - "Index ", \ - current, \ - " is out of bounds: ", \ - idx, \ - ", range 0 to ", \ - data_size); \ - ++current; \ - } \ - } \ - CAFFE_ENFORCE_EQ( \ - current, \ - index_size, \ - "Your input seems to be incorrect: the sum of lengths values should be " \ - "the size of the indices tensor, but it appears not."); \ - } -// clang-format on - -FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int32_t, float); -FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int64_t, float); - -#undef FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION - -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h deleted file mode 100644 index f7422bd7b752..000000000000 --- a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h +++ /dev/null @@ -1,57 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -/** - * Embedding lookup with reduction. - * - * `input` of size data_size * (block_size + 8B) - * `indices` of size index_size - * `offsets` of size output_size - * `weights` nullptr or array of size index_size - * `out` of size output_size * block_size - * - * Note that block_size should be the number of quantized values per row in the - * data, i.e. excluding the scale and bias. The total (fused) block size is - * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias. - * - * Behavior is roughly equivalent to pseudocode: - * - * pos = 0 - * fused_block_size = block_size + 8B // quantized values and scale and bias - * for (i = 0..output_size-1) - * for (k = 0..block_size-1) - * out[i*block_size + k] = 0 - * start_offset = offsets[i] - * end_offset = i == output_size-1 ? index_size : offsets[i+1] - 1 - * length = end_offset - start_offset - * for (j = start_offset..end_offset) - * for (k = 0..block_size-1) - * out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] * - * (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0) - * pos += 1 - * if (normalize_weights && length > 0) - * for (k = 0..block_size-1) - * out[i*block_size + k] /= length - * - */ - -template < - typename IndexType, - typename InType, - typename OutType, - bool IS_WEIGHT_POSITIONAL = false> -void Fused8BitRowwiseEmbeddingLookupIdx( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t data_size, - const InType* input, - const IndexType* indices, - const IndexType* offsets, - const float* weights, // optional, can be null for non-weighted sum - bool normalize_by_lengths, - OutType* out); -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_nbit_rowwise_conversion.cc b/caffe2/perfkernels/fused_nbit_rowwise_conversion.cc deleted file mode 100644 index 05cae2e280be..000000000000 --- a/caffe2/perfkernels/fused_nbit_rowwise_conversion.cc +++ /dev/null @@ -1,214 +0,0 @@ -#include "./fused_nbit_rowwise_conversion.h" - -#include -#include -#include - -#include "common.h" - -#ifdef USE_FBGEMM -#include "fbgemm/QuantUtils.h" -#endif - -namespace caffe2 { - -void FloatToFused8BitRowwiseQuantized__base( - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { - constexpr float kEpsilon = 1e-8f; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - int output_columns = input_columns + 2 * sizeof(float); - for (std::size_t row = 0; row < input_rows; ++row) { - const float* input_row = input + row * input_columns; - std::uint8_t* output_row = output + row * output_columns; - float* output_row_scale_bias = - reinterpret_cast(output_row + input_columns); - - float minimum_element = - *std::min_element(input_row, input_row + input_columns); - float maximum_element = - *std::max_element(input_row, input_row + input_columns); - float range = maximum_element - minimum_element; - - output_row_scale_bias[0] = range / 255.0f; - output_row_scale_bias[1] = minimum_element; - const auto inverse_scale = 255.0f / (range + kEpsilon); - for (std::size_t col = 0; col < static_cast(input_columns); ++col) { - output_row[col] = - std::lrintf((input_row[col] - minimum_element) * inverse_scale); - } - } -} - -void Fused8BitRowwiseQuantizedToFloat__base( - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - int output_columns = input_columns - 2 * sizeof(float); - - for (std::size_t row = 0; row < input_rows; ++row) { - const std::uint8_t* input_row = input + row * input_columns; - const float* input_row_scale_bias = - reinterpret_cast(input_row + output_columns); - float* output_row = output + row * output_columns; - - for (std::size_t col = 0; col < static_cast(output_columns); ++col) { - output_row[col] = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; - } - } -} - -void FloatToFused8BitRowwiseQuantized( - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { -#ifdef USE_FBGEMM - fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( - input, input_rows, input_columns, output); -#else - FloatToFused8BitRowwiseQuantized__base( - input, input_rows, input_columns, output); -#endif -} - -void Fused8BitRowwiseQuantizedToFloat( - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { -#ifdef USE_FBGEMM - fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( - input, input_rows, input_columns, output); -#else - Fused8BitRowwiseQuantizedToFloat__base( - input, input_rows, input_columns, output); -#endif -} - -void FloatToFusedNBitRowwiseQuantizedSBHalf__base( - int bit_rate, - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { - int num_elem_per_byte = 8 / bit_rate; - int output_columns = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (input_columns + num_elem_per_byte - 1) / num_elem_per_byte + - 2 * sizeof(at::Half); - for (std::size_t row = 0; row < input_rows; ++row) { - const float* input_row = input + row * input_columns; - std::uint8_t* output_row = output + row * output_columns; - at::Half* output_row_scale_bias = reinterpret_cast( - output_row + - (input_columns + num_elem_per_byte - 1) / num_elem_per_byte); - - float minimum_element = - *std::min_element(input_row, input_row + input_columns); - float maximum_element = - *std::max_element(input_row, input_row + input_columns); - - minimum_element = static_cast(minimum_element); - const float range = maximum_element - minimum_element; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - at::Half scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1); - if (scale == 0) { - // Corner case handling when maximum_element == minimum_element - // Any scale would work because X - minimum_element will be 0 for all X - scale = 1.0f; - } - float inverse_scale = 1.0f / scale; - if (std::isinf(inverse_scale)) { - scale = 1.0f; - inverse_scale = 1.0f; - } - - output_row_scale_bias[0] = scale; - output_row_scale_bias[1] = minimum_element; - for (std::size_t col = 0; col < static_cast(input_columns); ++col) { - float X = input_row[col]; - std::uint8_t quantized = std::max( - 0, - std::min( - std::lrintf((X - minimum_element) * inverse_scale), - (1 << bit_rate) - 1)); - if (col % num_elem_per_byte == 0) { - output_row[col / num_elem_per_byte] = quantized; - } else { - output_row[col / num_elem_per_byte] |= - (quantized << ((col % num_elem_per_byte) * bit_rate)); - } - } - } -} - -void FusedNBitRowwiseQuantizedSBHalfToFloat__base( - int bit_rate, - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { - int num_elem_per_byte = 8 / bit_rate; - int output_columns = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte; - - for (std::size_t row = 0; row < static_cast(input_rows); ++row) { - const std::uint8_t* input_row = input + row * input_columns; - const at::Half* input_row_scale_bias = reinterpret_cast( - input_row + - (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); - float scale = input_row_scale_bias[0]; - float bias = input_row_scale_bias[1]; - float* output_row = output + row * output_columns; - - for (std::size_t col = 0; col < static_cast(output_columns); ++col) { - std::uint8_t quantized = input_row[col / num_elem_per_byte]; - quantized >>= (col % num_elem_per_byte) * bit_rate; - quantized &= (1 << bit_rate) - 1; - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - output_row[col] = scale * quantized + bias; - } - } -} - -void FloatToFusedNBitRowwiseQuantizedSBHalf( - int bit_rate, - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output) { -#ifdef USE_FBGEMM - fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( - bit_rate, input, input_rows, input_columns, output); -#else - FloatToFusedNBitRowwiseQuantizedSBHalf__base( - bit_rate, input, input_rows, input_columns, output); -#endif -} - -void FusedNBitRowwiseQuantizedSBHalfToFloat( - int bit_rate, - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output) { -#ifdef USE_FBGEMM - fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( - bit_rate, input, input_rows, input_columns, output); -#else - FusedNBitRowwiseQuantizedSBHalfToFloat__base( - bit_rate, input, input_rows, input_columns, output); -#endif -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_nbit_rowwise_conversion.h b/caffe2/perfkernels/fused_nbit_rowwise_conversion.h deleted file mode 100644 index da9ec5c6cdd6..000000000000 --- a/caffe2/perfkernels/fused_nbit_rowwise_conversion.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include -#include - -namespace caffe2 { - -void FloatToFused8BitRowwiseQuantized( - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output); - -void Fused8BitRowwiseQuantizedToFloat( - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output); - -/** - * Row-wise quantization with fp16 scale and bias - * - * @param bit_rate can be 2, 4, or 8 - */ -void FloatToFusedNBitRowwiseQuantizedSBHalf( - int bit_rate, - const float* input, - size_t input_rows, - int input_columns, - std::uint8_t* output); - -void FusedNBitRowwiseQuantizedSBHalfToFloat( - int bit_rate, - const std::uint8_t* input, - size_t input_rows, - int input_columns, - float* output); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 7e4208caf655..26018c2c002c 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse diff --git a/caffe2/perfkernels/lstm_unit_cpu-impl.h b/caffe2/perfkernels/lstm_unit_cpu-impl.h deleted file mode 100644 index 239d2807f778..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu-impl.h +++ /dev/null @@ -1,141 +0,0 @@ -#pragma once -#include -#include -#include -#include "c10/util/irange.h" -#include "caffe2/utils/conversions.h" - -#include "vectorizer.h" - -namespace caffe2 { -namespace perfkernels { -namespace { -template -inline T sigmoid(T x) { - return 1 / (1 + std::exp(-x)); -} - -template -inline T host_tanh(T x) { - return 2 * sigmoid(2 * x) - 1; -} - -template -inline void LstmUnitImpl( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias) { - const T forgetBias = convert::To(forget_bias); - for (const auto n : c10::irange(N)) { - const bool valid = seqLengths == nullptr || t < seqLengths[n]; - if (!valid) { - if (drop_states) { - memset(H, 0, sizeof(T) * D); - memset(C, 0, sizeof(T) * D); - } else { - memcpy(H, H_prev, sizeof(T) * D); - memcpy(C, C_prev, sizeof(T) * D); - } - } else { - const T* X_D = &X[D]; - const T* X_2D = &X[2 * D]; - const T* X_3D = &X[3 * D]; - VECTOR_LOOP for (const auto d : c10::irange(D)) { - const T i = sigmoid(X[d]); - const T f = sigmoid(X_D[d] + forgetBias); - const T o = sigmoid(X_2D[d]); - const T g = host_tanh(X_3D[d]); - const T c_prev = C_prev[d]; - const T c = f * c_prev + i * g; - C[d] = c; - const T host_tanh_c = host_tanh(c); - H[d] = o * host_tanh_c; - } - } - H_prev += D; - C_prev += D; - X += 4 * D; - C += D; - H += D; - } -} - -template -inline void LstmUnitGradientImpl( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias) { - const T localForgetBias = convert::To(forget_bias); - for (const auto n : c10::irange(N)) { - const bool valid = seqLengths == nullptr || t < seqLengths[n]; - - if (!valid) { - if (drop_states) { - memset(C_prev_diff, 0, sizeof(T) * D); - memset(H_prev_diff, 0, sizeof(T) * D); - } else { - memcpy(H_prev_diff, H_diff, sizeof(T) * D); - memcpy(C_prev_diff, C_diff, sizeof(T) * D); - } - memset(X_diff, 0, 4 * sizeof(T) * D); - } else { - VECTOR_LOOP for (const auto d : c10::irange(D)) { - T* c_prev_diff = C_prev_diff + d; - T* h_prev_diff = H_prev_diff + d; - T* i_diff = X_diff + d; - T* f_diff = X_diff + 1 * D + d; - T* o_diff = X_diff + 2 * D + d; - T* g_diff = X_diff + 3 * D + d; - - const T i = sigmoid(X[d]); - const T f = sigmoid(X[1 * D + d] + localForgetBias); - const T o = sigmoid(X[2 * D + d]); - const T g = host_tanh(X[3 * D + d]); - const T c_prev = C_prev[d]; - const T c = C[d]; - const T host_tanh_c = host_tanh(c); - const T c_term_diff = - C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c); - *c_prev_diff = c_term_diff * f; - *h_prev_diff = 0; // not used in 'valid' case - *i_diff = c_term_diff * g * i * (1 - i); - *f_diff = c_term_diff * c_prev * f * (1 - f); - *o_diff = H_diff[d] * host_tanh_c * o * (1 - o); - *g_diff = c_term_diff * i * (1 - g * g); - } - } - C_prev += D; - X += 4 * D; - C += D; - H += D; - C_diff += D; - H_diff += D; - X_diff += 4 * D; - H_prev_diff += D; - C_prev_diff += D; - } -} - -} // namespace -} // namespace perfkernels -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu.h b/caffe2/perfkernels/lstm_unit_cpu.h deleted file mode 100644 index e9c87f3082f9..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu.h +++ /dev/null @@ -1,73 +0,0 @@ -#pragma once -#include - -namespace caffe2 { -namespace detail { - -// Forward declration of the LSTMUnit templated -// implementation -template -void LstmUnitCpu( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias); - -// Forward specialization -extern template void LstmUnitCpu( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template -void LstmUnitGradientCpu( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias); - -extern template void LstmUnitGradientCpu( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace detail -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu_avx2.cc b/caffe2/perfkernels/lstm_unit_cpu_avx2.cc deleted file mode 100644 index ac66c6bd3f52..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu_avx2.cc +++ /dev/null @@ -1,123 +0,0 @@ -#include "caffe2/perfkernels/lstm_unit_cpu-impl.h" - -namespace caffe2 { -namespace perfkernels { -namespace { -// Explicit initialize for float and AVX2 vectorization -template void LstmUnitImpl( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template void LstmUnitGradientImpl( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); -} // namespace - -// Define templated implementation fo LSTM kernels on CPU supporting AVX2 -template -void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias) { - LstmUnitImpl( - N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias); -} - -template -void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias) { - LstmUnitGradientImpl( - N, - D, - t, - C_prev, - X, - seqLengths, - C, - H, - C_diff, - H_diff, - drop_states, - H_prev_diff, - C_prev_diff, - X_diff, - forget_bias); -} - -// Explicit initialize for float -template void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace perfkernels -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu_common.cc b/caffe2/perfkernels/lstm_unit_cpu_common.cc deleted file mode 100644 index 72d97d832625..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu_common.cc +++ /dev/null @@ -1,125 +0,0 @@ -#include "caffe2/perfkernels/lstm_unit_cpu_common.h" -#include "caffe2/perfkernels/common.h" -#include "caffe2/perfkernels/lstm_unit_cpu-impl.h" - -namespace caffe2 { -namespace detail { - -// Define templated implementation fo LSTM kernels on CPU -template -void LstmUnitCpu( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias) { - // Do CPU dispatching - AVX2_FMA_DO( - perfkernels::LstmUnitImpl, - N, - D, - t, - H_prev, - C_prev, - X, - seqLengths, - drop_states, - C, - H, - forget_bias); - perfkernels::LstmUnitImpl( - N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias); -} - -template -void LstmUnitGradientCpu( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias) { - // Do CPU dispatching - AVX2_FMA_DO( - perfkernels::LstmUnitGradientImpl, - N, - D, - t, - C_prev, - X, - seqLengths, - C, - H, - C_diff, - H_diff, - drop_states, - H_prev_diff, - C_prev_diff, - X_diff, - forget_bias); - perfkernels::LstmUnitGradientImpl( - N, - D, - t, - C_prev, - X, - seqLengths, - C, - H, - C_diff, - H_diff, - drop_states, - H_prev_diff, - C_prev_diff, - X_diff, - forget_bias); -} - -// Explicit initialize for float -template void LstmUnitCpu( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -template void LstmUnitGradientCpu( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace detail -} // namespace caffe2 diff --git a/caffe2/perfkernels/lstm_unit_cpu_common.h b/caffe2/perfkernels/lstm_unit_cpu_common.h deleted file mode 100644 index d8680adf7d1d..000000000000 --- a/caffe2/perfkernels/lstm_unit_cpu_common.h +++ /dev/null @@ -1,71 +0,0 @@ -#pragma once -#include - -namespace caffe2 { -namespace perfkernels { - -template -void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const T* H_prev, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const bool drop_states, - T* C, - T* H, - const float forget_bias); - -template -void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const T* C_prev, - const T* X, - const int32_t* seqLengths, - const T* C, - const T* H, - const T* C_diff, - const T* H_diff, - bool drop_states, - T* H_prev_diff, - T* C_prev_diff, - T* X_diff, - const float forget_bias); - -// Forward declaration of specialized functions -extern template void LstmUnitImpl__avx2_fma( - const int N, - const int D, - const int t, - const float* H_prev, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const bool drop_states, - float* C, - float* H, - const float forget_bias); - -extern template void LstmUnitGradientImpl__avx2_fma( - int N, - int D, - int t, - const float* C_prev, - const float* X, - const int32_t* seqLengths, - const float* C, - const float* H, - const float* C_diff, - const float* H_diff, - bool drop_states, - float* H_prev_diff, - float* C_prev_diff, - float* X_diff, - const float forget_bias); - -} // namespace perfkernels -} // namespace caffe2 diff --git a/caffe2/perfkernels/math.h b/caffe2/perfkernels/math.h deleted file mode 100644 index 63380fc3f9a1..000000000000 --- a/caffe2/perfkernels/math.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -namespace math { - -// Returns the quantized and compressed values of floating inputs -// The "fused" representation stores the [bitwidth][tail][min][max] -// with the quantized data in one array. Since we store 8/bitwidth -// quantized data in one byte, the last buckets of some bytes may have -// unused bits. There are totally tail buckets are unused. -// We encode *bitwidth* and *tail* at the beginning, -// following by 32-bit floating data respresenting min and max. -// | bitwidth | tail | min | max | ... int8 data ... | -// | 1B | 1B | 4B | 4B | ...output_data....| -// In output_data: the b-th bucket of the i-th byte stores -// the i-th data of the b-th segment of input row - -void quantize_and_compress( - const float* input_data, - std::uint8_t* output_data, - std::uint64_t input_size, - std::uint64_t bitwidth, - bool random, - const float* random_buffer); - -void decompress_and_dequantize( - const std::uint8_t* input_data, - float* output_data, - std::uint64_t input_size); - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/perfkernels/math_cpu_avx2.cc b/caffe2/perfkernels/math_cpu_avx2.cc deleted file mode 100644 index 325d9c4591ef..000000000000 --- a/caffe2/perfkernels/math_cpu_avx2.cc +++ /dev/null @@ -1,246 +0,0 @@ -// Implements the math functions for CPU. -// The implementation in this file allows us to route the underlying numerical -// computation library to different compiler options (-mno-avx2 or -mavx2). - -#include -#include -#include - -#include - -using std::uint64_t; -using std::uint8_t; - -namespace caffe2 { - -namespace math { - -static constexpr double QEPSILON = 1e-8; - -void quantize_and_compress__avx2( - const float* input_data, - uint8_t* output_data, - uint64_t input_size, - uint64_t bitwidth, - bool random, - const float* random_buffer) { - __m256i shuffle_mask_v = _mm256_set_epi8( - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - 0x0c, - 0x08, - 0x04, - 0x00, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - 0xff, - 0x0c, - 0x08, - 0x04, - 0x00); - __m256i permute_mask_v = - _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - - uint64_t data_per_byte = 8 / bitwidth; - uint64_t tail = input_size % data_per_byte; - tail = tail ? data_per_byte - tail : 0; - uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; - - // basic info - float minimum_element = INFINITY, maximum_element = -INFINITY; - for (const auto i : c10::irange(input_size)) { - minimum_element = - (input_data[i] < minimum_element) ? input_data[i] : minimum_element; - maximum_element = - (input_data[i] > maximum_element) ? input_data[i] : maximum_element; - } - output_data[0] = bitwidth; - output_data[1] = tail; - reinterpret_cast(output_data + 2)[0] = minimum_element; - reinterpret_cast(output_data + 2)[1] = maximum_element; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap_inverse = 1. / (gap + QEPSILON); - uint8_t max_q = (1 << bitwidth) - 1; - uint64_t bit_start = 0; - if (random) { - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - constexpr int VLEN = 8; - for (; i < stride / VLEN * VLEN; i += VLEN) { - __m256 r_v = _mm256_loadu_ps(&random_buffer[start + i]); - __m256 fval_v = _mm256_loadu_ps(input_data + start + i); - __m256 thetimes_v = _mm256_mul_ps( - _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)), - _mm256_set1_ps(gap_inverse)); - __m256 rounded_v = _mm256_floor_ps(_mm256_add_ps(thetimes_v, r_v)); - rounded_v = _mm256_max_ps( - _mm256_setzero_ps(), - _mm256_min_ps(_mm256_set1_ps(max_q), rounded_v)); - __m256i qval_v = _mm256_cvtps_epi32(rounded_v); - __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128( - reinterpret_cast(output_data + 10 + i))); - orval_v = - _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start)); - orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v); - orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v); - *reinterpret_cast(output_data + 10 + i) = - _mm256_extract_epi64(orval_v, 0); - } - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - float rounded = floor(thetimes + random_buffer[start + i]); - rounded = rounded < static_cast(max_q) - ? rounded - : static_cast(max_q); - rounded = rounded > 0.0f ? rounded : 0.0f; - uint8_t qval = rounded; - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } else { - // !random - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - constexpr int VLEN = 8; - for (; i < stride / VLEN * VLEN; i += VLEN) { - __m256 fval_v = _mm256_loadu_ps(input_data + start + i); - __m256 thetimes_v = _mm256_mul_ps( - _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)), - _mm256_set1_ps(gap_inverse)); - thetimes_v = _mm256_max_ps( - _mm256_setzero_ps(), - _mm256_min_ps(_mm256_set1_ps(max_q), thetimes_v)); - __m256i qval_v = _mm256_cvtps_epi32(_mm256_round_ps( - thetimes_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128( - reinterpret_cast(output_data + 10 + i))); - orval_v = - _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start)); - orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v); - orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v); - *reinterpret_cast(output_data + 10 + i) = - _mm256_extract_epi64(orval_v, 0); - } - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - thetimes = thetimes < static_cast(max_q) - ? thetimes - : static_cast(max_q); - thetimes = thetimes > 0.0f ? thetimes : 0.0f; - uint8_t qval = nearbyint(thetimes); - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } // !random -} - -void decompress_and_dequantize__avx2( - const uint8_t* input_data, - float* output_data, - uint64_t input_size) { - // basic info - const float minimum_element = - reinterpret_cast(input_data + 2)[0]; - const float maximum_element = - reinterpret_cast(input_data + 2)[1]; - const uint64_t bitwidth = input_data[0]; - const float gap = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) + - QEPSILON; // for exact recovering - - const uint64_t tail = input_data[1]; - - const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail; - // decoding - uint64_t bit_start = 0; - const uint64_t segment_size = input_size - 10; - for (uint64_t start = 0; start < output_size; start += segment_size) { - uint64_t stride = start + segment_size <= output_size ? segment_size - : output_size - start; - uint8_t mask = (1 << bitwidth) - 1; - uint64_t i = 0; - // Can process 8 elements at a time because we need to expand uint8_t - // to int32_t to use epi32 vector instructions. - constexpr int VLEN = 8; - for (; i < stride / VLEN * VLEN; i += VLEN) { - __m128i in_v = _mm_lddqu_si128( - reinterpret_cast(input_data + 10 + i)); - __m256i out_epi32_v = _mm256_and_si256( - _mm256_srli_epi32(_mm256_cvtepu8_epi32(in_v), bit_start), - _mm256_set1_epi32(mask)); - __m256 out_v = _mm256_fmadd_ps( - _mm256_cvtepi32_ps(out_epi32_v), - _mm256_set1_ps(gap), - _mm256_set1_ps(minimum_element)); - _mm256_storeu_ps(output_data + start + i, out_v); - } - for (; i < stride; ++i) { - output_data[start + i] = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element; - } - bit_start += bitwidth; - } -} - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/perfkernels/math_cpu_base.cc b/caffe2/perfkernels/math_cpu_base.cc deleted file mode 100644 index fd3ba83cd4a9..000000000000 --- a/caffe2/perfkernels/math_cpu_base.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Implements the math functions for CPU. -// The implementation in this file allows us to route the underlying numerical -// computation library to different compiler options (-mno-avx2 or -mavx2). - -#include -#include -#include - -#include "common.h" -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include "math.h" - -#include - -using std::uint64_t; -using std::uint8_t; - -namespace caffe2 { - -namespace math { - -static constexpr double QEPSILON = 1e-8; - -void quantize_and_compress__base( - const float* input_data, - uint8_t* output_data, - uint64_t input_size, - uint64_t bitwidth, - bool random, - const float* random_buffer) { - uint64_t data_per_byte = 8 / bitwidth; - uint64_t tail = input_size % data_per_byte; - tail = tail ? data_per_byte - tail : 0; - uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte; - - // basic info - float minimum_element = INFINITY, maximum_element = -INFINITY; - for (const auto i : c10::irange(input_size)) { - minimum_element = - input_data[i] < minimum_element ? input_data[i] : minimum_element; - maximum_element = - input_data[i] > maximum_element ? input_data[i] : maximum_element; - } - output_data[0] = bitwidth; - output_data[1] = tail; - reinterpret_cast(output_data + 2)[0] = minimum_element; - reinterpret_cast(output_data + 2)[1] = maximum_element; - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f); - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - float gap_inverse = 1. / (gap + QEPSILON); - uint8_t max_q = (1 << bitwidth) - 1; - uint64_t bit_start = 0; - if (random) { - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - float rounded = floor(thetimes + random_buffer[start + i]); - rounded = rounded < static_cast(max_q) - ? rounded - : static_cast(max_q); - rounded = rounded > 0.0f ? rounded : 0.0f; - uint8_t qval = rounded; - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } else { - for (uint64_t start = 0; start < input_size; start += segment_size) { - uint64_t stride = start + segment_size <= input_size ? segment_size - : input_size - start; - uint64_t i = 0; - for (; i < stride; ++i) { - float fval = input_data[start + i]; - float thetimes = (fval - minimum_element) * gap_inverse; - thetimes = thetimes < static_cast(max_q) - ? thetimes - : static_cast(max_q); - thetimes = thetimes > 0.0f ? thetimes : 0.0f; - uint8_t qval = nearbyint(thetimes); - - uint8_t orval = output_data[10 + i]; - output_data[10 + i] = orval | static_cast(qval << bit_start); - } - bit_start += bitwidth; - } - } -} - -decltype(quantize_and_compress__base) quantize_and_compress__avx2; -void quantize_and_compress( - const float* input_data, - uint8_t* output_data, - uint64_t input_size, - uint64_t bitwidth, - bool random, - const float* random_buffer) { - AVX2_DO( - quantize_and_compress, - input_data, - output_data, - input_size, - bitwidth, - random, - random_buffer); - BASE_DO( - quantize_and_compress, - input_data, - output_data, - input_size, - bitwidth, - random, - random_buffer); -} - -void decompress_and_dequantize__base( - const uint8_t* input_data, - float* output_data, - uint64_t input_size) { - // basic info - const float minimum_element = - reinterpret_cast(input_data + 2)[0]; - const float maximum_element = - reinterpret_cast(input_data + 2)[1]; - const uint64_t bitwidth = input_data[0]; - const float gap = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) + - QEPSILON; // for exact recovering - - const uint64_t tail = input_data[1]; - - const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail; - // decoding - uint64_t bit_start = 0; - const uint64_t segment_size = input_size - 10; - for (uint64_t start = 0; start < output_size; start += segment_size) { - uint64_t stride = start + segment_size <= output_size ? segment_size - : output_size - start; - uint8_t mask = (1 << bitwidth) - 1; - uint64_t i = 0; - for (; i < stride; ++i) { - output_data[start + i] = - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) - ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element; - } - bit_start += bitwidth; - } -} - -decltype(decompress_and_dequantize__base) decompress_and_dequantize__avx2; -void decompress_and_dequantize( - const uint8_t* input_data, - float* output_data, - uint64_t input_size) { - AVX2_DO(decompress_and_dequantize, input_data, output_data, input_size); - BASE_DO(decompress_and_dequantize, input_data, output_data, input_size); -} - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy.cc b/caffe2/perfkernels/typed_axpy.cc deleted file mode 100644 index b8128ab951a4..000000000000 --- a/caffe2/perfkernels/typed_axpy.cc +++ /dev/null @@ -1,89 +0,0 @@ -#include "caffe2/perfkernels/typed_axpy.h" -#include "caffe2/core/types.h" -#include "caffe2/perfkernels/common.h" -#include "caffe2/utils/cpuid.h" - -namespace caffe2 { - -void TypedAxpy__base(int N, const float a, const float* x, float* y) { - for (int i = 0; i < N; ++i) { - y[i] += a * x[i]; - } -} - -decltype(TypedAxpy__base) TypedAxpy__avx2_fma; -decltype(TypedAxpy__base) TypedAxpy__avx_f16c; -template <> -void TypedAxpy(int N, const float a, const float* x, float* y) { - AVX2_FMA_DO(TypedAxpy, N, a, x, y); - AVX_F16C_DO(TypedAxpy, N, a, x, y); - BASE_DO(TypedAxpy, N, a, x, y); -} - -void TypedAxpyHalffloat__base( - int N, - const float a, - const at::Half* x, - float* y) { - for (int i = 0; i < N; ++i) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - union { - uint32_t intval; - float floatval; - } t1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t t2, t3; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t1.intval = x[i].x & 0x7fff; // Non-sign bits - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t2 = x[i].x & 0x8000; // Sign bit - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t3 = x[i].x & 0x7c00; // Exponent - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t1.intval <<= 13; // Align mantissa on MSB - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t2 <<= 16; // Shift sign bit into position - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - t1.intval += 0x38000000; // Adjust bias - t1.intval = (t3 == 0 ? 0 : t1.intval); // Denormals-as-zero - t1.intval |= t2; // Re-insert sign bit - y[i] += t1.floatval * a; - } -} - -decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx2_fma; -decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx_f16c; -template <> -void TypedAxpy( - int N, - const float a, - const at::Half* x, - float* y) { - AVX2_FMA_DO(TypedAxpyHalffloat, N, a, x, y); - AVX_F16C_DO(TypedAxpyHalffloat, N, a, x, y); - BASE_DO(TypedAxpyHalffloat, N, a, x, y); -} - -void TypedAxpy_uint8_float__base( - int N, - const float a, - const std::uint8_t* x, - float* y) { - for (int i = 0; i < N; ++i) { - y[i] += (float)(x[i]) * a; - } -} - -decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx2_fma; -decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx_f16c; -template <> -void TypedAxpy( - int N, - const float a, - const std::uint8_t* x, - float* y) { - AVX2_FMA_DO(TypedAxpy_uint8_float, N, a, x, y); - BASE_DO(TypedAxpy_uint8_float, N, a, x, y); -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy.h b/caffe2/perfkernels/typed_axpy.h deleted file mode 100644 index 85b1adda0b9b..000000000000 --- a/caffe2/perfkernels/typed_axpy.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -namespace caffe2 { - -// Similar to Axpy that calculate y = a * x + y, but allowing x and y to be -// of different data types. -// It also provides a performance optimization hint (use_a) to see if a is going -// to be 1 or not. -template -void TypedAxpy(int N, const OUT a, const IN* x, OUT* y); - -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy_avx.cc b/caffe2/perfkernels/typed_axpy_avx.cc deleted file mode 100644 index 2663cbc3ec79..000000000000 --- a/caffe2/perfkernels/typed_axpy_avx.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include -#include - -namespace caffe2 { - -void TypedAxpy__avx_f16c(int N, const float a, const float* x, float* y) { - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - __m256 mma = _mm256_set1_ps(a); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - _mm256_storeu_ps( - y + current, - _mm256_add_ps( - _mm256_mul_ps(mma, _mm256_loadu_ps(x + current)), - _mm256_loadu_ps(y + current))); - } - - if (bound != N) { - while (current < N) { - y[current] += x[current] * a; - ++current; - } - } -} - -void TypedAxpyHalffloat__avx_f16c( - int N, - const float a, - const at::Half* x, - float* y) { - // if x does not start at the 16 byte boundary, we will process the first few. - // before we get to a real one. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - while ((reinterpret_cast(x) % 16) && N) { - *(y++) += _cvtsh_ss((*(x++)).x) * a; - --N; - } - - // From now on we can do vectorized additions using __m256, which is 8 floats, - // so we will vectorize every 8 element and then resort to cvtsh_ss. - __m256 mma = _mm256_set1_ps(a); - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - __m128i mmx_16 = - _mm_loadu_si128(reinterpret_cast(x + current)); - __m256 mmx_32 = _mm256_cvtph_ps(mmx_16); - __m256 mmy_in = _mm256_loadu_ps(y + current); - __m256 mmmul = _mm256_mul_ps(mmx_32, mma); - __m256 mmy_out = _mm256_add_ps(mmmul, mmy_in); - _mm256_storeu_ps(y + current, mmy_out); - } - - if (bound != N) { - while (current < N) { - y[current] += _cvtsh_ss(x[current].x) * a; - ++current; - } - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/typed_axpy_avx2.cc b/caffe2/perfkernels/typed_axpy_avx2.cc deleted file mode 100644 index 2da1e7e379bd..000000000000 --- a/caffe2/perfkernels/typed_axpy_avx2.cc +++ /dev/null @@ -1,104 +0,0 @@ -#include "caffe2/perfkernels/cvtsh_ss_bugfix.h" - -#include -#include -#include - -namespace caffe2 { - -void TypedAxpy__avx2_fma(int N, const float a, const float* x, float* y) { - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - __m256 mma = _mm256_set1_ps(a); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - _mm256_storeu_ps( - y + current, - _mm256_fmadd_ps( - mma, _mm256_loadu_ps(x + current), _mm256_loadu_ps(y + current))); - } - - if (bound != N) { - while (current < N) { - y[current] += x[current] * a; - ++current; - } - } -} - -void TypedAxpyHalffloat__avx2_fma( - int N, - const float a, - const at::Half* x, - float* y) { - // if x does not start at the 16 byte boundary, we will process the first few. - // before we get to a real one. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - while ((reinterpret_cast(x) % 16) && N) { - *(y++) += _cvtsh_ss((*(x++)).x) * a; - --N; - } - - // From now on we can do vectorized additions using __m256, which is 8 floats, - // so we will vectorize every 8 element and then resort to cvtsh_ss. - __m256 mma = _mm256_set1_ps(a); - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - __m128i mmx_16 = - _mm_loadu_si128(reinterpret_cast(x + current)); - __m256 mmx_32 = _mm256_cvtph_ps(mmx_16); - __m256 mmy = _mm256_loadu_ps(y + current); - mmy = _mm256_fmadd_ps(mmx_32, mma, mmy); - _mm256_storeu_ps(y + current, mmy); - } - - if (bound != N) { - while (current < N) { - y[current] += _cvtsh_ss(x[current].x) * a; - ++current; - } - } -} - -void TypedAxpy_uint8_float__avx2_fma( - int N, - const float a, - const std::uint8_t* x, - float* y) { - // if x does not start at the 16 byte boundary, we will process the first few. - // before we get to a real one. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - while ((reinterpret_cast(x) % 16) && N) { - *(y++) += static_cast(*(x++)) * a; - --N; - } - - // From now on we can do vectorized additions using __m256, which is 8 floats, - // so we will vectorize every 8 element and then resort to cvtsh_ss. - __m256 mma = _mm256_set1_ps(a); - int current = 0; - const int bound = (N % 8) ? N - 8 : N; - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - for (; current < bound; current += 8) { - __m256i mmx_int32 = _mm256_cvtepi8_epi32( - _mm_loadu_si128(reinterpret_cast(x + current))); - __m256 mmx_fp32 = _mm256_cvtepi32_ps(mmx_int32); - - __m256 mmy = _mm256_loadu_ps(y + current); - mmy = _mm256_fmadd_ps(mmx_fp32, mma, mmy); - _mm256_storeu_ps(y + current, mmy); - } - - if (bound != N) { - while (current < N) { - y[current] += (float)(x[current]) * a; - ++current; - } - } -} - -} // namespace caffe2 diff --git a/caffe2/perfkernels/vectorizer.h b/caffe2/perfkernels/vectorizer.h deleted file mode 100644 index be4e6bbc280f..000000000000 --- a/caffe2/perfkernels/vectorizer.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG) -#if defined(__clang__) && (__clang_major__ > 7) -#define IS_SANITIZER \ - ((__has_feature(address_sanitizer) == 1) || \ - (__has_feature(memory_sanitizer) == 1) || \ - (__has_feature(thread_sanitizer) == 1) || \ - (__has_feature(undefined_sanitizer) == 1)) - -#if IS_SANITIZER == 0 -#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)") -#define FAST_MATH _Pragma("clang fp contract(fast)") -#define VECTORIZED_KERNEL 1 -#endif -#elif defined(_OPENMP) && (_OPENMP >= 201511) -// Support with OpenMP4.5 and above -#define VECTOR_LOOP _Pragma("omp for simd") -#define VECTORIZED_KERNEL 1 -#define FAST_MATH -#endif -#endif - -#ifndef VECTOR_LOOP -// Not supported -#define VECTOR_LOOP -#define FAST_MATH -#endif diff --git a/caffe2/proto/BUILD.bazel b/caffe2/proto/BUILD.bazel deleted file mode 100644 index dcffaac0e3de..000000000000 --- a/caffe2/proto/BUILD.bazel +++ /dev/null @@ -1,55 +0,0 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library") - -cc_library( - name = "caffe2_pb", - hdrs = ["caffe2_pb.h"], - visibility = [ - "//:__pkg__", - ], - deps = [ - ":caffe2_cc_proto", - "//c10/core:base", - "//c10/util:base", - ], -) - -cc_proto_library( - name = "caffe2_cc_proto", - deps = [":caffe2_proto"], -) - -proto_library( - name = "caffe2_proto", - srcs = ["caffe2.proto"], -) - -cc_proto_library( - name = "torch_cc_proto", - visibility = ["//:__pkg__"], # used in torch - deps = [":torch_proto"], -) - -proto_library( - name = "torch_proto", - srcs = ["torch.proto"], - deps = [":caffe2_proto"], -) - -cc_proto_library( - name = "cc_proto", - visibility = ["//:__pkg__"], - deps = [":proto"], -) - -proto_library( - name = "proto", - srcs = [ - "caffe2_legacy.proto", - "hsm.proto", - "metanet.proto", - "predictor_consts.proto", - "prof_dag.proto", - ], - deps = [":caffe2_proto"], -) diff --git a/caffe2/proto/CMakeLists.txt b/caffe2/proto/CMakeLists.txt deleted file mode 100644 index bdbc045afb3d..000000000000 --- a/caffe2/proto/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -set(Caffe2_PROTOBUF_FILES "${CMAKE_CURRENT_SOURCE_DIR}/torch.proto;${CMAKE_CURRENT_SOURCE_DIR}/caffe2.proto") - -caffe2_protobuf_generate_cpp_py(Caffe2_PROTO_SRCS Caffe2_PROTO_HEADERS Caffe2_PROTO_PY ${Caffe2_PROTOBUF_FILES}) - -add_library(Caffe2_PROTO OBJECT ${Caffe2_PROTO_HEADERS} ${Caffe2_PROTO_SRCS}) - -if(MSVC) - if(BUILD_SHARED_LIBS) - set(TORCH_API_DEFINE "-DTORCH_API=__declspec(dllexport)") - else() - set(TORCH_API_DEFINE "-DTORCH_API=") - endif() -else() - set(TORCH_API_DEFINE "-DTORCH_API=") -endif() -target_compile_definitions( - Caffe2_PROTO PRIVATE ${TORCH_API_DEFINE}) - -install(FILES ${Caffe2_PROTO_HEADERS} DESTINATION include/caffe2/proto) diff --git a/caffe2/proto/__init__.py b/caffe2/proto/__init__.py deleted file mode 100644 index c40ca97189d1..000000000000 --- a/caffe2/proto/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -import warnings - - -# NOTE: we have to import python protobuf here **before** we load cpp extension. -# Otherwise it breaks under certain build conditions if cpp implementation of -# protobuf is used. Presumably there's some registry in protobuf library and -# python side has to initialize the dictionary first, before static -# initialization in python extension does so. Otherwise, duplicated protobuf -# descriptors will be created and it can lead to obscure errors like -# "Parameter to MergeFrom() must be instance of same class: -# expected caffe2.NetDef got caffe2.NetDef." -# -# This has to be done for all python targets, so listing them here -try: - from caffe2.proto import caffe2_pb2, metanet_pb2, torch_pb2 -except ImportError: - warnings.warn('Caffe2 support is no longer present in PyTorch.') - raise - -try: - from caffe2.caffe2.fb.session.proto import session_pb2 -except ImportError: - pass diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto deleted file mode 100644 index 077e7b0ed544..000000000000 --- a/caffe2/proto/caffe2.proto +++ /dev/null @@ -1,528 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// A few notes about the Caffe2's protobuffer convention: -// (1) Most objects are registered by their types, such as operators and nets. -// For these, we have a string-type field "type" for registration purposes. -// (2) We do not use extension because that used to create quite some conflicts -// in Caffe's protobuf design. -// (3) We have not used any proto3 specific features, such as Any or Map. This -// is mainly for backward compatibility purposes but we may consider using -// those in the future. - -// TensorProto stores serialized Tensor objects. -message TensorProto { - // The dimensions in the tensor. - repeated int64 dims = 1; - - // Data type - enum DataType { - UNDEFINED = 0; - - // Basic types - FLOAT = 1; // float - INT32 = 2; // int - BYTE = 3; // byte, when deserialized, is going to be restored as uint8 - STRING = 4; // string - - // Less-commonly used data types - BOOL = 5; // bool - UINT8 = 6; // uint8_t - INT8 = 7; // int8_t - UINT16 = 8; // uint16_t - INT16 = 9; // int16_t - INT64 = 10; // int64_t - FLOAT16 = 12; // at::Half - DOUBLE = 13; // double - - ZERO_COLLISION_HASH = 14; // zero-collision hash state - REBATCHING_BUFFER = 15; // rebatching buffer - } - // The type of the deserialized tensor data - optional DataType data_type = 2 [ default = FLOAT ]; - - // The format of the serialized data. - enum SerializationFormat { - // FMT_PROTOBUF is the existing serialization format from before the - // data_format field was introduced. Most data types are serialized using - // the protobuf typed fields, although in some cases raw little endian data - // is stored in the byte_data field instead. - FMT_PROTOBUF = 0; - // bfloat16 data stored in the raw_data field. - FMT_BFLOAT16 = 1; - } - // data_format is a SerializationFormat enum value. - // However, we intentionally store it as an integer value so we can - // distinguish between old messages that do not have a data_format value vs - // new messages that have a SerializationFormat value that we don't - // understand. If we stored this as an enum then protobuf would deserialize - // both of these cases the same way. - optional uint32 data_format = 15 [ default = 0 ]; - - // For float - repeated float float_data = 3 [ packed = true ]; - // For int32, uint8, int8, uint16, int16, bool, and float16 - // Note about float16: in storage we will basically convert float16 byte-wise - // to unsigned short and then store them in the int32_data field. - // Note: storing int8 and uint8 values in this field unfortunately results in - // larger serialized data than necessary, as protobuf's varint encoding - // scheme requires 2 bytes to represent int8 and uint8 values that have the - // MSB set. - repeated int32 int32_data = 4 [ packed = true ]; - // For bytes - optional bytes byte_data = 5; - // For strings - repeated bytes string_data = 6; - // For double - repeated double double_data = 9 [ packed = true ]; - // For int64 - repeated int64 int64_data = 10 [ packed = true ]; - // store the raw data, contents are serialized as little-endian - optional bytes raw_data = 13; - - // Optionally, a name for the tensor. - optional string name = 7; - - // Optionally, a TensorProto can contain the details about the device that - // it was serialized from. This is useful in cases like snapshotting a whole - // workspace in a multi-GPU environment. - optional DeviceOption device_detail = 8; - - // When loading from chunks this is going to indicate where to put data in the - // full array. When not used full data have to be present - message Segment { - required int64 begin = 1; - required int64 end = 2; - } - optional Segment segment = 11; - - // Field numbers 12 and 14 were previously used for now-deprecated fields. - // reserved 12, 14; -} - -message QTensorProto { - repeated int64 dims = 1; - required int32 precision = 2; - required double scale = 3; - required double bias = 4; - required bool is_signed = 5; - repeated int32 data = 6 [ packed = true ]; - optional string name = 7; - optional TensorProto.DataType data_type = 8 [ default = INT32 ]; - - // Multi-group quantization params - repeated double scales = 9; - repeated double biases = 10; - - // Multi-group quantization needed, indicates in which dimension - // we do the "group wise quantization" - optional int32 axis = 11; - - // It should be true if it is a multi-group quantization proto - optional bool is_multiparam = 12 [ default = false ]; -} - -// TensorProtos stores multiple TensorProto objects in one single proto. This -// is useful for small tensors; For anything big, consider using a DB for -// storage. -message TensorProtos { - repeated TensorProto protos = 1; -} - -message TensorShape { - repeated int64 dims = 1; - optional TensorProto.DataType data_type = 2 [ default = FLOAT ]; - repeated int32 unknown_dims = 3; - optional bool unknown_shape = 4 [ default = false ]; - optional string name = 5; -} - -message TensorShapes { - repeated TensorShape shapes = 1; -} - -// TensorBoundShape is used to save bound shape inference result for a tensor. -// TensorBoundShape.shape is inferred shape for this tensor. -// TensorBoundShape.dimType contains dim_type for every dimension. -// eg: for dimension i, shape.dims[i] is the inferred shape and -// dim_type[i] is corresponding dim_type. -message TensorBoundShape { - optional TensorShape shape = 1; - enum DimType { - UNKNOWN = 0; // unknown - CONSTANT = 1; // constant - // batch, corresponding dimension is batch_size - BATCH = 2; - // batch_of_feature_max, - // corresponding shape is inferred_feature_length * batch_size - BATCH_OF_FEATURE_MAX = 3; - // batch_of_feature_max_default - // corresponding shape is default_feature_length * batch_size - BATCH_OF_FEATURE_MAX_DEFAULT = 4; - // feature_max, corresponding shape is inferred_feature_length - FEATURE_MAX = 5; - // feature_max_default, corresponding shape is default_feature_length - FEATURE_MAX_DEFAULT = 6; - } - repeated DimType dim_type = 2; // dim_type.size() == shape.dims.size() - optional string name = 3; - // a flag to indicate whether the shape is final and cannot be changed - // eg: input/output of in-place ops - optional bool shape_is_final = 4; -} - -message TensorBoundShapes { - repeated TensorBoundShape shapes = 1; - optional int64 max_batch_size = 2; - optional int64 max_feature_len = 3; -} - -message AOTConfig { - required int64 max_batch_size = 1; - required int64 max_seq_size = 2; - required bool in_batch_broadcast = 3; - optional string onnxifi_blacklist_ops = 4; - optional int32 onnxifi_min_ops = 5; -} - -// A named argument containing either singular float, integer and string -// values, or repeated float, int and string arrays. -message Argument { - optional string name = 1; - - optional float f = 2; - optional int64 i = 3; - optional bytes s = 4; - optional TensorProto t = 10; - optional NetDef n = 8; - - repeated float floats = 5; - repeated int64 ints = 6; - repeated bytes strings = 7; - repeated TensorProto tensors = 11; - repeated NetDef nets = 9; - repeated QTensorProto qtensors = 12; -} - -// DeviceType that Caffe2 currently supports. -// Note: if you add a device type, make sure you add the corresponding device -// line in the DeviceTypeName() function in caffe2/utils/proto_utils.cc -// and update c10/core/DeviceType.h -enum DeviceTypeProto { - PROTO_CPU = 0; // In default, we will use CPU. - PROTO_CUDA = 1; // CUDA. - PROTO_MKLDNN = 2; // Reserved for explicit MKLDNN - PROTO_OPENGL = 3; // OpenGL - PROTO_OPENCL = 4; // OpenCL - PROTO_IDEEP = 5; // IDEEP. - PROTO_HIP = 6; // AMD HIP - PROTO_FPGA = 7; // FPGA - PROTO_MAIA = 8; // MAIA - PROTO_XLA = 9; // XLA / TPU - PROTO_MPS = 10; // MPS - // Change the following number if you add more devices in the code. - PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11; -} - -// Device-specific options. We do not distinguish DeviceOption protos for -// different DeviceTypes, so currently all devices share the same DeviceOption -// proto. Fields that are specific to a device type is ignored if the type does -// not match. -// Note: if you add fields to the DeviceOption, make sure you add the -// corresponding changes to IsSameDevice() function in utils/proto_utils.{h,cc}. -message DeviceOption { - // [general] Options that need to be carried out before running the execution. - // optional DeviceType device_type = 1 [ default = CPU ]; - optional int32 device_type = 1 [ default = 0 ]; // 0 is CPU. - // [general] Used together with device_type to identify the exact device - optional int32 device_id = 2; - // [general] The random seed to start the device random number generator with. - optional uint32 random_seed = 3; - // [general] What node this op should execute on. - // Used for net transformation purposes. Must be empty at execution time. - optional string node_name = 4; - // [CPU and Linux specific] NUMA node id - optional int32 numa_node_id = 5; - // [general] Extra information passed, not used at execution time currently. - repeated string extra_info = 6; -} - -// Operator Definition. -message OperatorDef { - repeated string input = 1; // the name of the input blobs - repeated string output = 2; // the name of output top blobs - optional string name = 3; // the operator name. This is optional. - // the operator type. This is needed to create the object from the operator - // registry. - optional string type = 4; - // arg is for the argument defined in operator schema - repeated Argument arg = 5; - - // The device option that the operator should run under. - optional DeviceOption device_option = 6; - - // Optionally, one can specify an engine when there are multiple - // implementations available simultaneously for one device type. - // If one specifies an engine but that engine does not exist in the compiled - // Caffe2 binary, Caffe2 will fall back to the default engine of that device - // type. - optional string engine = 7; - - // Additional 'fake' inputs used for expressing control dependencies - // in the operator graph. This can be used to ensure that an - // operator does not run until another operator is ready, for e.g. - // scheduling control. These are not passed as actual inputs to the - // Operator implementation, and are only used by the Net class for - // scheduling purposes. - repeated string control_input = 8; - - // is_gradient_op argument is only used as a hint in shape inference - // and has no runtime significance - optional bool is_gradient_op = 9 [ default = false ]; - - // debug information associated with the construction of the operator. - // This is an optional string with no assumed characteristics as - // operators can be constructed in any language. - optional string debug_info = 10; - - // the domain of the operator to help runtime distinguish which operator - // library this OperatorDef refers to. For example, both caffe2 and aten - // has `Add` operator, with domain, we can easily decide which operator - // to execute. to support multiple operator libs, we use domain to - // distinguish which operator lib we refer to: - // - "caffe2" means this uses Caffe2 operator library - // - "aten" means this uses ATen operator library - // - "c10" is for the fused library - // - if the domain is missing or empty, we use "caffe2", this is for - // legacy models, new serializer should always export an OperatorDef - // with domain and op_version - optional string domain = 11; - // each operator is has its own version number. - // operator version information - // each time, we change the API or semantics of the operator, - // we bump the version for the operator. - // the runtime system should check the op_version of each OperatorDef - // and decide it should reject or accept the model - optional int64 op_version = 12; -} - -// MapFieldEntry follows the pattern for cross-proto-version maps. -// See https://developers.google.com/protocol-buffers/docs/proto3#maps -message MapFieldEntry { - required string key = 1; - required string val = 2; -}; - -// Used to hold backend-specific options. -message BackendOptions { - // Name of the backend that the specified options apply to. - required string backend_name = 1; - // Flexible map for passing in the options. - repeated MapFieldEntry option = 2; -}; - -// Partition definition. -message PartitionInfo { - // Name of the partition. - required string name = 1; - - // A list of logic device ID, indicating which devices this partition - // can be executed on. If empty, it means the partition won't run on - // device but on host CPU instead. - repeated int32 device_id = 2; - - // Extra debug info. - optional string extra_info = 3; - - // Flexible map for passing options specific to a backend. - repeated BackendOptions backend_options = 4; -} - -// Network definition. -message NetDef { - optional string name = 1; // the network's name - // Operators that the network contains. - // Note: this is not named "operator" because that is a reserved word in C++. - repeated OperatorDef op = 2; - - // The type of network that the net should be run with. This routes the - // network instantiation to different execution modes. The default mode, - // "simple", runs the operators in a sequential way as the original Caffe - // implementation does. - optional string type = 3; - - // the number of workers, if the operators in the network is to be carried out - // in parallel. - // Note: This is to be deprecated. Using the arg field with "num_workers" as - // key. - // Note 2: The old uses of this were never actually cleaned up - optional int32 num_workers = 4; - - // The device option for the network. If a network has a specific device - // option and one of its operators does not have it set, we will copy over the - // device option to the operator. This allows us to basically avoid putting - // device options at every operator. - optional DeviceOption device_option = 5; - - repeated Argument arg = 6; - - // Two optional fields to declare external input and output of a net. - // If these two are set, when a net is created, we will sanity check for - // every op whether its input is declared (either as an external input, - // or as an intermediate blob created by one of the ops), and sanity check - // if all blobs in external_output are produced. - // - // In cases of memory optimization, declaring external_input and - // external_output also ensures that storage of these blobs are persistent: - // for any blob in external_input and external_output, after a network run - // finishes, their content are actually the right content. Any intermediate - // blobs' contents may be overwritten. - repeated string external_input = 7; - repeated string external_output = 8; - - // Partitioning info, indexed by partition names. - repeated PartitionInfo partition_info = 9; -} - -// ExecutionStep is actually a sort-of-hacky way we simulate iteration right -// now. -message ExecutionStep { - // ExecutionStep should either contain a set of substeps, or a set of - // network names to run in this execution step. They should NOT both be set - // at the same time. - optional string name = 1; - // An execution step could be recursive, in which it involves a set of - // substeps. - repeated ExecutionStep substep = 2; - // Alternatively, an execution step could involve one or more networks. - // Note that you cannot have both substeps and networks. Choose one. - // Note that an execution step refers networks by their name. The actual - // network definition of the same name should be included in the network field - // of the plan. The reason is that a network object might hold internal states - // (think of a data layer), so we want to have the same network object that - // multiple steps could ask to run. - repeated string network = 3; - // Number of iterations to run this step. The substeps or the networks - // specified will be run sequentially, and one sequential run is considered - // one iteration. If this is not set, the number of iterations is assumed to - // be 1. - optional int64 num_iter = 4; - - // Criteria network specifies a single output (TensorCPU) of - // size (1), is run on every iteration by the executor, and - // execution terminates when the output[0] is `false`. - optional string criteria_network = 5 [ deprecated = true ]; - - // DEPRECATED. Use `run_every_ms`. - optional string report_net = 7; - optional int32 report_interval = 8; - - // If provided, execute this step at every time interval (in millisecs) - // while its sibiling execution steps execute in parallel. This step is - // guaranteed to run at least once after all non-interval siblings finished. - optional int64 run_every_ms = 11; - - // If false or not set, execute sub-steps serially. - // If true, execute all substeps concurrently, each one in a separate thread. - optional bool concurrent_substeps = 6; - - // Name of a scalar boolean tensor. - // ES checks this blob AFTER every substeps/subnets. - // If specified, and the value is true, then ES will skip the rest and return - // immediately. - // This means that the report_net and the first step will always be called. - // Use cases: - // 1) the first substep stops the rest if data condition not met - // 2) the first substep decide which of the rest of the steps should be run. - // 3) external control - // - // ** It is the user's responsibility to not to put this blob in race - // conditions. - // ** For example when setting this blob in concurrent substeps - optional string should_stop_blob = 9; - - // if only_once is true, this step will only be executed once. this ONLY takes - // effect when using should_stop_blob - optional bool only_once = 10; - - // Whether to create a child workspace for this step. - // If yes, the workflow and nets are re-created every time this step is run. - optional bool create_workspace = 12; - - // How many copies of the children execution steps to run concurrently. - optional int32 num_concurrent_instances = 13; -} - -message PlanDef { - // All the networks that are used in this execution. Note that networks should - // be ordered in the way they are executed, i.e. for a layer in a network, all - // its input blobs should already have been initialized by the layers or - // networks defined before it. - optional string name = 1; - // The networks that are going to be used in this plan. - repeated NetDef network = 2; - repeated ExecutionStep execution_step = 3; -} - -// Protobuf format for blobs that are not Tensors. We use a key to store the -// type of the blob. For example for a serialized DBProto, the type should -// be "DBReader" and the content should be a serialized DBProto object. -message BlobProto { - optional string name = 1; - optional string type = 2; - optional TensorProto tensor = 3; - optional bytes content = 4; - optional QTensorProto qtensor = 5; - // If blob is not Tensor and is divided into chunks, content_num_chunks - // contains number of chunks, into which blob was divided. - optional int32 content_num_chunks = 6; - optional int32 content_chunk_id = 7; -} - -// Protobuf format to serialize DBReader. -message DBReaderProto { - // The name for the DB object in the workspace. - optional string name = 1; - // The source of the DB - optional string source = 2; - // The type of the DB - optional string db_type = 3; - // The current key of the DB if the DB supports seeking. - optional string key = 4; -} - -message BlobSerializationOptions { - // This set of options will only apply to blobs whose name matches this - // pattern. If the blob_name_pattern is empty then it will be treated as - // matching all blobs. - optional string blob_name_regex = 1; - - // Note: - // - a chunk_size of 0 means "use the default chunk size". The default chunk - // size is controlled by the --caffe2_tensor_chunk_size command line flag. - // - a chunk size of -1 means to disable chunking, and serialize the blob in - // a single chunk. - optional int64 chunk_size = 2; - - enum FloatFormat { - // Use the current default serialization format, as chosen by the - // current version of the code. (At the time of writing this is PROTOBUF) - FLOAT_DEFAULT = 0; - // Store the data in the TensorProto's float_data field - FLOAT_PROTOBUF = 1; - // Serialize float values as bfloat16. Note that this conversion is lossy. - FLOAT_BFLOAT16 = 2; - } - - // Settings for how to serialize tensors containing float values - optional FloatFormat float_format = 3; -} - -message SerializationOptions { - // A set of options to use when serialializing blobs. - // This is a list, sorted from highest to lowest precedence. When - // serializing a blob, the first entry whose blob_name_pattern matches the - // blob name will be used. - repeated BlobSerializationOptions options = 1; -} diff --git a/caffe2/proto/caffe2_legacy.proto b/caffe2/proto/caffe2_legacy.proto deleted file mode 100644 index 4fb2cda002fe..000000000000 --- a/caffe2/proto/caffe2_legacy.proto +++ /dev/null @@ -1,50 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// Original Caffe1 Datum copy: this is used in image input op to allow us to -// load caffe1 serialized datum without having to regenerate the database. -message CaffeDatum { - optional int32 channels = 1; - optional int32 height = 2; - optional int32 width = 3; - // the actual image data, in bytes - optional bytes data = 4; - optional int32 label = 5; - // Optionally, the datum could also hold float data. - repeated float float_data = 6; - // If true data contains an encoded image that need to be decoded - optional bool encoded = 7 [ default = false ]; -} - -enum LegacyPadding { - NOTSET = 0; // Do not use old-stype padding strategies. - - // VALID and SAME are two strategies adopted in Google DistBelief: it forces - // the input shape as follows. For SAME, the output is: - // R_out = ceil(float(R) / float(S)) - // C_out = ceil(float(C) / float(S)) - // where R and C are row and column, S is the stride, and K is the kernel. - // The number of padded pixels is then computed as - // Pr = ((R_out - 1) * S + K - R) - // Pc = ((C_out - 1) * S + K - C) - // When Pr and Pc are even numbers, both sides (left and right, or top and - // bottom) get half each. When Pr and Pc are odd numbers, the right and the - // bottom gets the one additional padding pixel. - // For VALID, padding values of 0 are always used. - VALID = 1; - SAME = 2; - - // CAFFE_LEGACY_POOLING is a flag that notifies the code to use the old Caffe - // padding strategy. - // Basically, in caffe2, after padding the convolution and pooling use the - // same computation strategy: half-windows at the right and bottom are - // discarded. In Caffe, convolution follows this strategy but if there are - // some pixels in the half-windows, the pooling layer will actually put one - // additional output. If you set LegacyPadding to this, we will compute the - // equivalent padding strategy in caffe2 so that the output size is - // backward compatible with Caffe. - // THIS IS NOW DEPRECATED. ANY non-conventional use has to be manually - // converted. - CAFFE_LEGACY_POOLING = 3; -} diff --git a/caffe2/proto/caffe2_legacy_pb2.pyi b/caffe2/proto/caffe2_legacy_pb2.pyi deleted file mode 100644 index eaee65471eef..000000000000 --- a/caffe2/proto/caffe2_legacy_pb2.pyi +++ /dev/null @@ -1,58 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -global___LegacyPadding = LegacyPadding -class _LegacyPadding(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[LegacyPadding], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - NOTSET = LegacyPadding.V(0) - VALID = LegacyPadding.V(1) - SAME = LegacyPadding.V(2) - CAFFE_LEGACY_POOLING = LegacyPadding.V(3) -class LegacyPadding(metaclass=_LegacyPadding): - V = typing.NewType('V', int) -NOTSET = LegacyPadding.V(0) -VALID = LegacyPadding.V(1) -SAME = LegacyPadding.V(2) -CAFFE_LEGACY_POOLING = LegacyPadding.V(3) - -class CaffeDatum(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - CHANNELS_FIELD_NUMBER: int - HEIGHT_FIELD_NUMBER: int - WIDTH_FIELD_NUMBER: int - DATA_FIELD_NUMBER: int - LABEL_FIELD_NUMBER: int - FLOAT_DATA_FIELD_NUMBER: int - ENCODED_FIELD_NUMBER: int - channels: int = ... - height: int = ... - width: int = ... - data: bytes = ... - label: int = ... - float_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - encoded: bool = ... - - def __init__(self, - *, - channels : typing.Optional[int] = ..., - height : typing.Optional[int] = ..., - width : typing.Optional[int] = ..., - data : typing.Optional[bytes] = ..., - label : typing.Optional[int] = ..., - float_data : typing.Optional[typing.Iterable[float]] = ..., - encoded : typing.Optional[bool] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"channels",b"channels",u"data",b"data",u"encoded",b"encoded",u"height",b"height",u"label",b"label",u"width",b"width"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"channels",b"channels",u"data",b"data",u"encoded",b"encoded",u"float_data",b"float_data",u"height",b"height",u"label",b"label",u"width",b"width"]) -> None: ... -global___CaffeDatum = CaffeDatum diff --git a/caffe2/proto/caffe2_pb.h b/caffe2/proto/caffe2_pb.h deleted file mode 100644 index fc82659dc51d..000000000000 --- a/caffe2/proto/caffe2_pb.h +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once -#include -#include -#include - -namespace caffe2 { - -using DeviceType = at::DeviceType; -constexpr DeviceType CPU = DeviceType::CPU; -constexpr DeviceType CUDA = DeviceType::CUDA; -constexpr DeviceType OPENGL = DeviceType::OPENGL; -constexpr DeviceType OPENCL = DeviceType::OPENCL; -constexpr DeviceType MKLDNN = DeviceType::MKLDNN; -constexpr DeviceType IDEEP = DeviceType::IDEEP; -constexpr DeviceType HIP = DeviceType::HIP; -constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES = - DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; - -inline TORCH_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) { - switch (p) { - case caffe2::PROTO_CPU: - return DeviceType::CPU; - case caffe2::PROTO_CUDA: - return DeviceType::CUDA; - case caffe2::PROTO_OPENGL: - return DeviceType::OPENGL; - case caffe2::PROTO_OPENCL: - return DeviceType::OPENCL; - case caffe2::PROTO_MKLDNN: - return DeviceType::MKLDNN; - case caffe2::PROTO_IDEEP: - return DeviceType::IDEEP; - case caffe2::PROTO_HIP: - return DeviceType::HIP; - case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES: - return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; - default: - AT_ERROR( - "Unknown device:", - static_cast(p), - ". If you have recently updated the caffe2.proto file to add a new " - "device type, did you forget to update the ProtoToType() and TypeToProto" - "function to reflect such recent changes?"); - } -} - -inline TORCH_API DeviceType ProtoToType(int p) { - return ProtoToType(static_cast(p)); -} - -inline TORCH_API DeviceTypeProto TypeToProto(const DeviceType& t) { - switch (t) { - case DeviceType::CPU: - return caffe2::PROTO_CPU; - case DeviceType::CUDA: - return caffe2::PROTO_CUDA; - case DeviceType::OPENGL: - return caffe2::PROTO_OPENGL; - case DeviceType::OPENCL: - return caffe2::PROTO_OPENCL; - case DeviceType::MKLDNN: - return caffe2::PROTO_MKLDNN; - case DeviceType::IDEEP: - return caffe2::PROTO_IDEEP; - case DeviceType::HIP: - return caffe2::PROTO_HIP; - case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES: - return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES; - default: - AT_ERROR( - "Unknown device:", - static_cast(t), - ". If you have recently updated the caffe2.proto file to add a new " - "device type, did you forget to update the ProtoToType() and TypeToProto" - "function to reflect such recent changes?"); - } -} - -inline TORCH_API caffe2::DeviceOption DeviceToOption(const at::Device& device) { - caffe2::DeviceOption option; - auto type = device.type(); - option.set_device_type(TypeToProto(type)); - - switch (type) { - case DeviceType::CPU: - if (device.index() != -1) { - option.set_numa_node_id(device.index()); - } - break; - case DeviceType::CUDA: - case DeviceType::HIP: - option.set_device_id(device.index()); - break; - case DeviceType::OPENGL: - case DeviceType::OPENCL: - case DeviceType::MKLDNN: - case DeviceType::IDEEP: - case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES: - break; - default: - AT_ERROR( - "Unknown device:", - static_cast(type), - ". If you have recently updated the caffe2.proto file to add a new " - "device type, did you forget to update the ProtoToType() and TypeToProto" - "function to reflect such recent changes?"); - } - return option; -} - -inline TORCH_API at::Device OptionToDevice(const caffe2::DeviceOption& option) { - auto type = option.device_type(); - c10::DeviceIndex id = -1; - switch (type) { - case caffe2::PROTO_CPU: - if (option.has_numa_node_id()) { - id = static_cast(option.numa_node_id()); - } - break; - case caffe2::PROTO_CUDA: - case caffe2::PROTO_HIP: - id = static_cast(option.device_id()); - break; - } - return at::Device(ProtoToType(type), id); -} - -inline void ExtractDeviceOption( - DeviceOption* device_option, - const at::Device& device) { - AT_ASSERT(device_option); - device_option->CopyFrom(DeviceToOption(device)); -} - -} // namespace caffe2 diff --git a/caffe2/proto/caffe2_pb2.pyi b/caffe2/proto/caffe2_pb2.pyi deleted file mode 100644 index 43249ebf75db..000000000000 --- a/caffe2/proto/caffe2_pb2.pyi +++ /dev/null @@ -1,767 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -global___DeviceTypeProto = DeviceTypeProto -class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[DeviceTypeProto], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - PROTO_CPU = DeviceTypeProto.V(0) - PROTO_CUDA = DeviceTypeProto.V(1) - PROTO_MKLDNN = DeviceTypeProto.V(2) - PROTO_OPENGL = DeviceTypeProto.V(3) - PROTO_OPENCL = DeviceTypeProto.V(4) - PROTO_IDEEP = DeviceTypeProto.V(5) - PROTO_HIP = DeviceTypeProto.V(6) - PROTO_FPGA = DeviceTypeProto.V(7) - PROTO_MAIA = DeviceTypeProto.V(8) - PROTO_XLA = DeviceTypeProto.V(9) - PROTO_MPS = DeviceTypeProto.V(10) - PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) -class DeviceTypeProto(metaclass=_DeviceTypeProto): - V = typing.NewType('V', int) -PROTO_CPU = DeviceTypeProto.V(0) -PROTO_CUDA = DeviceTypeProto.V(1) -PROTO_MKLDNN = DeviceTypeProto.V(2) -PROTO_OPENGL = DeviceTypeProto.V(3) -PROTO_OPENCL = DeviceTypeProto.V(4) -PROTO_IDEEP = DeviceTypeProto.V(5) -PROTO_HIP = DeviceTypeProto.V(6) -PROTO_FPGA = DeviceTypeProto.V(7) -PROTO_MAIA = DeviceTypeProto.V(8) -PROTO_XLA = DeviceTypeProto.V(9) -PROTO_MPS = DeviceTypeProto.V(10) -PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) - -class TensorProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - class _DataType(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[DataType], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - UNDEFINED = TensorProto.DataType.V(0) - FLOAT = TensorProto.DataType.V(1) - INT32 = TensorProto.DataType.V(2) - BYTE = TensorProto.DataType.V(3) - STRING = TensorProto.DataType.V(4) - BOOL = TensorProto.DataType.V(5) - UINT8 = TensorProto.DataType.V(6) - INT8 = TensorProto.DataType.V(7) - UINT16 = TensorProto.DataType.V(8) - INT16 = TensorProto.DataType.V(9) - INT64 = TensorProto.DataType.V(10) - FLOAT16 = TensorProto.DataType.V(12) - DOUBLE = TensorProto.DataType.V(13) - ZERO_COLLISION_HASH = TensorProto.DataType.V(14) - REBATCHING_BUFFER = TensorProto.DataType.V(15) - class DataType(metaclass=_DataType): - V = typing.NewType('V', int) - UNDEFINED = TensorProto.DataType.V(0) - FLOAT = TensorProto.DataType.V(1) - INT32 = TensorProto.DataType.V(2) - BYTE = TensorProto.DataType.V(3) - STRING = TensorProto.DataType.V(4) - BOOL = TensorProto.DataType.V(5) - UINT8 = TensorProto.DataType.V(6) - INT8 = TensorProto.DataType.V(7) - UINT16 = TensorProto.DataType.V(8) - INT16 = TensorProto.DataType.V(9) - INT64 = TensorProto.DataType.V(10) - FLOAT16 = TensorProto.DataType.V(12) - DOUBLE = TensorProto.DataType.V(13) - ZERO_COLLISION_HASH = TensorProto.DataType.V(14) - REBATCHING_BUFFER = TensorProto.DataType.V(15) - - class _SerializationFormat(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[SerializationFormat], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - FMT_PROTOBUF = TensorProto.SerializationFormat.V(0) - FMT_BFLOAT16 = TensorProto.SerializationFormat.V(1) - class SerializationFormat(metaclass=_SerializationFormat): - V = typing.NewType('V', int) - FMT_PROTOBUF = TensorProto.SerializationFormat.V(0) - FMT_BFLOAT16 = TensorProto.SerializationFormat.V(1) - - class Segment(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - BEGIN_FIELD_NUMBER: int - END_FIELD_NUMBER: int - begin: int = ... - end: int = ... - - def __init__(self, - *, - begin : typing.Optional[int] = ..., - end : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"begin",b"begin",u"end",b"end"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"begin",b"begin",u"end",b"end"]) -> None: ... - - DIMS_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - DATA_FORMAT_FIELD_NUMBER: int - FLOAT_DATA_FIELD_NUMBER: int - INT32_DATA_FIELD_NUMBER: int - BYTE_DATA_FIELD_NUMBER: int - STRING_DATA_FIELD_NUMBER: int - DOUBLE_DATA_FIELD_NUMBER: int - INT64_DATA_FIELD_NUMBER: int - RAW_DATA_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - DEVICE_DETAIL_FIELD_NUMBER: int - SEGMENT_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - data_type: global___TensorProto.DataType = ... - data_format: int = ... - float_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - int32_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - byte_data: bytes = ... - string_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[bytes] = ... - double_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - int64_data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - raw_data: bytes = ... - name: typing.Text = ... - - @property - def device_detail(self) -> global___DeviceOption: ... - - @property - def segment(self) -> global___TensorProto.Segment: ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - data_type : typing.Optional[global___TensorProto.DataType] = ..., - data_format : typing.Optional[int] = ..., - float_data : typing.Optional[typing.Iterable[float]] = ..., - int32_data : typing.Optional[typing.Iterable[int]] = ..., - byte_data : typing.Optional[bytes] = ..., - string_data : typing.Optional[typing.Iterable[bytes]] = ..., - double_data : typing.Optional[typing.Iterable[float]] = ..., - int64_data : typing.Optional[typing.Iterable[int]] = ..., - raw_data : typing.Optional[bytes] = ..., - name : typing.Optional[typing.Text] = ..., - device_detail : typing.Optional[global___DeviceOption] = ..., - segment : typing.Optional[global___TensorProto.Segment] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"byte_data",b"byte_data",u"data_format",b"data_format",u"data_type",b"data_type",u"device_detail",b"device_detail",u"name",b"name",u"raw_data",b"raw_data",u"segment",b"segment"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"byte_data",b"byte_data",u"data_format",b"data_format",u"data_type",b"data_type",u"device_detail",b"device_detail",u"dims",b"dims",u"double_data",b"double_data",u"float_data",b"float_data",u"int32_data",b"int32_data",u"int64_data",b"int64_data",u"name",b"name",u"raw_data",b"raw_data",u"segment",b"segment",u"string_data",b"string_data"]) -> None: ... -global___TensorProto = TensorProto - -class QTensorProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DIMS_FIELD_NUMBER: int - PRECISION_FIELD_NUMBER: int - SCALE_FIELD_NUMBER: int - BIAS_FIELD_NUMBER: int - IS_SIGNED_FIELD_NUMBER: int - DATA_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - SCALES_FIELD_NUMBER: int - BIASES_FIELD_NUMBER: int - AXIS_FIELD_NUMBER: int - IS_MULTIPARAM_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - precision: int = ... - scale: float = ... - bias: float = ... - is_signed: bool = ... - data: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - name: typing.Text = ... - data_type: global___TensorProto.DataType = ... - scales: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - biases: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - axis: int = ... - is_multiparam: bool = ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - precision : typing.Optional[int] = ..., - scale : typing.Optional[float] = ..., - bias : typing.Optional[float] = ..., - is_signed : typing.Optional[bool] = ..., - data : typing.Optional[typing.Iterable[int]] = ..., - name : typing.Optional[typing.Text] = ..., - data_type : typing.Optional[global___TensorProto.DataType] = ..., - scales : typing.Optional[typing.Iterable[float]] = ..., - biases : typing.Optional[typing.Iterable[float]] = ..., - axis : typing.Optional[int] = ..., - is_multiparam : typing.Optional[bool] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"axis",b"axis",u"bias",b"bias",u"data_type",b"data_type",u"is_multiparam",b"is_multiparam",u"is_signed",b"is_signed",u"name",b"name",u"precision",b"precision",u"scale",b"scale"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"axis",b"axis",u"bias",b"bias",u"biases",b"biases",u"data",b"data",u"data_type",b"data_type",u"dims",b"dims",u"is_multiparam",b"is_multiparam",u"is_signed",b"is_signed",u"name",b"name",u"precision",b"precision",u"scale",b"scale",u"scales",b"scales"]) -> None: ... -global___QTensorProto = QTensorProto - -class TensorProtos(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - PROTOS_FIELD_NUMBER: int - - @property - def protos(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorProto]: ... - - def __init__(self, - *, - protos : typing.Optional[typing.Iterable[global___TensorProto]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal[u"protos",b"protos"]) -> None: ... -global___TensorProtos = TensorProtos - -class TensorShape(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DIMS_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - UNKNOWN_DIMS_FIELD_NUMBER: int - UNKNOWN_SHAPE_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - data_type: global___TensorProto.DataType = ... - unknown_dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - unknown_shape: bool = ... - name: typing.Text = ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - data_type : typing.Optional[global___TensorProto.DataType] = ..., - unknown_dims : typing.Optional[typing.Iterable[int]] = ..., - unknown_shape : typing.Optional[bool] = ..., - name : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"data_type",b"data_type",u"name",b"name",u"unknown_shape",b"unknown_shape"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"data_type",b"data_type",u"dims",b"dims",u"name",b"name",u"unknown_dims",b"unknown_dims",u"unknown_shape",b"unknown_shape"]) -> None: ... -global___TensorShape = TensorShape - -class TensorShapes(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SHAPES_FIELD_NUMBER: int - - @property - def shapes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorShape]: ... - - def __init__(self, - *, - shapes : typing.Optional[typing.Iterable[global___TensorShape]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal[u"shapes",b"shapes"]) -> None: ... -global___TensorShapes = TensorShapes - -class TensorBoundShape(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - class _DimType(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[DimType], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - UNKNOWN = TensorBoundShape.DimType.V(0) - CONSTANT = TensorBoundShape.DimType.V(1) - BATCH = TensorBoundShape.DimType.V(2) - BATCH_OF_FEATURE_MAX = TensorBoundShape.DimType.V(3) - BATCH_OF_FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(4) - FEATURE_MAX = TensorBoundShape.DimType.V(5) - FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(6) - class DimType(metaclass=_DimType): - V = typing.NewType('V', int) - UNKNOWN = TensorBoundShape.DimType.V(0) - CONSTANT = TensorBoundShape.DimType.V(1) - BATCH = TensorBoundShape.DimType.V(2) - BATCH_OF_FEATURE_MAX = TensorBoundShape.DimType.V(3) - BATCH_OF_FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(4) - FEATURE_MAX = TensorBoundShape.DimType.V(5) - FEATURE_MAX_DEFAULT = TensorBoundShape.DimType.V(6) - - SHAPE_FIELD_NUMBER: int - DIM_TYPE_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - SHAPE_IS_FINAL_FIELD_NUMBER: int - dim_type: google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___TensorBoundShape.DimType] = ... - name: typing.Text = ... - shape_is_final: bool = ... - - @property - def shape(self) -> global___TensorShape: ... - - def __init__(self, - *, - shape : typing.Optional[global___TensorShape] = ..., - dim_type : typing.Optional[typing.Iterable[global___TensorBoundShape.DimType]] = ..., - name : typing.Optional[typing.Text] = ..., - shape_is_final : typing.Optional[bool] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"name",b"name",u"shape",b"shape",u"shape_is_final",b"shape_is_final"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"dim_type",b"dim_type",u"name",b"name",u"shape",b"shape",u"shape_is_final",b"shape_is_final"]) -> None: ... -global___TensorBoundShape = TensorBoundShape - -class TensorBoundShapes(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SHAPES_FIELD_NUMBER: int - MAX_BATCH_SIZE_FIELD_NUMBER: int - MAX_FEATURE_LEN_FIELD_NUMBER: int - max_batch_size: int = ... - max_feature_len: int = ... - - @property - def shapes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorBoundShape]: ... - - def __init__(self, - *, - shapes : typing.Optional[typing.Iterable[global___TensorBoundShape]] = ..., - max_batch_size : typing.Optional[int] = ..., - max_feature_len : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"max_batch_size",b"max_batch_size",u"max_feature_len",b"max_feature_len"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"max_batch_size",b"max_batch_size",u"max_feature_len",b"max_feature_len",u"shapes",b"shapes"]) -> None: ... -global___TensorBoundShapes = TensorBoundShapes - -class AOTConfig(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - MAX_BATCH_SIZE_FIELD_NUMBER: int - MAX_SEQ_SIZE_FIELD_NUMBER: int - IN_BATCH_BROADCAST_FIELD_NUMBER: int - ONNXIFI_BLACKLIST_OPS_FIELD_NUMBER: int - ONNXIFI_MIN_OPS_FIELD_NUMBER: int - max_batch_size: int = ... - max_seq_size: int = ... - in_batch_broadcast: bool = ... - onnxifi_blacklist_ops: typing.Text = ... - onnxifi_min_ops: int = ... - - def __init__(self, - *, - max_batch_size : typing.Optional[int] = ..., - max_seq_size : typing.Optional[int] = ..., - in_batch_broadcast : typing.Optional[bool] = ..., - onnxifi_blacklist_ops : typing.Optional[typing.Text] = ..., - onnxifi_min_ops : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"in_batch_broadcast",b"in_batch_broadcast",u"max_batch_size",b"max_batch_size",u"max_seq_size",b"max_seq_size",u"onnxifi_blacklist_ops",b"onnxifi_blacklist_ops",u"onnxifi_min_ops",b"onnxifi_min_ops"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"in_batch_broadcast",b"in_batch_broadcast",u"max_batch_size",b"max_batch_size",u"max_seq_size",b"max_seq_size",u"onnxifi_blacklist_ops",b"onnxifi_blacklist_ops",u"onnxifi_min_ops",b"onnxifi_min_ops"]) -> None: ... -global___AOTConfig = AOTConfig - -class Argument(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - F_FIELD_NUMBER: int - I_FIELD_NUMBER: int - S_FIELD_NUMBER: int - T_FIELD_NUMBER: int - N_FIELD_NUMBER: int - FLOATS_FIELD_NUMBER: int - INTS_FIELD_NUMBER: int - STRINGS_FIELD_NUMBER: int - TENSORS_FIELD_NUMBER: int - NETS_FIELD_NUMBER: int - QTENSORS_FIELD_NUMBER: int - name: typing.Text = ... - f: float = ... - i: int = ... - s: bytes = ... - floats: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - ints: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - strings: google.protobuf.internal.containers.RepeatedScalarFieldContainer[bytes] = ... - - @property - def t(self) -> global___TensorProto: ... - - @property - def n(self) -> global___NetDef: ... - - @property - def tensors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorProto]: ... - - @property - def nets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NetDef]: ... - - @property - def qtensors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___QTensorProto]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - f : typing.Optional[float] = ..., - i : typing.Optional[int] = ..., - s : typing.Optional[bytes] = ..., - t : typing.Optional[global___TensorProto] = ..., - n : typing.Optional[global___NetDef] = ..., - floats : typing.Optional[typing.Iterable[float]] = ..., - ints : typing.Optional[typing.Iterable[int]] = ..., - strings : typing.Optional[typing.Iterable[bytes]] = ..., - tensors : typing.Optional[typing.Iterable[global___TensorProto]] = ..., - nets : typing.Optional[typing.Iterable[global___NetDef]] = ..., - qtensors : typing.Optional[typing.Iterable[global___QTensorProto]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"f",b"f",u"i",b"i",u"n",b"n",u"name",b"name",u"s",b"s",u"t",b"t"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"f",b"f",u"floats",b"floats",u"i",b"i",u"ints",b"ints",u"n",b"n",u"name",b"name",u"nets",b"nets",u"qtensors",b"qtensors",u"s",b"s",u"strings",b"strings",u"t",b"t",u"tensors",b"tensors"]) -> None: ... -global___Argument = Argument - -class DeviceOption(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DEVICE_TYPE_FIELD_NUMBER: int - DEVICE_ID_FIELD_NUMBER: int - RANDOM_SEED_FIELD_NUMBER: int - NODE_NAME_FIELD_NUMBER: int - NUMA_NODE_ID_FIELD_NUMBER: int - EXTRA_INFO_FIELD_NUMBER: int - device_type: int = ... - device_id: int = ... - random_seed: int = ... - node_name: typing.Text = ... - numa_node_id: int = ... - extra_info: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - def __init__(self, - *, - device_type : typing.Optional[int] = ..., - device_id : typing.Optional[int] = ..., - random_seed : typing.Optional[int] = ..., - node_name : typing.Optional[typing.Text] = ..., - numa_node_id : typing.Optional[int] = ..., - extra_info : typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"device_id",b"device_id",u"device_type",b"device_type",u"node_name",b"node_name",u"numa_node_id",b"numa_node_id",u"random_seed",b"random_seed"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"device_id",b"device_id",u"device_type",b"device_type",u"extra_info",b"extra_info",u"node_name",b"node_name",u"numa_node_id",b"numa_node_id",u"random_seed",b"random_seed"]) -> None: ... -global___DeviceOption = DeviceOption - -class OperatorDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - INPUT_FIELD_NUMBER: int - OUTPUT_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - ARG_FIELD_NUMBER: int - DEVICE_OPTION_FIELD_NUMBER: int - ENGINE_FIELD_NUMBER: int - CONTROL_INPUT_FIELD_NUMBER: int - IS_GRADIENT_OP_FIELD_NUMBER: int - DEBUG_INFO_FIELD_NUMBER: int - DOMAIN_FIELD_NUMBER: int - OP_VERSION_FIELD_NUMBER: int - input: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - output: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - name: typing.Text = ... - type: typing.Text = ... - engine: typing.Text = ... - control_input: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - is_gradient_op: bool = ... - debug_info: typing.Text = ... - domain: typing.Text = ... - op_version: int = ... - - @property - def arg(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Argument]: ... - - @property - def device_option(self) -> global___DeviceOption: ... - - def __init__(self, - *, - input : typing.Optional[typing.Iterable[typing.Text]] = ..., - output : typing.Optional[typing.Iterable[typing.Text]] = ..., - name : typing.Optional[typing.Text] = ..., - type : typing.Optional[typing.Text] = ..., - arg : typing.Optional[typing.Iterable[global___Argument]] = ..., - device_option : typing.Optional[global___DeviceOption] = ..., - engine : typing.Optional[typing.Text] = ..., - control_input : typing.Optional[typing.Iterable[typing.Text]] = ..., - is_gradient_op : typing.Optional[bool] = ..., - debug_info : typing.Optional[typing.Text] = ..., - domain : typing.Optional[typing.Text] = ..., - op_version : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"debug_info",b"debug_info",u"device_option",b"device_option",u"domain",b"domain",u"engine",b"engine",u"is_gradient_op",b"is_gradient_op",u"name",b"name",u"op_version",b"op_version",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"arg",b"arg",u"control_input",b"control_input",u"debug_info",b"debug_info",u"device_option",b"device_option",u"domain",b"domain",u"engine",b"engine",u"input",b"input",u"is_gradient_op",b"is_gradient_op",u"name",b"name",u"op_version",b"op_version",u"output",b"output",u"type",b"type"]) -> None: ... -global___OperatorDef = OperatorDef - -class MapFieldEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VAL_FIELD_NUMBER: int - key: typing.Text = ... - val: typing.Text = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - val : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"val",b"val"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"val",b"val"]) -> None: ... -global___MapFieldEntry = MapFieldEntry - -class BackendOptions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - BACKEND_NAME_FIELD_NUMBER: int - OPTION_FIELD_NUMBER: int - backend_name: typing.Text = ... - - @property - def option(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___MapFieldEntry]: ... - - def __init__(self, - *, - backend_name : typing.Optional[typing.Text] = ..., - option : typing.Optional[typing.Iterable[global___MapFieldEntry]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"backend_name",b"backend_name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"backend_name",b"backend_name",u"option",b"option"]) -> None: ... -global___BackendOptions = BackendOptions - -class PartitionInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - DEVICE_ID_FIELD_NUMBER: int - EXTRA_INFO_FIELD_NUMBER: int - BACKEND_OPTIONS_FIELD_NUMBER: int - name: typing.Text = ... - device_id: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - extra_info: typing.Text = ... - - @property - def backend_options(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BackendOptions]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - device_id : typing.Optional[typing.Iterable[int]] = ..., - extra_info : typing.Optional[typing.Text] = ..., - backend_options : typing.Optional[typing.Iterable[global___BackendOptions]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"extra_info",b"extra_info",u"name",b"name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"backend_options",b"backend_options",u"device_id",b"device_id",u"extra_info",b"extra_info",u"name",b"name"]) -> None: ... -global___PartitionInfo = PartitionInfo - -class NetDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - OP_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - NUM_WORKERS_FIELD_NUMBER: int - DEVICE_OPTION_FIELD_NUMBER: int - ARG_FIELD_NUMBER: int - EXTERNAL_INPUT_FIELD_NUMBER: int - EXTERNAL_OUTPUT_FIELD_NUMBER: int - PARTITION_INFO_FIELD_NUMBER: int - name: typing.Text = ... - type: typing.Text = ... - num_workers: int = ... - external_input: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - external_output: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - @property - def op(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___OperatorDef]: ... - - @property - def device_option(self) -> global___DeviceOption: ... - - @property - def arg(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Argument]: ... - - @property - def partition_info(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PartitionInfo]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - op : typing.Optional[typing.Iterable[global___OperatorDef]] = ..., - type : typing.Optional[typing.Text] = ..., - num_workers : typing.Optional[int] = ..., - device_option : typing.Optional[global___DeviceOption] = ..., - arg : typing.Optional[typing.Iterable[global___Argument]] = ..., - external_input : typing.Optional[typing.Iterable[typing.Text]] = ..., - external_output : typing.Optional[typing.Iterable[typing.Text]] = ..., - partition_info : typing.Optional[typing.Iterable[global___PartitionInfo]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"device_option",b"device_option",u"name",b"name",u"num_workers",b"num_workers",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"arg",b"arg",u"device_option",b"device_option",u"external_input",b"external_input",u"external_output",b"external_output",u"name",b"name",u"num_workers",b"num_workers",u"op",b"op",u"partition_info",b"partition_info",u"type",b"type"]) -> None: ... -global___NetDef = NetDef - -class ExecutionStep(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - SUBSTEP_FIELD_NUMBER: int - NETWORK_FIELD_NUMBER: int - NUM_ITER_FIELD_NUMBER: int - CRITERIA_NETWORK_FIELD_NUMBER: int - REPORT_NET_FIELD_NUMBER: int - REPORT_INTERVAL_FIELD_NUMBER: int - RUN_EVERY_MS_FIELD_NUMBER: int - CONCURRENT_SUBSTEPS_FIELD_NUMBER: int - SHOULD_STOP_BLOB_FIELD_NUMBER: int - ONLY_ONCE_FIELD_NUMBER: int - CREATE_WORKSPACE_FIELD_NUMBER: int - NUM_CONCURRENT_INSTANCES_FIELD_NUMBER: int - name: typing.Text = ... - network: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - num_iter: int = ... - criteria_network: typing.Text = ... - report_net: typing.Text = ... - report_interval: int = ... - run_every_ms: int = ... - concurrent_substeps: bool = ... - should_stop_blob: typing.Text = ... - only_once: bool = ... - create_workspace: bool = ... - num_concurrent_instances: int = ... - - @property - def substep(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ExecutionStep]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - substep : typing.Optional[typing.Iterable[global___ExecutionStep]] = ..., - network : typing.Optional[typing.Iterable[typing.Text]] = ..., - num_iter : typing.Optional[int] = ..., - criteria_network : typing.Optional[typing.Text] = ..., - report_net : typing.Optional[typing.Text] = ..., - report_interval : typing.Optional[int] = ..., - run_every_ms : typing.Optional[int] = ..., - concurrent_substeps : typing.Optional[bool] = ..., - should_stop_blob : typing.Optional[typing.Text] = ..., - only_once : typing.Optional[bool] = ..., - create_workspace : typing.Optional[bool] = ..., - num_concurrent_instances : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"concurrent_substeps",b"concurrent_substeps",u"create_workspace",b"create_workspace",u"criteria_network",b"criteria_network",u"name",b"name",u"num_concurrent_instances",b"num_concurrent_instances",u"num_iter",b"num_iter",u"only_once",b"only_once",u"report_interval",b"report_interval",u"report_net",b"report_net",u"run_every_ms",b"run_every_ms",u"should_stop_blob",b"should_stop_blob"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"concurrent_substeps",b"concurrent_substeps",u"create_workspace",b"create_workspace",u"criteria_network",b"criteria_network",u"name",b"name",u"network",b"network",u"num_concurrent_instances",b"num_concurrent_instances",u"num_iter",b"num_iter",u"only_once",b"only_once",u"report_interval",b"report_interval",u"report_net",b"report_net",u"run_every_ms",b"run_every_ms",u"should_stop_blob",b"should_stop_blob",u"substep",b"substep"]) -> None: ... -global___ExecutionStep = ExecutionStep - -class PlanDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - NETWORK_FIELD_NUMBER: int - EXECUTION_STEP_FIELD_NUMBER: int - name: typing.Text = ... - - @property - def network(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NetDef]: ... - - @property - def execution_step(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ExecutionStep]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - network : typing.Optional[typing.Iterable[global___NetDef]] = ..., - execution_step : typing.Optional[typing.Iterable[global___ExecutionStep]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"name",b"name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"execution_step",b"execution_step",u"name",b"name",u"network",b"network"]) -> None: ... -global___PlanDef = PlanDef - -class BlobProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - TENSOR_FIELD_NUMBER: int - CONTENT_FIELD_NUMBER: int - QTENSOR_FIELD_NUMBER: int - CONTENT_NUM_CHUNKS_FIELD_NUMBER: int - CONTENT_CHUNK_ID_FIELD_NUMBER: int - name: typing.Text = ... - type: typing.Text = ... - content: bytes = ... - content_num_chunks: int = ... - content_chunk_id: int = ... - - @property - def tensor(self) -> global___TensorProto: ... - - @property - def qtensor(self) -> global___QTensorProto: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - type : typing.Optional[typing.Text] = ..., - tensor : typing.Optional[global___TensorProto] = ..., - content : typing.Optional[bytes] = ..., - qtensor : typing.Optional[global___QTensorProto] = ..., - content_num_chunks : typing.Optional[int] = ..., - content_chunk_id : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"content",b"content",u"content_chunk_id",b"content_chunk_id",u"content_num_chunks",b"content_num_chunks",u"name",b"name",u"qtensor",b"qtensor",u"tensor",b"tensor",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"content",b"content",u"content_chunk_id",b"content_chunk_id",u"content_num_chunks",b"content_num_chunks",u"name",b"name",u"qtensor",b"qtensor",u"tensor",b"tensor",u"type",b"type"]) -> None: ... -global___BlobProto = BlobProto - -class DBReaderProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - SOURCE_FIELD_NUMBER: int - DB_TYPE_FIELD_NUMBER: int - KEY_FIELD_NUMBER: int - name: typing.Text = ... - source: typing.Text = ... - db_type: typing.Text = ... - key: typing.Text = ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - source : typing.Optional[typing.Text] = ..., - db_type : typing.Optional[typing.Text] = ..., - key : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"db_type",b"db_type",u"key",b"key",u"name",b"name",u"source",b"source"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"db_type",b"db_type",u"key",b"key",u"name",b"name",u"source",b"source"]) -> None: ... -global___DBReaderProto = DBReaderProto - -class BlobSerializationOptions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - class _FloatFormat(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[FloatFormat], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - FLOAT_DEFAULT = BlobSerializationOptions.FloatFormat.V(0) - FLOAT_PROTOBUF = BlobSerializationOptions.FloatFormat.V(1) - FLOAT_BFLOAT16 = BlobSerializationOptions.FloatFormat.V(2) - class FloatFormat(metaclass=_FloatFormat): - V = typing.NewType('V', int) - FLOAT_DEFAULT = BlobSerializationOptions.FloatFormat.V(0) - FLOAT_PROTOBUF = BlobSerializationOptions.FloatFormat.V(1) - FLOAT_BFLOAT16 = BlobSerializationOptions.FloatFormat.V(2) - - BLOB_NAME_REGEX_FIELD_NUMBER: int - CHUNK_SIZE_FIELD_NUMBER: int - FLOAT_FORMAT_FIELD_NUMBER: int - blob_name_regex: typing.Text = ... - chunk_size: int = ... - float_format: global___BlobSerializationOptions.FloatFormat = ... - - def __init__(self, - *, - blob_name_regex : typing.Optional[typing.Text] = ..., - chunk_size : typing.Optional[int] = ..., - float_format : typing.Optional[global___BlobSerializationOptions.FloatFormat] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size",u"float_format",b"float_format"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"blob_name_regex",b"blob_name_regex",u"chunk_size",b"chunk_size",u"float_format",b"float_format"]) -> None: ... -global___BlobSerializationOptions = BlobSerializationOptions - -class SerializationOptions(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - OPTIONS_FIELD_NUMBER: int - - @property - def options(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BlobSerializationOptions]: ... - - def __init__(self, - *, - options : typing.Optional[typing.Iterable[global___BlobSerializationOptions]] = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal[u"options",b"options"]) -> None: ... -global___SerializationOptions = SerializationOptions - -DeviceType = int - -# These are freedom-patched into caffe2_pb2 in caffe2/proto/__init__.py -CPU: int = DeviceType.PROTO_CPU -CUDA: int = DeviceType.PROTO_CUDA -MKLDNN: int = DeviceType.PROTO_MKLDNN -OPENGL: int = DeviceType.PROTO_OPENGL -OPENCL: int = DeviceType.PROTO_OPENCL -IDEEP: int = DeviceType.PROTO_IDEEP -HIP: int = DeviceType.PROTO_HIP -COMPILE_TIME_MAX_DEVICE_TYPES: int = DeviceType.PROTO_COMPILE_TIME_MAX_DEVICE_TYPES diff --git a/caffe2/proto/gen_proto_typestubs.sh b/caffe2/proto/gen_proto_typestubs.sh deleted file mode 100755 index 85503936ea1a..000000000000 --- a/caffe2/proto/gen_proto_typestubs.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env bash - -# Generate type stubs for .proto definition files. - -# This should be run from as -# ./gen_proto_typestubs.sh -# (i.e., from inside the proto/ directory) - -# assumes mypy-protobuf installed to ~/.local; i.e. via -# pip3 install mypy-protobuf --user - -set -euxo pipefail - -MYPY_PROTOBUF_HOME="${1:-${HOME}/.local/bin}" - -pushd ../../ -buck run fbsource//third-party/protobuf:protoc -- --plugin=protoc-gen-mypy="${MYPY_PROTOBUF_HOME}"/protoc-gen-mypy --mypy_out=./ caffe2/proto/*.proto -popd - -# get rid of 'builtins.' prefix, which pyre does not like -sed -E -i 's/builtins\.//g' ./*.pyi - -# mypy-protobuf references types from other mypy-protobuf-generated stubs as -# 'type.V', but it should just be 'type', so we get rid of the '.V' suffix -# when it's not followed by parens to indicate a particular enum value. -sed -E -i 's/\.V([^(_[:alnum:]])/\1/g' ./*.pyi - -# --------------------------- -# Freedom-patched DeviceTypes -# --------------------------- -# -# In order to make DeviceTypes like CPU, CUDA, etc. directly accessible from -# the caffe2_pb2 module, they are currently freedom-patched into it in -# caffe2/python/__init__.py. This is not ideal: it would be better if these -# were autogenerated when the protobuf definitions were created by using -# allow_alias = true in the DeviceTypeProto definition in caffe2.proto. -# -# However, it is impossible to do this currently without significant effort. -# The issue is that the generated proto constants would conflict with various -# constants defined in the C++ caffe2 codebase (`caffe2_pb2.h`). We cannot -# simply remove these constants and replace them with the caffe2 -# DeviceTypeProto constants, because a huge portion of code expects -# at::DeviceType constants defined in `core/DeviceType.h` (apparently -# duplicated to avoid having to figure out how to autogenerate the protobuf -# definitions using cmake for ATen). -# -# Instead, we make a best-effort to add additional definitions in -# `caffe2_pb2.py` by looking for any freedom-patched constants in -# `caffe2/python/__init__.py` and making sure they have corresponding stubs in -# the pyi (see `gen_proto_typestubs_helper.py`). - -python3 ./gen_proto_typestubs_helper.py >> caffe2_pb2.pyi diff --git a/caffe2/proto/gen_proto_typestubs_helper.py b/caffe2/proto/gen_proto_typestubs_helper.py deleted file mode 100644 index 4ed83f55998f..000000000000 --- a/caffe2/proto/gen_proto_typestubs_helper.py +++ /dev/null @@ -1,15 +0,0 @@ -import ast - -with open("../python/__init__.py", "r") as f: - tree = ast.parse(f.read()) - -print("\nDeviceType = int\n") -print("# These are freedom-patched into caffe2_pb2 in caffe2/proto/__init__.py") -for stmt in tree.body: - if not isinstance(stmt, ast.Assign): - continue - target = stmt.targets[0] - if not isinstance(target, ast.Attribute): - continue - if isinstance(target.value, ast.Name) and target.value.id == "caffe2_pb2": - print(f"{target.attr}: int = DeviceType.PROTO_{target.attr}") diff --git a/caffe2/proto/hsm.proto b/caffe2/proto/hsm.proto deleted file mode 100644 index 2e3152cc332e..000000000000 --- a/caffe2/proto/hsm.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// Hierarchical Softmax protobuffer convention: -// The HSM operator requires a hierarchy of vocabulary words in the form of a -// tree from the user. This tree is expressed using the proto format. -// TreeProto points to the root NodeProto which can recursively contain children -// NodeProtos (internal nodes) or word_ids (leaf nodes). - -// The aforementioned TreeProto is internally translated into a list of word_ids -// tagged with a list of NodeProtos that lie in the path from the root to that -// word_id using hsm_util.create_hierarchy(tree_proto). -// Specifically, HierarchyProto contains a list of PathProtos. Each PathProto -// belongs to a word_id and contains a list of PathNodeProtos. Each -// PathNodeProto contains information about the number of children the node has -// (length), the index of the child node that lies in the path from root to -// word_id (target) and a cumulative sum of children nodes (index; this acts as -// the weight parameter matrix offset). - -// Each node in the hierarchy contains links to either leaf nodes or more -// non-terminal nodes -message NodeProto { - // Links to non-terminal children nodes - repeated NodeProto children = 1; - // Links to terminal (leaf) nodes - repeated int32 word_ids = 2; - optional int32 offset = 3; - optional string name = 4; - repeated float scores = 5; -} - -// Protobuf format to accept hierarchy for hierarchical softmax operator. -// TreeProto points to the root node. -message TreeProto { - optional NodeProto root_node = 1; -} - -// Internal Protobuf format which represents the path in the tree hierarchy for -// each word in the vocabulary. -message HierarchyProto { - optional int32 size = 1; - repeated PathProto paths = 2; -} - -// Each PathProto belongs to a word and is an array of nodes in the -// path from the root to the leaf (which is the word itself) in the tree. -message PathProto { - optional int32 word_id = 1; - repeated PathNodeProto path_nodes = 2; -} - -// Represents a node in the path from the root node all the way down to the -// word (leaf). -message PathNodeProto { - // Parameter matrix offset for this node - optional int32 index = 1; - // Number of children - optional int32 length = 2; - // Index of the next node in the path - optional int32 target = 3; -} diff --git a/caffe2/proto/hsm_pb2.pyi b/caffe2/proto/hsm_pb2.pyi deleted file mode 100644 index 86a47f58d17c..000000000000 --- a/caffe2/proto/hsm_pb2.pyi +++ /dev/null @@ -1,109 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class NodeProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - CHILDREN_FIELD_NUMBER: int - WORD_IDS_FIELD_NUMBER: int - OFFSET_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - SCORES_FIELD_NUMBER: int - word_ids: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - offset: int = ... - name: typing.Text = ... - scores: google.protobuf.internal.containers.RepeatedScalarFieldContainer[float] = ... - - @property - def children(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NodeProto]: ... - - def __init__(self, - *, - children : typing.Optional[typing.Iterable[global___NodeProto]] = ..., - word_ids : typing.Optional[typing.Iterable[int]] = ..., - offset : typing.Optional[int] = ..., - name : typing.Optional[typing.Text] = ..., - scores : typing.Optional[typing.Iterable[float]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"name",b"name",u"offset",b"offset"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"children",b"children",u"name",b"name",u"offset",b"offset",u"scores",b"scores",u"word_ids",b"word_ids"]) -> None: ... -global___NodeProto = NodeProto - -class TreeProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - ROOT_NODE_FIELD_NUMBER: int - - @property - def root_node(self) -> global___NodeProto: ... - - def __init__(self, - *, - root_node : typing.Optional[global___NodeProto] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"root_node",b"root_node"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"root_node",b"root_node"]) -> None: ... -global___TreeProto = TreeProto - -class HierarchyProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SIZE_FIELD_NUMBER: int - PATHS_FIELD_NUMBER: int - size: int = ... - - @property - def paths(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PathProto]: ... - - def __init__(self, - *, - size : typing.Optional[int] = ..., - paths : typing.Optional[typing.Iterable[global___PathProto]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"size",b"size"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"paths",b"paths",u"size",b"size"]) -> None: ... -global___HierarchyProto = HierarchyProto - -class PathProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - WORD_ID_FIELD_NUMBER: int - PATH_NODES_FIELD_NUMBER: int - word_id: int = ... - - @property - def path_nodes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PathNodeProto]: ... - - def __init__(self, - *, - word_id : typing.Optional[int] = ..., - path_nodes : typing.Optional[typing.Iterable[global___PathNodeProto]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"word_id",b"word_id"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"path_nodes",b"path_nodes",u"word_id",b"word_id"]) -> None: ... -global___PathProto = PathProto - -class PathNodeProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - INDEX_FIELD_NUMBER: int - LENGTH_FIELD_NUMBER: int - TARGET_FIELD_NUMBER: int - index: int = ... - length: int = ... - target: int = ... - - def __init__(self, - *, - index : typing.Optional[int] = ..., - length : typing.Optional[int] = ..., - target : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"index",b"index",u"length",b"length",u"target",b"target"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"index",b"index",u"length",b"length",u"target",b"target"]) -> None: ... -global___PathNodeProto = PathNodeProto diff --git a/caffe2/proto/metanet.proto b/caffe2/proto/metanet.proto deleted file mode 100644 index 8008610ac0fa..000000000000 --- a/caffe2/proto/metanet.proto +++ /dev/null @@ -1,50 +0,0 @@ -syntax = "proto2"; - -import "caffe2/proto/caffe2.proto"; - -package caffe2; - -message ModelInfo { - optional string project = 1; - optional string modelClass = 2; - optional string version = 3; - optional string predictorType = 4 [ default = "SINGLE_PREDICTOR" ]; - optional string modelId = 5; -} - -message BlobsMap { - required string key = 1; - repeated string value = 2; -} - -message NetsMap { - required string key = 1; - required NetDef value = 2; -} - -message PlansMap { - required string key = 1; - required PlanDef value = 2; -} - -message StringMap { - required string key = 1; - required string value = 2; -} - -message MetaNetDef { - repeated BlobsMap blobs = 1; - // Text-format serialized NetDefs. - repeated NetsMap nets = 2; - // Info about where the model comes from. Possible use cases: - // 1) sanity check or diagnose - // 2) provide info for evaluation. - optional ModelInfo modelInfo = 3; - repeated PlansMap plans = 4; - repeated StringMap applicationSpecificInfo = 5; - repeated string blobsOrder = 6; - repeated string preLoadBlobs = 7; - optional TensorBoundShapes tensorBoundShapes = 8; - repeated string requestOnlyEmbeddings = 9; - optional AOTConfig aotConfig = 10; -} diff --git a/caffe2/proto/metanet_pb2.pyi b/caffe2/proto/metanet_pb2.pyi deleted file mode 100644 index 096fd90df876..000000000000 --- a/caffe2/proto/metanet_pb2.pyi +++ /dev/null @@ -1,160 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import caffe2.proto.caffe2_pb2 -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class ModelInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - PROJECT_FIELD_NUMBER: int - MODELCLASS_FIELD_NUMBER: int - VERSION_FIELD_NUMBER: int - PREDICTORTYPE_FIELD_NUMBER: int - MODELID_FIELD_NUMBER: int - project: typing.Text = ... - modelClass: typing.Text = ... - version: typing.Text = ... - predictorType: typing.Text = ... - modelId: typing.Text = ... - - def __init__(self, - *, - project : typing.Optional[typing.Text] = ..., - modelClass : typing.Optional[typing.Text] = ..., - version : typing.Optional[typing.Text] = ..., - predictorType : typing.Optional[typing.Text] = ..., - modelId : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"modelClass",b"modelClass",u"modelId",b"modelId",u"predictorType",b"predictorType",u"project",b"project",u"version",b"version"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"modelClass",b"modelClass",u"modelId",b"modelId",u"predictorType",b"predictorType",u"project",b"project",u"version",b"version"]) -> None: ... -global___ModelInfo = ModelInfo - -class BlobsMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - value: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___BlobsMap = BlobsMap - -class NetsMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - - @property - def value(self) -> caffe2.proto.caffe2_pb2.NetDef: ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[caffe2.proto.caffe2_pb2.NetDef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___NetsMap = NetsMap - -class PlansMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - - @property - def value(self) -> caffe2.proto.caffe2_pb2.PlanDef: ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[caffe2.proto.caffe2_pb2.PlanDef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___PlansMap = PlansMap - -class StringMap(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - VALUE_FIELD_NUMBER: int - key: typing.Text = ... - value: typing.Text = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - value : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ... -global___StringMap = StringMap - -class MetaNetDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - BLOBS_FIELD_NUMBER: int - NETS_FIELD_NUMBER: int - MODELINFO_FIELD_NUMBER: int - PLANS_FIELD_NUMBER: int - APPLICATIONSPECIFICINFO_FIELD_NUMBER: int - BLOBSORDER_FIELD_NUMBER: int - PRELOADBLOBS_FIELD_NUMBER: int - TENSORBOUNDSHAPES_FIELD_NUMBER: int - REQUESTONLYEMBEDDINGS_FIELD_NUMBER: int - AOTCONFIG_FIELD_NUMBER: int - blobsOrder: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - preLoadBlobs: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - requestOnlyEmbeddings: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - @property - def blobs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BlobsMap]: ... - - @property - def nets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NetsMap]: ... - - @property - def modelInfo(self) -> global___ModelInfo: ... - - @property - def plans(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PlansMap]: ... - - @property - def applicationSpecificInfo(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___StringMap]: ... - - @property - def tensorBoundShapes(self) -> caffe2.proto.caffe2_pb2.TensorBoundShapes: ... - - @property - def aotConfig(self) -> caffe2.proto.caffe2_pb2.AOTConfig: ... - - def __init__(self, - *, - blobs : typing.Optional[typing.Iterable[global___BlobsMap]] = ..., - nets : typing.Optional[typing.Iterable[global___NetsMap]] = ..., - modelInfo : typing.Optional[global___ModelInfo] = ..., - plans : typing.Optional[typing.Iterable[global___PlansMap]] = ..., - applicationSpecificInfo : typing.Optional[typing.Iterable[global___StringMap]] = ..., - blobsOrder : typing.Optional[typing.Iterable[typing.Text]] = ..., - preLoadBlobs : typing.Optional[typing.Iterable[typing.Text]] = ..., - tensorBoundShapes : typing.Optional[caffe2.proto.caffe2_pb2.TensorBoundShapes] = ..., - requestOnlyEmbeddings : typing.Optional[typing.Iterable[typing.Text]] = ..., - aotConfig : typing.Optional[caffe2.proto.caffe2_pb2.AOTConfig] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"aotConfig",b"aotConfig",u"modelInfo",b"modelInfo",u"tensorBoundShapes",b"tensorBoundShapes"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"aotConfig",b"aotConfig",u"applicationSpecificInfo",b"applicationSpecificInfo",u"blobs",b"blobs",u"blobsOrder",b"blobsOrder",u"modelInfo",b"modelInfo",u"nets",b"nets",u"plans",b"plans",u"preLoadBlobs",b"preLoadBlobs",u"requestOnlyEmbeddings",b"requestOnlyEmbeddings",u"tensorBoundShapes",b"tensorBoundShapes"]) -> None: ... -global___MetaNetDef = MetaNetDef diff --git a/caffe2/proto/predictor_consts.proto b/caffe2/proto/predictor_consts.proto deleted file mode 100644 index d45ecb8396c7..000000000000 --- a/caffe2/proto/predictor_consts.proto +++ /dev/null @@ -1,36 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -message PredictorConsts { - // Important - to ensure ordered traversal of the DB, these must be - // set in the given (lexicographic) order in the input DBReader. - optional string META_NET_DEF = 1 [ default = "!!META_NET_DEF" ]; - - // The key the Predictor sets in the global workspace for DBReader - // consumed by the LoadOp in GLOBAL_INIT_NET. - - optional string PREDICTOR_DBREADER = 2 [ default = "!!PREDICTOR_DBREADER" ]; - - // Blob types used in MetaNetDef blobs - optional string PARAMETERS_BLOB_TYPE = 3 [ default = "PARAMETERS_BLOB_TYPE" ]; - optional string INPUTS_BLOB_TYPE = 4 [ default = "INPUTS_BLOB_TYPE" ]; - optional string OUTPUTS_BLOB_TYPE = 5 [ default = "OUTPUTS_BLOB_TYPE" ]; - - // Net types used in MetaNetDef nets - optional string GLOBAL_INIT_NET_TYPE = 6 [ default = "GLOBAL_INIT_NET_TYPE" ]; - optional string PREDICT_INIT_NET_TYPE = 7 - [ default = "PREDICT_INIT_NET_TYPE" ]; - optional string PREDICT_NET_TYPE = 8 [ default = "PREDICT_NET_TYPE" ]; - optional string SINGLE_PREDICTOR = 9 [ default = "SINGLE_PREDICTOR" ]; - optional string MULTI_PREDICTOR = 10 [ default = "MULTI_PREDICTOR" ]; - optional string TRAIN_INIT_PLAN_TYPE = 11 - [ default = "TRAIN_INIT_PLAN_TYPE" ]; - optional string TRAIN_PLAN_TYPE = 12 [ default = "TRAIN_PLAN_TYPE" ]; - - // Shape info blob name - optional string SHAPE_INFO_BLOB = 13 [ default = "SHAPE_INFO_BLOB" ]; - // Sequential blob reader name - optional string DEFERRED_BLOB_READER = 14 - [ default = "__DEFERRED_BLOB_READER__" ]; -} diff --git a/caffe2/proto/predictor_consts_pb2.pyi b/caffe2/proto/predictor_consts_pb2.pyi deleted file mode 100644 index 83b62ae0e949..000000000000 --- a/caffe2/proto/predictor_consts_pb2.pyi +++ /dev/null @@ -1,63 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class PredictorConsts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - META_NET_DEF_FIELD_NUMBER: int - PREDICTOR_DBREADER_FIELD_NUMBER: int - PARAMETERS_BLOB_TYPE_FIELD_NUMBER: int - INPUTS_BLOB_TYPE_FIELD_NUMBER: int - OUTPUTS_BLOB_TYPE_FIELD_NUMBER: int - GLOBAL_INIT_NET_TYPE_FIELD_NUMBER: int - PREDICT_INIT_NET_TYPE_FIELD_NUMBER: int - PREDICT_NET_TYPE_FIELD_NUMBER: int - SINGLE_PREDICTOR_FIELD_NUMBER: int - MULTI_PREDICTOR_FIELD_NUMBER: int - TRAIN_INIT_PLAN_TYPE_FIELD_NUMBER: int - TRAIN_PLAN_TYPE_FIELD_NUMBER: int - SHAPE_INFO_BLOB_FIELD_NUMBER: int - DEFERRED_BLOB_READER_FIELD_NUMBER: int - META_NET_DEF: typing.Text = ... - PREDICTOR_DBREADER: typing.Text = ... - PARAMETERS_BLOB_TYPE: typing.Text = ... - INPUTS_BLOB_TYPE: typing.Text = ... - OUTPUTS_BLOB_TYPE: typing.Text = ... - GLOBAL_INIT_NET_TYPE: typing.Text = ... - PREDICT_INIT_NET_TYPE: typing.Text = ... - PREDICT_NET_TYPE: typing.Text = ... - SINGLE_PREDICTOR: typing.Text = ... - MULTI_PREDICTOR: typing.Text = ... - TRAIN_INIT_PLAN_TYPE: typing.Text = ... - TRAIN_PLAN_TYPE: typing.Text = ... - SHAPE_INFO_BLOB: typing.Text = ... - DEFERRED_BLOB_READER: typing.Text = ... - - def __init__(self, - *, - META_NET_DEF : typing.Optional[typing.Text] = ..., - PREDICTOR_DBREADER : typing.Optional[typing.Text] = ..., - PARAMETERS_BLOB_TYPE : typing.Optional[typing.Text] = ..., - INPUTS_BLOB_TYPE : typing.Optional[typing.Text] = ..., - OUTPUTS_BLOB_TYPE : typing.Optional[typing.Text] = ..., - GLOBAL_INIT_NET_TYPE : typing.Optional[typing.Text] = ..., - PREDICT_INIT_NET_TYPE : typing.Optional[typing.Text] = ..., - PREDICT_NET_TYPE : typing.Optional[typing.Text] = ..., - SINGLE_PREDICTOR : typing.Optional[typing.Text] = ..., - MULTI_PREDICTOR : typing.Optional[typing.Text] = ..., - TRAIN_INIT_PLAN_TYPE : typing.Optional[typing.Text] = ..., - TRAIN_PLAN_TYPE : typing.Optional[typing.Text] = ..., - SHAPE_INFO_BLOB : typing.Optional[typing.Text] = ..., - DEFERRED_BLOB_READER : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"DEFERRED_BLOB_READER",b"DEFERRED_BLOB_READER",u"GLOBAL_INIT_NET_TYPE",b"GLOBAL_INIT_NET_TYPE",u"INPUTS_BLOB_TYPE",b"INPUTS_BLOB_TYPE",u"META_NET_DEF",b"META_NET_DEF",u"MULTI_PREDICTOR",b"MULTI_PREDICTOR",u"OUTPUTS_BLOB_TYPE",b"OUTPUTS_BLOB_TYPE",u"PARAMETERS_BLOB_TYPE",b"PARAMETERS_BLOB_TYPE",u"PREDICTOR_DBREADER",b"PREDICTOR_DBREADER",u"PREDICT_INIT_NET_TYPE",b"PREDICT_INIT_NET_TYPE",u"PREDICT_NET_TYPE",b"PREDICT_NET_TYPE",u"SHAPE_INFO_BLOB",b"SHAPE_INFO_BLOB",u"SINGLE_PREDICTOR",b"SINGLE_PREDICTOR",u"TRAIN_INIT_PLAN_TYPE",b"TRAIN_INIT_PLAN_TYPE",u"TRAIN_PLAN_TYPE",b"TRAIN_PLAN_TYPE"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"DEFERRED_BLOB_READER",b"DEFERRED_BLOB_READER",u"GLOBAL_INIT_NET_TYPE",b"GLOBAL_INIT_NET_TYPE",u"INPUTS_BLOB_TYPE",b"INPUTS_BLOB_TYPE",u"META_NET_DEF",b"META_NET_DEF",u"MULTI_PREDICTOR",b"MULTI_PREDICTOR",u"OUTPUTS_BLOB_TYPE",b"OUTPUTS_BLOB_TYPE",u"PARAMETERS_BLOB_TYPE",b"PARAMETERS_BLOB_TYPE",u"PREDICTOR_DBREADER",b"PREDICTOR_DBREADER",u"PREDICT_INIT_NET_TYPE",b"PREDICT_INIT_NET_TYPE",u"PREDICT_NET_TYPE",b"PREDICT_NET_TYPE",u"SHAPE_INFO_BLOB",b"SHAPE_INFO_BLOB",u"SINGLE_PREDICTOR",b"SINGLE_PREDICTOR",u"TRAIN_INIT_PLAN_TYPE",b"TRAIN_INIT_PLAN_TYPE",u"TRAIN_PLAN_TYPE",b"TRAIN_PLAN_TYPE"]) -> None: ... -global___PredictorConsts = PredictorConsts diff --git a/caffe2/proto/prof_dag.proto b/caffe2/proto/prof_dag.proto deleted file mode 100644 index ab427a1c66fa..000000000000 --- a/caffe2/proto/prof_dag.proto +++ /dev/null @@ -1,68 +0,0 @@ -syntax = "proto2"; - -package caffe2; - -// A few notes about the Caffe2's protobuffer convention: -// (1) Most objects are registered by their types, such as operators and nets. -// For these, we have a string-type field "type" for registration purposes. -// (2) We do not use extension because that used to create quite some conflicts -// in Caffe's protobuf design. -// (3) We have not used any proto3 specific features, such as Any or Map. This -// is mainly for backward compatibility purposes but we may consider using -// those in the future. - -// A two number summary for a value. It also has count for restoring. -message TwoNumberStatsProto { - optional float mean = 1; - optional float stddev = 2; - optional int64 count = 3; -} - -// Blob profiling information. Profile for a blob is created every time -// a node outputs to the blob. -message BlobProfile { - // Name of the blob (corresponds to OperatorDef.output). - optional string name = 1; // required - - // Profiling statistics. - optional TwoNumberStatsProto bytes_used = 3; -} - -// Protobuf format to serialize profiler data. -message ProfDAGProto { - // The name for the operator - required string name = 1; - // The mean execution time - required float mean = 2; - // The standard deviation - required float stddev = 3; - - // New field to represent the numbers above, and with count. - optional TwoNumberStatsProto execution_time = 4; - - // Blob profiles that this node outputs. - repeated BlobProfile output_profile = 5; - - // The extra_info from the operator device option. - repeated string extra_info = 7; -} - -// Operator profiling information. -// -// Note: The indices for elements of 'stats' and the indices of -// 'output_profile' inside each 'stats' are assumed to match the -// indices of 'op' elements of a corresponding NetDef and the 'output' -// indices within each 'op'. -message ProfDAGProtos { - repeated ProfDAGProto stats = 1; - optional string net_name = 2; - repeated OpProfile ops_stats = 3; -} - -// Represents specification of an operation cost. -message OpProfile { - optional string idx = 1; - optional string net_name = 2; - optional string type = 3; - optional float exec_time_secs = 4; -} diff --git a/caffe2/proto/prof_dag_pb2.pyi b/caffe2/proto/prof_dag_pb2.pyi deleted file mode 100644 index 98affd51fd0b..000000000000 --- a/caffe2/proto/prof_dag_pb2.pyi +++ /dev/null @@ -1,126 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -class TwoNumberStatsProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - MEAN_FIELD_NUMBER: int - STDDEV_FIELD_NUMBER: int - COUNT_FIELD_NUMBER: int - mean: float = ... - stddev: float = ... - count: int = ... - - def __init__(self, - *, - mean : typing.Optional[float] = ..., - stddev : typing.Optional[float] = ..., - count : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"count",b"count",u"mean",b"mean",u"stddev",b"stddev"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"count",b"count",u"mean",b"mean",u"stddev",b"stddev"]) -> None: ... -global___TwoNumberStatsProto = TwoNumberStatsProto - -class BlobProfile(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - BYTES_USED_FIELD_NUMBER: int - name: typing.Text = ... - - @property - def bytes_used(self) -> global___TwoNumberStatsProto: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - bytes_used : typing.Optional[global___TwoNumberStatsProto] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"bytes_used",b"bytes_used",u"name",b"name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"bytes_used",b"bytes_used",u"name",b"name"]) -> None: ... -global___BlobProfile = BlobProfile - -class ProfDAGProto(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - NAME_FIELD_NUMBER: int - MEAN_FIELD_NUMBER: int - STDDEV_FIELD_NUMBER: int - EXECUTION_TIME_FIELD_NUMBER: int - OUTPUT_PROFILE_FIELD_NUMBER: int - EXTRA_INFO_FIELD_NUMBER: int - name: typing.Text = ... - mean: float = ... - stddev: float = ... - extra_info: google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text] = ... - - @property - def execution_time(self) -> global___TwoNumberStatsProto: ... - - @property - def output_profile(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___BlobProfile]: ... - - def __init__(self, - *, - name : typing.Optional[typing.Text] = ..., - mean : typing.Optional[float] = ..., - stddev : typing.Optional[float] = ..., - execution_time : typing.Optional[global___TwoNumberStatsProto] = ..., - output_profile : typing.Optional[typing.Iterable[global___BlobProfile]] = ..., - extra_info : typing.Optional[typing.Iterable[typing.Text]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"execution_time",b"execution_time",u"mean",b"mean",u"name",b"name",u"stddev",b"stddev"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"execution_time",b"execution_time",u"extra_info",b"extra_info",u"mean",b"mean",u"name",b"name",u"output_profile",b"output_profile",u"stddev",b"stddev"]) -> None: ... -global___ProfDAGProto = ProfDAGProto - -class ProfDAGProtos(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - STATS_FIELD_NUMBER: int - NET_NAME_FIELD_NUMBER: int - OPS_STATS_FIELD_NUMBER: int - net_name: typing.Text = ... - - @property - def stats(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProfDAGProto]: ... - - @property - def ops_stats(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___OpProfile]: ... - - def __init__(self, - *, - stats : typing.Optional[typing.Iterable[global___ProfDAGProto]] = ..., - net_name : typing.Optional[typing.Text] = ..., - ops_stats : typing.Optional[typing.Iterable[global___OpProfile]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"net_name",b"net_name"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"net_name",b"net_name",u"ops_stats",b"ops_stats",u"stats",b"stats"]) -> None: ... -global___ProfDAGProtos = ProfDAGProtos - -class OpProfile(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - IDX_FIELD_NUMBER: int - NET_NAME_FIELD_NUMBER: int - TYPE_FIELD_NUMBER: int - EXEC_TIME_SECS_FIELD_NUMBER: int - idx: typing.Text = ... - net_name: typing.Text = ... - type: typing.Text = ... - exec_time_secs: float = ... - - def __init__(self, - *, - idx : typing.Optional[typing.Text] = ..., - net_name : typing.Optional[typing.Text] = ..., - type : typing.Optional[typing.Text] = ..., - exec_time_secs : typing.Optional[float] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"exec_time_secs",b"exec_time_secs",u"idx",b"idx",u"net_name",b"net_name",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"exec_time_secs",b"exec_time_secs",u"idx",b"idx",u"net_name",b"net_name",u"type",b"type"]) -> None: ... -global___OpProfile = OpProfile diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto deleted file mode 100644 index 1ac4f5443579..000000000000 --- a/caffe2/proto/torch.proto +++ /dev/null @@ -1,114 +0,0 @@ -syntax = "proto2"; - -import "caffe2/proto/caffe2.proto"; - -package torch; - -message RecordRef { - optional string key = 1; -} - -message TensorDef { - repeated int64 dims = 1; - optional int64 offset = 2; - repeated int64 strides = 3; - // whether we compute the gradient for the parameter - optional bool requires_grad = 4; - optional caffe2.TensorProto.DataType data_type = 5; - - optional RecordRef data = 6; - - // device field stores the canonical device string, and it follows the - // format below: `(cpu|cuda)[:]`, e.g., 'cuda:0' - optional string device = 7; - - optional bool is_quantized = 8; - optional double scale = 9; - optional int64 zero_point = 10; -} - -message AttributeDef { - // The mypy type of this attribute - required string type = 1; - required string name = 2; - - // Offset into attribute table - required int64 id = 3; -} - -message ParameterDef { - // whether this parameter is registered as buffer or not - optional bool is_buffer = 1; - - // the offset into the tensor table where this parameter is stored - optional int64 tensor_id = 2; - - optional string name = 3; -} - -message ModuleDef { - repeated ModuleDef submodules = 1; - - optional RecordRef torchscript_arena = 2; - - repeated caffe2.NetDef caffe2_nets = 3; - - // because the old pickle modules may not be supported by torch_script, - // have to stored as pickle_arena at this moment. - optional RecordRef pickle_arena = 4; - // should be exposed by the Class Archive, so user can save - // module specific data which cannot be store in the graph or torch_script - optional RecordRef cpp_arena = 5; - - // the parameters of this module - repeated ParameterDef parameters = 6; - - // the names of inputs and outputs of the module are inferred - // from the main method. - - optional string name = 7; - - // whether apply the optimizations to this module, only applicable to - // script modules - optional bool optimize = 8; - - repeated AttributeDef attributes = 9; - - // Used for retrieving module state from the pickled IValues table - optional int64 get_state_attribute_id = 10; - - optional RecordRef torchscript_debug_arena = 11; -} - -// Represents all non-module code that the model depends on. -// Right now it's just a straight list of classes, defined in dependency order -// (i.e. dependencies appear before their dependers) -message LibDef { - optional RecordRef torchscript_arena = 1; -} - -enum ProtoVersion { PROTO_VERSION_NEWEST = 0x0000000000000006; } - -message ModelDef { - // numbers of fields that have been removed. Do not reuse them! - reserved 9; - reserved "libs"; - // for the proto version, to keep both backward and forward - // compatibility, please bump the proto_version when we add any - // change in the proto. runtime decides whether accept the - // model based on the ir_version. - optional int64 proto_version = 1; - - // main module of the model - optional ModuleDef main_module = 2; - - // to distinguish whether exported from c2 or torch - optional string producer_name = 3; - - // put build version here - optional string producer_version = 4; - - // the table contains all the tensor information - // the tensor id is defined as TensorProto.name - repeated TensorDef tensors = 5; -} diff --git a/caffe2/proto/torch_pb2.pyi b/caffe2/proto/torch_pb2.pyi deleted file mode 100644 index 33826e2aff5d..000000000000 --- a/caffe2/proto/torch_pb2.pyi +++ /dev/null @@ -1,218 +0,0 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" -import builtins -import caffe2.proto.caffe2_pb2 -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import typing -import typing_extensions - -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ... - -global___ProtoVersion = ProtoVersion -class _ProtoVersion(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ProtoVersion], type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor = ... - PROTO_VERSION_NEWEST = ProtoVersion.V(6) -class ProtoVersion(metaclass=_ProtoVersion): - V = typing.NewType('V', int) -PROTO_VERSION_NEWEST = ProtoVersion.V(6) - -class RecordRef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - KEY_FIELD_NUMBER: int - key: typing.Text = ... - - def __init__(self, - *, - key : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"key",b"key"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key"]) -> None: ... -global___RecordRef = RecordRef - -class TensorDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - DIMS_FIELD_NUMBER: int - OFFSET_FIELD_NUMBER: int - STRIDES_FIELD_NUMBER: int - REQUIRES_GRAD_FIELD_NUMBER: int - DATA_TYPE_FIELD_NUMBER: int - DATA_FIELD_NUMBER: int - DEVICE_FIELD_NUMBER: int - IS_QUANTIZED_FIELD_NUMBER: int - SCALE_FIELD_NUMBER: int - ZERO_POINT_FIELD_NUMBER: int - dims: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - offset: int = ... - strides: google.protobuf.internal.containers.RepeatedScalarFieldContainer[int] = ... - requires_grad: bool = ... - data_type: caffe2.proto.caffe2_pb2.TensorProto.DataType = ... - device: typing.Text = ... - is_quantized: bool = ... - scale: float = ... - zero_point: int = ... - - @property - def data(self) -> global___RecordRef: ... - - def __init__(self, - *, - dims : typing.Optional[typing.Iterable[int]] = ..., - offset : typing.Optional[int] = ..., - strides : typing.Optional[typing.Iterable[int]] = ..., - requires_grad : typing.Optional[bool] = ..., - data_type : typing.Optional[caffe2.proto.caffe2_pb2.TensorProto.DataType] = ..., - data : typing.Optional[global___RecordRef] = ..., - device : typing.Optional[typing.Text] = ..., - is_quantized : typing.Optional[bool] = ..., - scale : typing.Optional[float] = ..., - zero_point : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"data",b"data",u"data_type",b"data_type",u"device",b"device",u"is_quantized",b"is_quantized",u"offset",b"offset",u"requires_grad",b"requires_grad",u"scale",b"scale",u"zero_point",b"zero_point"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"data",b"data",u"data_type",b"data_type",u"device",b"device",u"dims",b"dims",u"is_quantized",b"is_quantized",u"offset",b"offset",u"requires_grad",b"requires_grad",u"scale",b"scale",u"strides",b"strides",u"zero_point",b"zero_point"]) -> None: ... -global___TensorDef = TensorDef - -class AttributeDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - TYPE_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - ID_FIELD_NUMBER: int - type: typing.Text = ... - name: typing.Text = ... - id: int = ... - - def __init__(self, - *, - type : typing.Optional[typing.Text] = ..., - name : typing.Optional[typing.Text] = ..., - id : typing.Optional[int] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"id",b"id",u"name",b"name",u"type",b"type"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"id",b"id",u"name",b"name",u"type",b"type"]) -> None: ... -global___AttributeDef = AttributeDef - -class ParameterDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - IS_BUFFER_FIELD_NUMBER: int - TENSOR_ID_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - is_buffer: bool = ... - tensor_id: int = ... - name: typing.Text = ... - - def __init__(self, - *, - is_buffer : typing.Optional[bool] = ..., - tensor_id : typing.Optional[int] = ..., - name : typing.Optional[typing.Text] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"is_buffer",b"is_buffer",u"name",b"name",u"tensor_id",b"tensor_id"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"is_buffer",b"is_buffer",u"name",b"name",u"tensor_id",b"tensor_id"]) -> None: ... -global___ParameterDef = ParameterDef - -class ModuleDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - SUBMODULES_FIELD_NUMBER: int - TORCHSCRIPT_ARENA_FIELD_NUMBER: int - CAFFE2_NETS_FIELD_NUMBER: int - PICKLE_ARENA_FIELD_NUMBER: int - CPP_ARENA_FIELD_NUMBER: int - PARAMETERS_FIELD_NUMBER: int - NAME_FIELD_NUMBER: int - OPTIMIZE_FIELD_NUMBER: int - ATTRIBUTES_FIELD_NUMBER: int - GET_STATE_ATTRIBUTE_ID_FIELD_NUMBER: int - TORCHSCRIPT_DEBUG_ARENA_FIELD_NUMBER: int - name: typing.Text = ... - optimize: bool = ... - get_state_attribute_id: int = ... - - @property - def submodules(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ModuleDef]: ... - - @property - def torchscript_arena(self) -> global___RecordRef: ... - - @property - def caffe2_nets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[caffe2.proto.caffe2_pb2.NetDef]: ... - - @property - def pickle_arena(self) -> global___RecordRef: ... - - @property - def cpp_arena(self) -> global___RecordRef: ... - - @property - def parameters(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ParameterDef]: ... - - @property - def attributes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AttributeDef]: ... - - @property - def torchscript_debug_arena(self) -> global___RecordRef: ... - - def __init__(self, - *, - submodules : typing.Optional[typing.Iterable[global___ModuleDef]] = ..., - torchscript_arena : typing.Optional[global___RecordRef] = ..., - caffe2_nets : typing.Optional[typing.Iterable[caffe2.proto.caffe2_pb2.NetDef]] = ..., - pickle_arena : typing.Optional[global___RecordRef] = ..., - cpp_arena : typing.Optional[global___RecordRef] = ..., - parameters : typing.Optional[typing.Iterable[global___ParameterDef]] = ..., - name : typing.Optional[typing.Text] = ..., - optimize : typing.Optional[bool] = ..., - attributes : typing.Optional[typing.Iterable[global___AttributeDef]] = ..., - get_state_attribute_id : typing.Optional[int] = ..., - torchscript_debug_arena : typing.Optional[global___RecordRef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"cpp_arena",b"cpp_arena",u"get_state_attribute_id",b"get_state_attribute_id",u"name",b"name",u"optimize",b"optimize",u"pickle_arena",b"pickle_arena",u"torchscript_arena",b"torchscript_arena",u"torchscript_debug_arena",b"torchscript_debug_arena"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"attributes",b"attributes",u"caffe2_nets",b"caffe2_nets",u"cpp_arena",b"cpp_arena",u"get_state_attribute_id",b"get_state_attribute_id",u"name",b"name",u"optimize",b"optimize",u"parameters",b"parameters",u"pickle_arena",b"pickle_arena",u"submodules",b"submodules",u"torchscript_arena",b"torchscript_arena",u"torchscript_debug_arena",b"torchscript_debug_arena"]) -> None: ... -global___ModuleDef = ModuleDef - -class LibDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - TORCHSCRIPT_ARENA_FIELD_NUMBER: int - - @property - def torchscript_arena(self) -> global___RecordRef: ... - - def __init__(self, - *, - torchscript_arena : typing.Optional[global___RecordRef] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"torchscript_arena",b"torchscript_arena"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"torchscript_arena",b"torchscript_arena"]) -> None: ... -global___LibDef = LibDef - -class ModelDef(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor = ... - PROTO_VERSION_FIELD_NUMBER: int - MAIN_MODULE_FIELD_NUMBER: int - PRODUCER_NAME_FIELD_NUMBER: int - PRODUCER_VERSION_FIELD_NUMBER: int - TENSORS_FIELD_NUMBER: int - proto_version: int = ... - producer_name: typing.Text = ... - producer_version: typing.Text = ... - - @property - def main_module(self) -> global___ModuleDef: ... - - @property - def tensors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TensorDef]: ... - - def __init__(self, - *, - proto_version : typing.Optional[int] = ..., - main_module : typing.Optional[global___ModuleDef] = ..., - producer_name : typing.Optional[typing.Text] = ..., - producer_version : typing.Optional[typing.Text] = ..., - tensors : typing.Optional[typing.Iterable[global___TensorDef]] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal[u"main_module",b"main_module",u"producer_name",b"producer_name",u"producer_version",b"producer_version",u"proto_version",b"proto_version"]) -> bool: ... - def ClearField(self, field_name: typing_extensions.Literal[u"main_module",b"main_module",u"producer_name",b"producer_name",u"producer_version",b"producer_version",u"proto_version",b"proto_version",u"tensors",b"tensors"]) -> None: ... -global___ModelDef = ModelDef diff --git a/caffe2/release-notes.md b/caffe2/release-notes.md deleted file mode 100644 index d449e98f78e3..000000000000 --- a/caffe2/release-notes.md +++ /dev/null @@ -1,175 +0,0 @@ -# Caffe2 v0.7.0 Release Notes - -## Installation - -This build is confirmed for: - -* Ubuntu 14.04 -* Ubuntu 16.06 - -### Required Dependencies - -```bash -sudo apt-get update -sudo apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler \ - python-dev \ - python-pip -sudo pip install numpy protobuf -``` - -### Optional GPU Support - -If you plan to use GPU instead of CPU only, then you should install NVIDIA CUDA and cuDNN, a GPU-accelerated library of primitives for deep neural networks. -[NVIDIA's detailed instructions](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu-installation) or if you're feeling lucky try the quick install set of commands below. - -**Update your graphics card drivers first!** Otherwise you may suffer from a wide range of difficult to diagnose errors. - -**For Ubuntu 14.04** - -```bash -sudo apt-get update && sudo apt-get install wget -y --no-install-recommends -wget "http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/cuda-repo-ubuntu1404_8.0.61-1_amd64.deb" -sudo dpkg -i cuda-repo-ubuntu1404_8.0.61-1_amd64.deb -sudo apt-get update -sudo apt-get install cuda -``` - -**For Ubuntu 16.04** - -```bash -sudo apt-get update && sudo apt-get install wget -y --no-install-recommends -wget "http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb" -sudo dpkg -i cuda-repo-ubuntu1604_8.0.61-1_amd64.deb -sudo apt-get update -sudo apt-get install cuda -``` - -#### Install cuDNN (all Ubuntu versions) - -``` -CUDNN_URL="http://developer.download.nvidia.com/compute/redist/cudnn/v5.1/cudnn-8.0-linux-x64-v5.1.tgz" -wget ${CUDNN_URL} -sudo tar -xzf cudnn-8.0-linux-x64-v5.1.tgz -C /usr/local -rm cudnn-8.0-linux-x64-v5.1.tgz && sudo ldconfig -``` - -### Optional Dependencies - -> Note `libgflags2` is for Ubuntu 14.04. `libgflags-dev` is for Ubuntu 16.04. - -```bash -# for Ubuntu 14.04 -sudo apt-get install -y --no-install-recommends libgflags2 -``` - -```bash -# for Ubuntu 16.04 -sudo apt-get install -y --no-install-recommends libgflags-dev -``` - -```bash -# for both Ubuntu 14.04 and 16.04 -sudo apt-get install -y --no-install-recommends \ - libgtest-dev \ - libiomp-dev \ - libleveldb-dev \ - liblmdb-dev \ - libopencv-dev \ - libopenmpi-dev \ - libsnappy-dev \ - openmpi-bin \ - openmpi-doc \ - python-pydot -sudo pip install \ - flask \ - graphviz \ - hypothesis \ - jupyter \ - matplotlib \ - pydot python-nvd3 \ - pyyaml \ - requests \ - scikit-image \ - scipy \ - setuptools \ - tornado -``` - -### Clone & Build - -```bash -git clone --recursive https://github.com/caffe2/caffe2.git && cd caffe2 -make && cd build && sudo make install -python -c 'from caffe2.python import core' 2>/dev/null && echo "Success" || echo "Failure" -``` - -Run this command below to test if your GPU build was a success. You will get a test output either way, but it will warn you at the top of the output if CPU was used instead along with other errors like missing libraries. - -```bash -python -m caffe2.python.operator_test.relu_op_test -``` - -### Environment Variables - -These environment variables may assist you depending on your current configuration. When using the install instructions above on the AWS Deep Learning AMI you don't need to set these variables. However, our Docker scripts built on Ubuntu-14.04 or NVIDIA's CUDA images seem to benefit from having these set. If you ran into problems with the build tests above then these are good things to check. Echo them first and see what you have and possibly append or replace with these directories. Also visit the troubleshooting section below. - -```bash -echo $PYTHONPATH -# export PYTHONPATH=/usr/local:$PYTHONPATH -# export PYTHONPATH=$PYTHONPATH:/home/ubuntu/caffe2/build -echo $LD_LIBRARY_PATH -# export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH -``` - -### Setting Up Tutorials & Jupyter Server - -If you're running this all on a cloud computer, you probably won't have a UI or way to view the IPython notebooks by default. Typically, you would launch them locally with `ipython notebook` and you would see a localhost:8888 webpage pop up with the directory of notebooks running. The following example will show you how to launch the Jupyter server and connect to remotely via an SSH tunnel. - -First configure your cloud server to accept port 8889, or whatever you want, but change the port in the following commands. On AWS you accomplish this by adding a rule to your server's security group allowing a TCP inbound on port 8889. Otherwise you would adjust iptables for this. - -Next you launch the Jupyter server. - -``` -jupyter notebook --no-browser --port=8889 -``` - -Then create the SSH tunnel. This will pass the cloud server's Jupyter instance to your localhost 8888 port for you to use locally. The example below is templated after how you would connect AWS, where `your-public-cert.pem` is your own public certificate and `ubuntu@super-rad-GPU-instance.compute-1.amazonaws.com` is your login to your cloud server. You can easily grab this on AWS by going to Instances > Connect and copy the part after `ssh` and swap that out in the command below. - -``` -ssh -N -f -L localhost:8888:localhost:8889 -i "your-public-cert.pem" ubuntu@super-rad-GPU-instance.compute-1.amazonaws.com -``` - -### Troubleshooting - -|Python errors|| -|----|-----| -|Python version | [Python](https://www.python.org/) is core to run Caffe2. We currently require [Python2.7](https://www.python.org/download/releases/2.7/). *Ubuntu 14.04 and greater have Python built in by default*, and that can be used to run Caffe2. To check your version: `python --version`| -|Solution | If you want the developer version of python, you could install the `dev` package for Python: `sudo apt-get install python-dev`| -|Python environment | You may have another version of Python installed or need to support Python version 3 for other projects.| -|Solution | Try virtualenv or Anaconda. The [Anaconda](https://www.continuum.io/downloads) platform provides a single script to install many of the necessary packages for Caffe2, including Python. Using Anaconda is outside the scope of these instructions, but if you are interested, it may work well for you.| -|pip version | If you plan to use Python with Caffe2 then you need pip.| -|Solution | `sudo apt-get install python-pip` and also try using pip2 instead of pip.| -|"AttributeError: 'module' object has no attribute 'MakeArgument'" | Occurs when calling `core.CreateOperator`| -|Solution | Check your install directory (`/usr/local/`), and remove the folder `/caffe2/python/utils`| - -|Building from source|| -|----|-----| -|OS version | Caffe2 requires Ubuntu 14.04 or greater.| -|git | While you can download the Caffe2 source code and submodules directly from GitHub as a zip, using git makes it much easier.| -|Solution | `sudo apt-get install git`| -|protobuf | You may experience an error related to protobuf during the make step.| -|Solution | Make sure you've installed protobuf in **both** of these two ways: `sudo apt-get install libprotobuf-dev protobuf-compiler && sudo pip install protobuf`| -|libgflags2 error | This optional dependency is for Ubuntu 14.04.| -|Solution | Use `apt-get install libgflags-dev` for Ubuntu 16.04.| - -|GPU Support|| -|----|-----| -|GPU errors | Unsupported GPU or wrong version| -|Solution | You need to know the specific `deb` for your version of Linux. `sudo dpkg -i| |cuda-repo-__.deb` Refer to NVIDIA's [installation guide](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu-installation).| -|Build issues | Be warned that installing CUDA and cuDNN will increase the size of your build by about 4GB, so plan to have at least 12GB for your Ubuntu disk size.| diff --git a/caffe2/requirements.txt b/caffe2/requirements.txt deleted file mode 100644 index aa8d2be43aa5..000000000000 --- a/caffe2/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy -enum34 -pyyaml -requests diff --git a/caffe2/test/assets/squeeze_predict_net.pb b/caffe2/test/assets/squeeze_predict_net.pb deleted file mode 100644 index ac4c476b91cc..000000000000 Binary files a/caffe2/test/assets/squeeze_predict_net.pb and /dev/null differ diff --git a/caffe2/test/caffe2_gtest_main.cc b/caffe2/test/caffe2_gtest_main.cc deleted file mode 100644 index 920b79ef4d65..000000000000 --- a/caffe2/test/caffe2_gtest_main.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2006, Google Inc. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#include - -#include -#include "caffe2/core/flags.h" -#include "caffe2/core/init.h" - -C10_DEFINE_string( - caffe_test_root, - "gen/", - "The root of the caffe test folder."); - -GTEST_API_ int main(int argc, char** argv) { - // std::cout << "Running main() from gtest_main.cc\n"; - testing::InitGoogleTest(&argc, argv); - caffe2::GlobalInit(&argc, &argv); - return RUN_ALL_TESTS(); -} diff --git a/caffe2/utils/GpuAtomics.cuh b/caffe2/utils/GpuAtomics.cuh deleted file mode 100644 index 2bbcc14fa7da..000000000000 --- a/caffe2/utils/GpuAtomics.cuh +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_ATOMICS_H_ -#define CAFFE2_UTILS_GPU_ATOMICS_H_ - -#include - -namespace caffe2 { - -namespace { - -template -inline __device__ void gpu_atomic_add(T* address, const T val) { - atomicAdd(address, val); -} - -template <> -inline __device__ void gpu_atomic_add(float* address, const float val) { -#if defined(USE_ROCM) && defined(__gfx908__) - atomicAddNoRet(address, val); -#else - atomicAdd(address, val); -#endif -} - -} // namespace - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_ATOMICS_H_ diff --git a/caffe2/utils/GpuBitonicSort.cuh b/caffe2/utils/GpuBitonicSort.cuh deleted file mode 100644 index 45cb298733a8..000000000000 --- a/caffe2/utils/GpuBitonicSort.cuh +++ /dev/null @@ -1,178 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_BITONIC_SORT_H_ -#define CAFFE2_UTILS_GPU_BITONIC_SORT_H_ - -#include "caffe2/utils/math.h" -#include "caffe2/utils/GpuDefs.cuh" - -namespace caffe2 { - -// Returns true if the given integer type is a power-of-2 (positive only) -// Note(jiayq): windows reported an error per -// https://github.com/caffe2/caffe2/issues/997 -// and as a result will make it a macro. -#ifdef _MSC_VER -#define integerIsPowerOf2(v) ((v) && !((v) & ((v) - 1))) -#else // _MSC_VER -template -constexpr bool integerIsPowerOf2(T v) { - return (v && !(v & (v - 1))); -} -#endif // _MSC_VER - -/// The maximum in-block bitonic sort we support -constexpr int kMaxBitonicSortSize = 4096; - -template -__device__ inline void swapVars(T& t1, T& t2) { - T tmp = t1; - t1 = t2; - t2 = tmp; -} - -template -__device__ inline void bitonicSwap(K& kA, V& vA, - K& kB, V& vB, - bool dir, - const Comparator& comp) { - bool swap = comp(kA, vA, kB, vB); - if (swap == dir) { - swapVars(kA, kB); - swapVars(vA, vB); - } -}; - -template -__device__ inline void bitonicSort(K* keys, - V* values, - const Comparator& comp) { - static_assert(Power2SortSize <= kMaxBitonicSortSize, - "sort size <= 4096 only supported"); - // Assume the sort is taking place in shared memory - // static_assert(Power2SortSize * (sizeof(K) + sizeof(V)) < 32768, - // "sort data too large (>32768 bytes)"); - static_assert(integerIsPowerOf2(Power2SortSize), - "sort size must be power of 2"); - static_assert(integerIsPowerOf2(ThreadsPerBlock), - "threads in block must be power of 2"); - - // If what we are sorting is too small, then not all threads - // participate - constexpr int numThreadsForSort = Power2SortSize / 2; - constexpr bool allThreads = numThreadsForSort >= ThreadsPerBlock; - - // If what we are sorting is too large, then threads must loop more - // than once - constexpr int loopPerThread = - allThreads ? numThreadsForSort / ThreadsPerBlock : 1; - -#pragma unroll - for (int size = 2; size < Power2SortSize; size *= 2) { - -#pragma unroll - for (int stride = size / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * ThreadsPerBlock + threadIdx.x; - bool flag = ((threadId & (size / 2)) != 0); - - int pos = 2 * threadId - (threadId & (stride - 1)); - - if (allThreads || (threadId < numThreadsForSort)) { - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - flag, comp); - } - - __syncthreads(); - } - } - } - -#pragma unroll - for (int stride = Power2SortSize / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * ThreadsPerBlock + threadIdx.x; - - int pos = 2 * threadId - (threadId & (stride - 1)); - - if (allThreads || (threadId < numThreadsForSort)) { - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - false, comp); - } - - __syncthreads(); - } - } -} - -template -__device__ inline void warpBitonicSort(K* keys, - V* values, - const Comparator& comp) { - // Smaller sorts should use a warp shuffle sort - static_assert(Power2SortSize > kWarpSize, - "sort not large enough"); - static_assert(integerIsPowerOf2(Power2SortSize), - "sort size must be power of 2"); - static_assert(Power2SortSize <= kMaxBitonicSortSize, - "sort size <= 4096 only supported"); - - // If what we are sorting is too large, then lanes must loop more - // than once - constexpr int loopPerThread = (Power2SortSize / 2) / kWarpSize; - int laneId = getLaneId(); - -#pragma unroll - for (int size = 2; size < Power2SortSize; size *= 2) { - -#pragma unroll - for (int stride = size / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * kWarpSize + laneId; - bool flag = ((threadId & (size / 2)) != 0); - - int pos = 2 * threadId - (threadId & (stride - 1)); - - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - flag, comp); - - __threadfence_block(); - } - } - } - -#pragma unroll - for (int stride = Power2SortSize / 2; stride > 0; stride /= 2) { - -#pragma unroll - for (int loop = 0; loop < loopPerThread; ++loop) { - int threadId = loop * kWarpSize + laneId; - - int pos = 2 * threadId - (threadId & (stride - 1)); - - bitonicSwap( - keys[pos], values[pos], - keys[pos + stride], values[pos + stride], - false, comp); - - __threadfence_block(); - } - } -} - - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_BITONIC_SORT_H_ diff --git a/caffe2/utils/GpuDefs.cuh b/caffe2/utils/GpuDefs.cuh deleted file mode 100644 index fcf2c64ddcb1..000000000000 --- a/caffe2/utils/GpuDefs.cuh +++ /dev/null @@ -1,158 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_DEFS_H_ -#define CAFFE2_UTILS_GPU_DEFS_H_ - -#include - -namespace caffe2 { - -// Static definition of GPU warp size for unrolling and code generation - -#if defined(USE_ROCM) -constexpr int kWarpSize = warpSize; // = 64 (Defined in hip_runtime.h) -#else -constexpr int kWarpSize = 32; -#endif // __CUDA_ARCH__ - -// -// Interfaces to PTX instructions for which there appears to be no -// intrinsic -// - -template -struct Bitfield {}; - -template <> -struct Bitfield { - static __device__ __forceinline__ - unsigned int getBitfield(unsigned int val, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned int m = (1u << len) - 1u; - return (val >> pos) & m; -#else - unsigned int ret; - asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } - - static __device__ __forceinline__ - unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned int m = (1u << len) - 1u; - toInsert &= m; - toInsert <<= pos; - m <<= pos; - - return (val & ~m) | toInsert; -#else - unsigned int ret; - asm("bfi.b32 %0, %1, %2, %3, %4;" : - "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } -}; - -template <> -struct Bitfield { - static __device__ __forceinline__ - unsigned long long int getBitfield(unsigned long long int val, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned long long int m = (1u << len) - 1u; - return (val >> pos) & m; -#else - unsigned long long int ret; - asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } - - static __device__ __forceinline__ - unsigned long long int setBitfield(unsigned long long int val, unsigned long long int toInsert, int pos, int len) { -#if defined(USE_ROCM) - pos &= 0xff; - len &= 0xff; - - unsigned long long int m = (1u << len) - 1u; - toInsert &= m; - toInsert <<= pos; - m <<= pos; - - return (val & ~m) | toInsert; -#else - unsigned long long int ret; - asm("bfi.b64 %0, %1, %2, %3, %4;" : - "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len)); - return ret; -#endif // USE_ROCM - } -}; - -__device__ __forceinline__ int getLaneId() { -#if defined(USE_ROCM) - return __lane_id(); -#else - int laneId; - asm("mov.s32 %0, %%laneid;" : "=r"(laneId) ); - return laneId; -#endif // USE_ROCM -} - -#if defined(USE_ROCM) -__device__ __forceinline__ unsigned long long int getLaneMaskLt() { - unsigned long long int m = (1ull << getLaneId()) - 1ull; - return m; -} - -__device__ __forceinline__ unsigned long long int getLaneMaskLe() { - unsigned long long int m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); - return m; -} - -__device__ __forceinline__ unsigned long long int getLaneMaskGt() { - unsigned long long int m = getLaneMaskLe(); - return m ? ~m : m; -} - -__device__ __forceinline__ unsigned long long int getLaneMaskGe() { - unsigned long long int m = getLaneMaskLt(); - return ~m; -} -#else -__device__ __forceinline__ unsigned getLaneMaskLt() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); - return mask; -} - -__device__ __forceinline__ unsigned getLaneMaskLe() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); - return mask; -} - -__device__ __forceinline__ unsigned getLaneMaskGt() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); - return mask; -} - -__device__ __forceinline__ unsigned getLaneMaskGe() { - unsigned mask; - asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); - return mask; -} -#endif // USE_ROCM - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_DEFS_H_ diff --git a/caffe2/utils/GpuScanUtils.cuh b/caffe2/utils/GpuScanUtils.cuh deleted file mode 100644 index 0f6823d8e85e..000000000000 --- a/caffe2/utils/GpuScanUtils.cuh +++ /dev/null @@ -1,133 +0,0 @@ -#ifndef CAFFE2_UTILS_GPU_SCAN_UTILS_H_ -#define CAFFE2_UTILS_GPU_SCAN_UTILS_H_ - -#include "caffe2/utils/GpuDefs.cuh" - -namespace caffe2 { - -// from the cutorch library; can probably be replaced with their CUB -// equivalents -// Collection of in-kernel scan / prefix sum utilities - -// Inclusive prefix sum using shared memory -template -__device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) { - // FIXME: this is a slow, simple implementation; need up/down sweep, - // prevent smem conflicts - smem[threadIdx.x] = in; - - __syncthreads(); - - for (int offset = 1; offset < blockDim.x; offset *= 2) { - T val = 0; - - if (threadIdx.x >= offset) { - val = binop(smem[threadIdx.x - offset], smem[threadIdx.x]); - } - - __syncthreads(); - if (threadIdx.x >= offset) { - smem[threadIdx.x] = val; - } - - __syncthreads(); - } - - *out = smem[threadIdx.x]; - - // Prevent write-after-read dependencies on smem usage above if necessary - if (KillWARDependency) { - __syncthreads(); - } -} - -// Exclusive prefix sum using shared memory -template -__device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunction binop) { - // FIXME: crappy implementation - // We kill write-after-read dependencies separately below, hence the `false` - inclusivePrefixScan(smem, in, out, binop); - - *out -= in; - *carry = smem[blockDim.x - 1]; - - // Prevent write-after-read dependencies on smem usage above if necessary - if (KillWARDependency) { - __syncthreads(); - } -} - -// Inclusive prefix sum for binary vars using intra-warp voting + -// shared memory -template -__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { - // Within-warp, we use warp voting. -#if defined(USE_ROCM) - unsigned long long int vote = __ballot(in); - - T index = __popcll(getLaneMaskLe() & vote); - T carry = __popcll(vote); -#else - T vote = __ballot_sync(__activemask(), in); - T index = __popc(getLaneMaskLe() & vote); - T carry = __popc(vote); -#endif // USE_ROCM - - int warp = threadIdx.x / kWarpSize; - - // Per each warp, write out a value - if (getLaneId() == 0) { - smem[warp] = carry; - } - - __syncthreads(); - - // Sum across warps in one thread. This appears to be faster than a - // warp shuffle scan for CC 3.0+ - if (threadIdx.x == 0) { - int current = 0; - for (int i = 0; i < blockDim.x / kWarpSize; ++i) { - T v = smem[i]; - smem[i] = binop(smem[i], current); - current = binop(current, v); - } - } - - __syncthreads(); - - // load the carry from the preceding warp - if (warp >= 1) { - index = binop(index, smem[warp - 1]); - } - - *out = index; - - if (KillWARDependency) { - __syncthreads(); - } -} - -// Exclusive prefix sum for binary vars using intra-warp voting + -// shared memory -template -__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) { - inclusiveBinaryPrefixScan(smem, in, out, binop); - - // Inclusive to exclusive - *out -= (T) in; - - // The outgoing carry for all threads is the last warp's sum -#if defined(USE_ROCM) - *carry = smem[math::DivUp(blockDim.x, kWarpSize) - 1]; -#else - *carry = smem[(blockDim.x / kWarpSize) - 1]; -#endif // USE_ROCM - - if (KillWARDependency) { - __syncthreads(); - } -} - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_GPU_SCAN_UTILS_H_ diff --git a/caffe2/utils/bench_utils.cc b/caffe2/utils/bench_utils.cc deleted file mode 100644 index baa8d34fd146..000000000000 --- a/caffe2/utils/bench_utils.cc +++ /dev/null @@ -1,120 +0,0 @@ -#if !defined(__s390x__) && !defined(__powerpc__) -#include -#else -#include -#endif -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include -// NOLINTNEXTLINE(modernize-deprecated-headers) -#include - -#include "caffe2/core/logging.h" -#include "caffe2/utils/bench_utils.h" - -namespace caffe2 { - -uint32_t wipe_cache() { - static uint32_t* wipe_buffer = nullptr; - static size_t wipe_size = 0; - - if (wipe_buffer == nullptr) { -#if !defined(__s390x__) && !defined(__powerpc__) - CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo"); - const cpuinfo_processor* processor = cpuinfo_get_processor(0); - if (processor->cache.l4 != nullptr) { - wipe_size = processor->cache.l4->size; - } else if (processor->cache.l3 != nullptr) { - wipe_size = processor->cache.l3->size; - } else if (processor->cache.l2 != nullptr) { - wipe_size = processor->cache.l2->size; - } else { - wipe_size = processor->cache.l1d->size; - } -#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 - /* - * On ARM precise cache size is not available, and cpuinfo may - * underestimate. Use max for uArch (see src/arm/cache.c) - */ - switch (processor->core->uarch) { - case cpuinfo_uarch_cortex_a5: - wipe_size = 512 * 1024; /* Max observed */ - break; - case cpuinfo_uarch_cortex_a7: - wipe_size = 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a8: - wipe_size = 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a9: - wipe_size = 1024 * 1024; /* Max observed */ - break; - case cpuinfo_uarch_cortex_a12: - case cpuinfo_uarch_cortex_a17: - wipe_size = 8 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a15: - wipe_size = 4 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a35: - wipe_size = 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a53: - wipe_size = 2 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a57: - wipe_size = 2 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a72: - wipe_size = 4 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a73: - wipe_size = 8 * 1024 * 1024; /* uArch max */ - break; - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_cortex_a75: - case cpuinfo_uarch_meerkat_m3: - wipe_size = 4 * 1024 * 1024; /* DynamIQ max */ - break; - default: - wipe_size = 60 * 1024 * 1024; - break; - } -#endif -#elif defined (__s390x__) - wipe_size = sysconf(_SC_LEVEL4_CACHE_SIZE); - if (wipe_size <= 0) - { - /* - * Take current max L4 cache size for s390x - */ - wipe_size = 1024 * 1024 * 1024; - } -#else - /* ppc64le */ - wipe_size = sysconf(_SC_LEVEL4_CACHE_SIZE); - if (wipe_size <= 0) { - wipe_size = sysconf(_SC_LEVEL3_CACHE_SIZE); - if (wipe_size <= 0) { - wipe_size = sysconf(_SC_LEVEL2_CACHE_SIZE); - if(wipe_size <= 0) { - wipe_size = sysconf(_SC_LEVEL1D_CACHE_SIZE); - } - } - } -#endif - LOG(INFO) << "Allocating cache wipe buffer of size " << wipe_size; - // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) - wipe_buffer = static_cast(malloc(wipe_size)); - CAFFE_ENFORCE(wipe_buffer != nullptr); - } - uint32_t hash = 0; - for (uint32_t i = 0; i * sizeof(uint32_t) < wipe_size; i += 8) { - // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) - hash ^= wipe_buffer[i]; - wipe_buffer[i] = hash; - } - /* Make sure compiler doesn't optimize the loop away */ - return hash; -} - -} /* namespace caffe2 */ diff --git a/caffe2/utils/bench_utils.h b/caffe2/utils/bench_utils.h deleted file mode 100644 index 59997edad58d..000000000000 --- a/caffe2/utils/bench_utils.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CAFFE2_UTILS_BENCH_UTILS_H_ -#define CAFFE2_UTILS_BENCH_UTILS_H_ - -#include - -#include - -namespace caffe2 { - -TORCH_API uint32_t wipe_cache(); - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_BENCH_UTILS_H_ diff --git a/caffe2/utils/cast.h b/caffe2/utils/cast.h deleted file mode 100644 index 6f9db0837946..000000000000 --- a/caffe2/utils/cast.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -namespace cast { - -inline TensorProto_DataType GetCastDataType(const ArgumentHelper& helper, std::string arg) { - TensorProto_DataType to; - if (helper.HasSingleArgumentOfType(arg)) { - string s = helper.GetSingleArgument(arg, "float"); - std::transform(s.begin(), s.end(), s.begin(), ::toupper); -#ifndef CAFFE2_USE_LITE_PROTO - CAFFE_ENFORCE(TensorProto_DataType_Parse(s, &to), "Unknown 'to' argument: ", s); -#else - -// Manually implement in the lite proto case. -#define X(t) \ - if (s == #t) { \ - return TensorProto_DataType_##t; \ - } - - X(FLOAT); - X(INT32); - X(BYTE); - X(STRING); - X(BOOL); - X(UINT8); - X(INT8); - X(UINT16); - X(INT16); - X(INT64); - X(FLOAT16); - X(DOUBLE); -#undef X - CAFFE_THROW("Unhandled type argument: ", s); - -#endif - } else { - to = static_cast( - helper.GetSingleArgument(arg, TensorProto_DataType_FLOAT)); - } - return to; -} - -}; // namespace cast - -}; // namespace caffe2 diff --git a/caffe2/utils/cast_test.cc b/caffe2/utils/cast_test.cc deleted file mode 100644 index 680e87b3aecc..000000000000 --- a/caffe2/utils/cast_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include - -#include - -#include "caffe2/utils/cast.h" - -namespace caffe2 { - -TEST(CastTest, GetCastDataType) { - auto castOp = [](std::string t) { - // Ensure lowercase. - std::transform(t.begin(), t.end(), t.begin(), ::tolower); - auto op = CreateOperatorDef("Cast", "", {}, {}); - AddArgument("to", t, &op); - return op; - }; - -#define X(t) \ - EXPECT_EQ( \ - TensorProto_DataType_##t, \ - cast::GetCastDataType(ArgumentHelper(castOp(#t)), "to")); - - X(FLOAT); - X(INT32); - X(BYTE); - X(STRING); - X(BOOL); - X(UINT8); - X(INT8); - X(UINT16); - X(INT16); - X(INT64); - X(FLOAT16); - X(DOUBLE); -#undef X -} - -} // namespace caffe2 diff --git a/caffe2/utils/cblas.h b/caffe2/utils/cblas.h deleted file mode 100644 index c91b8bf8c530..000000000000 --- a/caffe2/utils/cblas.h +++ /dev/null @@ -1,606 +0,0 @@ -// This is the exact cblas.h header file, placed here purely in order to get -// the enums. - -#include "caffe2/core/macros.h" - -#ifndef CBLAS_H -#ifdef CAFFE2_USE_MKL -#include -#else // CAFFE2_USE_MKL - -#ifndef CBLAS_ENUM_DEFINED_H - #define CBLAS_ENUM_DEFINED_H - enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; - enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, - AtlasConj=114}; - enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; - enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; - enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; -#endif - -#ifndef CBLAS_ENUM_ONLY -#define CBLAS_H -#define CBLAS_INDEX int - -int cblas_errprn(int ierr, int info, char *form, ...); -void cblas_xerbla(int p, const char *rout, const char *form, ...); - -/* - * =========================================================================== - * Prototypes for level 1 BLAS functions (complex are recast as routines) - * =========================================================================== - */ -float cblas_sdsdot(const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY); -double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, - const int incY); -float cblas_sdot(const int N, const float *X, const int incX, - const float *Y, const int incY); -double cblas_ddot(const int N, const double *X, const int incX, - const double *Y, const int incY); -/* - * Functions having prefixes Z and C only - */ -void cblas_cdotu_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotu); -void cblas_cdotc_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotc); - -void cblas_zdotu_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotu); -void cblas_zdotc_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotc); - - -/* - * Functions having prefixes S D SC DZ - */ -float cblas_snrm2(const int N, const float *X, const int incX); -float cblas_sasum(const int N, const float *X, const int incX); - -double cblas_dnrm2(const int N, const double *X, const int incX); -double cblas_dasum(const int N, const double *X, const int incX); - -float cblas_scnrm2(const int N, const void *X, const int incX); -float cblas_scasum(const int N, const void *X, const int incX); - -double cblas_dznrm2(const int N, const void *X, const int incX); -double cblas_dzasum(const int N, const void *X, const int incX); - - -/* - * Functions having standard 4 prefixes (S D C Z) - */ -CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); -CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); -CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); -CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); - -/* - * =========================================================================== - * Prototypes for level 1 BLAS routines - * =========================================================================== - */ - -/* - * Routines with standard 4 prefixes (s, d, c, z) - */ -void cblas_sswap(const int N, float *X, const int incX, - float *Y, const int incY); -void cblas_scopy(const int N, const float *X, const int incX, - float *Y, const int incY); -void cblas_saxpy(const int N, const float alpha, const float *X, - const int incX, float *Y, const int incY); -void catlas_saxpby(const int N, const float alpha, const float *X, - const int incX, const float beta, float *Y, const int incY); -void catlas_sset - (const int N, const float alpha, float *X, const int incX); - -void cblas_dswap(const int N, double *X, const int incX, - double *Y, const int incY); -void cblas_dcopy(const int N, const double *X, const int incX, - double *Y, const int incY); -void cblas_daxpy(const int N, const double alpha, const double *X, - const int incX, double *Y, const int incY); -void catlas_daxpby(const int N, const double alpha, const double *X, - const int incX, const double beta, double *Y, const int incY); -void catlas_dset - (const int N, const double alpha, double *X, const int incX); - -void cblas_cswap(const int N, void *X, const int incX, - void *Y, const int incY); -void cblas_ccopy(const int N, const void *X, const int incX, - void *Y, const int incY); -void cblas_caxpy(const int N, const void *alpha, const void *X, - const int incX, void *Y, const int incY); -void catlas_caxpby(const int N, const void *alpha, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void catlas_cset - (const int N, const void *alpha, void *X, const int incX); - -void cblas_zswap(const int N, void *X, const int incX, - void *Y, const int incY); -void cblas_zcopy(const int N, const void *X, const int incX, - void *Y, const int incY); -void cblas_zaxpy(const int N, const void *alpha, const void *X, - const int incX, void *Y, const int incY); -void catlas_zaxpby(const int N, const void *alpha, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void catlas_zset - (const int N, const void *alpha, void *X, const int incX); - - -/* - * Routines with S and D prefix only - */ -void cblas_srotg(float *a, float *b, float *c, float *s); -void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); -void cblas_srot(const int N, float *X, const int incX, - float *Y, const int incY, const float c, const float s); -void cblas_srotm(const int N, float *X, const int incX, - float *Y, const int incY, const float *P); - -void cblas_drotg(double *a, double *b, double *c, double *s); -void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); -void cblas_drot(const int N, double *X, const int incX, - double *Y, const int incY, const double c, const double s); -void cblas_drotm(const int N, double *X, const int incX, - double *Y, const int incY, const double *P); - - -/* - * Routines with S D C Z CS and ZD prefixes - */ -void cblas_sscal(const int N, const float alpha, float *X, const int incX); -void cblas_dscal(const int N, const double alpha, double *X, const int incX); -void cblas_cscal(const int N, const void *alpha, void *X, const int incX); -void cblas_zscal(const int N, const void *alpha, void *X, const int incX); -void cblas_csscal(const int N, const float alpha, void *X, const int incX); -void cblas_zdscal(const int N, const double alpha, void *X, const int incX); - -/* - * Extra reference routines provided by ATLAS, but not mandated by the standard - */ -void cblas_crotg(void *a, void *b, void *c, void *s); -void cblas_zrotg(void *a, void *b, void *c, void *s); -void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY, - const float c, const float s); -void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY, - const double c, const double s); - -/* - * =========================================================================== - * Prototypes for level 2 BLAS - * =========================================================================== - */ - -/* - * Routines with standard 4 prefixes (S, D, C, Z) - */ -void cblas_sgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const float alpha, const float *A, const int lda, - const float *X, const int incX, const float beta, - float *Y, const int incY); -void cblas_sgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const float alpha, - const float *A, const int lda, const float *X, - const int incX, const float beta, float *Y, const int incY); -void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *A, const int lda, - float *X, const int incX); -void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const float *A, const int lda, - float *X, const int incX); -void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *Ap, float *X, const int incX); -void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *A, const int lda, float *X, - const int incX); -void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const float *A, const int lda, - float *X, const int incX); -void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *Ap, float *X, const int incX); - -void cblas_dgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const double alpha, const double *A, const int lda, - const double *X, const int incX, const double beta, - double *Y, const int incY); -void cblas_dgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const double alpha, - const double *A, const int lda, const double *X, - const int incX, const double beta, double *Y, const int incY); -void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *A, const int lda, - double *X, const int incX); -void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const double *A, const int lda, - double *X, const int incX); -void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *Ap, double *X, const int incX); -void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *A, const int lda, double *X, - const int incX); -void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const double *A, const int lda, - double *X, const int incX); -void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *Ap, double *X, const int incX); - -void cblas_cgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *X, const int incX, const void *beta, - void *Y, const int incY); -void cblas_cgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const void *alpha, - const void *A, const int lda, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, - void *X, const int incX); -void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); -void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, void *X, - const int incX); -void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - -void cblas_zgemv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *X, const int incX, const void *beta, - void *Y, const int incY); -void cblas_zgbmv(const enum CBLAS_ORDER Order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const void *alpha, - const void *A, const int lda, const void *X, - const int incX, const void *beta, void *Y, const int incY); -void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, - void *X, const int incX); -void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); -void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, void *X, - const int incX); -void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); -void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - - -/* - * Routines with S and D prefixes only - */ -void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *A, - const int lda, const float *X, const int incX, - const float beta, float *Y, const int incY); -void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const float alpha, const float *A, - const int lda, const float *X, const int incX, - const float beta, float *Y, const int incY); -void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *Ap, - const float *X, const int incX, - const float beta, float *Y, const int incY); -void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N, - const float alpha, const float *X, const int incX, - const float *Y, const int incY, float *A, const int lda); -void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, float *A, const int lda); -void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, float *Ap); -void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY, float *A, - const int lda); -void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY, float *A); - -void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *A, - const int lda, const double *X, const int incX, - const double beta, double *Y, const int incY); -void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const double alpha, const double *A, - const int lda, const double *X, const int incX, - const double beta, double *Y, const int incY); -void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *Ap, - const double *X, const int incX, - const double beta, double *Y, const int incY); -void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N, - const double alpha, const double *X, const int incX, - const double *Y, const int incY, double *A, const int lda); -void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, double *A, const int lda); -void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, double *Ap); -void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, const double *Y, const int incY, double *A, - const int lda); -void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, const double *Y, const int incY, double *A); - - -/* - * Routines with C and Z prefixes only - */ -void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *Ap, - const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const void *X, const int incX, - void *A, const int lda); -void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const void *X, - const int incX, void *A); -void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *Ap); - -void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *Ap, - const void *X, const int incX, - const void *beta, void *Y, const int incY); -void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const void *X, const int incX, - void *A, const int lda); -void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const void *X, - const int incX, void *A); -void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); -void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *Ap); - -/* - * =========================================================================== - * Prototypes for level 3 BLAS - * =========================================================================== - */ - -/* - * Routines with standard 4 prefixes (S, D, C, Z) - */ -void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const float alpha, const float *A, - const int lda, const float *B, const int ldb, - const float beta, float *C, const int ldc); -void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const float alpha, const float *A, const int lda, - const float *B, const int ldb, const float beta, - float *C, const int ldc); -void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const float *A, const int lda, - const float beta, float *C, const int ldc); -void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const float *A, const int lda, - const float *B, const int ldb, const float beta, - float *C, const int ldc); -void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const float alpha, const float *A, const int lda, - float *B, const int ldb); -void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const float alpha, const float *A, const int lda, - float *B, const int ldb); - -void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const double alpha, const double *A, - const int lda, const double *B, const int ldb, - const double beta, double *C, const int ldc); -void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const double alpha, const double *A, const int lda, - const double *B, const int ldb, const double beta, - double *C, const int ldc); -void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const double *A, const int lda, - const double beta, double *C, const int ldc); -void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const double *A, const int lda, - const double *B, const int ldb, const double beta, - double *C, const int ldc); -void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const double alpha, const double *A, const int lda, - double *B, const int ldb); -void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const double alpha, const double *A, const int lda, - double *B, const int ldb); - -void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const void *alpha, const void *A, - const int lda, const void *B, const int ldb, - const void *beta, void *C, const int ldc); -void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *beta, void *C, const int ldc); -void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); -void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - -void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const void *alpha, const void *A, - const int lda, const void *B, const int ldb, - const void *beta, void *C, const int ldc); -void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *beta, void *C, const int ldc); -void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); -void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - - -/* - * Routines with prefixes C and Z only - */ -void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const void *A, const int lda, - const float beta, void *C, const int ldc); -void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const float beta, - void *C, const int ldc); -void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); -void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const void *A, const int lda, - const double beta, void *C, const int ldc); -void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const double beta, - void *C, const int ldc); - -int cblas_errprn(int ierr, int info, char *form, ...); - -#endif /* end #ifdef CBLAS_ENUM_ONLY */ -#endif // CAFFE2_USE_MKL -#endif diff --git a/caffe2/utils/cpu_neon.h b/caffe2/utils/cpu_neon.h deleted file mode 100644 index 7e68d73c1bef..000000000000 --- a/caffe2/utils/cpu_neon.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef CAFFE2_UTILS_CPU_NEON_H_ -#define CAFFE2_UTILS_CPU_NEON_H_ - -// Provides a variety of ARM NEON-specific utility functions -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#include - -namespace caffe2 { - -template -inline bool isPointerAligned(T* p, size_t align) { - return (reinterpret_cast(p) % align == 0); -} - -inline float32x4_t vert_sum_f32(float32x4_t v0, - float32x4_t v1, - float32x4_t v2, - float32x4_t v3) { - v0 = vaddq_f32(v0, v1); - v2 = vaddq_f32(v2, v3); - return vaddq_f32(v0, v2); -} - -inline float horizontal_sum_f32(float32x4_t v0, - float32x4_t v1, - float32x4_t v2, - float32x4_t v3) { - v0 = vert_sum_f32(v0, v1, v2, v3); - float32x2_t v = vadd_f32(vget_high_f32(v0), vget_low_f32(v0)); - return vget_lane_f32(vpadd_f32(v, v), 0); -} - -// Load/store functions that assume alignment - -inline float32x4_t vld1q_f32_aligned(const float* p) { - return vld1q_f32((const float*) - __builtin_assume_aligned(p, sizeof(float32x4_t))); -} - -inline void vst1q_f32_aligned(float* p, float32x4_t v) { - vst1q_f32((float*) __builtin_assume_aligned(p, sizeof(float32x4_t)), v); -} - -inline void vst4_u8_aligned(uint8_t* p, uint8x8x4_t v) { - vst4_u8((uint8_t*) - __builtin_assume_aligned(p, sizeof(uint8x8x4_t)), v); -} - -} // namespace caffe2 - -#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) - -#endif // CAFFE2_UTILS_CPU_NEON_H_ diff --git a/caffe2/utils/cpuid.cc b/caffe2/utils/cpuid.cc deleted file mode 100644 index 2ba1d2dd8840..000000000000 --- a/caffe2/utils/cpuid.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "caffe2/utils/cpuid.h" - -namespace caffe2 { - -const CpuId& GetCpuId() { - static CpuId cpuid_singleton; - return cpuid_singleton; -} - -TORCH_API uint32_t CpuId::f1c_ = 0; -TORCH_API uint32_t CpuId::f1d_ = 0; -TORCH_API uint32_t CpuId::f7b_ = 0; -TORCH_API uint32_t CpuId::f7c_ = 0; - -CpuId::CpuId() { -#ifdef _MSC_VER - int reg[4]; - __cpuid(static_cast(reg), 0); - const int n = reg[0]; - if (n >= 1) { - __cpuid(static_cast(reg), 1); - f1c_ = uint32_t(reg[2]); - f1d_ = uint32_t(reg[3]); - } - if (n >= 7) { - __cpuidex(static_cast(reg), 7, 0); - f7b_ = uint32_t(reg[1]); - f7c_ = uint32_t(reg[2]); - } -#elif defined(__i386__) && defined(__PIC__) && !defined(__clang__) && \ - defined(__GNUC__) - // The following block like the normal cpuid branch below, but gcc - // reserves ebx for use of its pic register so we must specially - // handle the save and restore to avoid clobbering the register - uint32_t n; - __asm__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "popl %%ebx\n\t" - : "=a"(n) - : "a"(0) - : "ecx", "edx"); - if (n >= 1) { - uint32_t f1a; - __asm__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "popl %%ebx\n\t" - : "=a"(f1a), "=c"(f1c_), "=d"(f1d_) - : "a"(1) - :); - } - if (n >= 7) { - __asm__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "movl %%ebx, %%eax\n\r" - "popl %%ebx" - : "=a"(f7b_), "=c"(f7c_) - : "a"(7), "c"(0) - : "edx"); - } -#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386__) - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t n; - __asm__("cpuid" : "=a"(n) : "a"(0) : "ebx", "ecx", "edx"); - if (n >= 1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t f1a; - __asm__("cpuid" : "=a"(f1a), "=c"(f1c_), "=d"(f1d_) : "a"(1) : "ebx"); - } - if (n >= 7) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t f7a; - __asm__("cpuid" - : "=a"(f7a), "=b"(f7b_), "=c"(f7c_) - : "a"(7), "c"(0) - : "edx"); - } -#endif -} - -} // namespace caffe2 diff --git a/caffe2/utils/cpuid.h b/caffe2/utils/cpuid.h deleted file mode 100644 index 2cac7637ba32..000000000000 --- a/caffe2/utils/cpuid.h +++ /dev/null @@ -1,146 +0,0 @@ -#pragma once - -#include - -#ifdef _MSC_VER -#include -#endif - -#include - -namespace caffe2 { - -class CpuId; - -TORCH_API const CpuId& GetCpuId(); - -/////////////////////////////////////////////////////////////////////////////// -// Implementation of CpuId that is borrowed from folly. -/////////////////////////////////////////////////////////////////////////////// - -// TODO: It might be good to use cpuinfo third-party dependency instead for -// consistency sake. - -/** - * Identification of an Intel CPU. - * Supports CPUID feature flags (EAX=1) and extended features (EAX=7, ECX=0). - * Values from - * http://www.intel.com/content/www/us/en/processors/processor-identification-cpuid-instruction-note.html - */ -class CpuId { - public: - CpuId(); - -#define X(name, r, bit) \ - inline bool name() const { \ - return ((r) & (1U << bit)) != 0; \ - } - -// cpuid(1): Processor Info and Feature Bits. -#define C(name, bit) X(name, f1c_, bit) - C(sse3, 0) - C(pclmuldq, 1) - C(dtes64, 2) - C(monitor, 3) - C(dscpl, 4) - C(vmx, 5) - C(smx, 6) - C(eist, 7) - C(tm2, 8) - C(ssse3, 9) - C(cnxtid, 10) - C(fma, 12) - C(cx16, 13) - C(xtpr, 14) - C(pdcm, 15) - C(pcid, 17) - C(dca, 18) - C(sse41, 19) - C(sse42, 20) - C(x2apic, 21) - C(movbe, 22) - C(popcnt, 23) - C(tscdeadline, 24) - C(aes, 25) - C(xsave, 26) - C(osxsave, 27) - C(avx, 28) - C(f16c, 29) - C(rdrand, 30) -#undef C - -#define D(name, bit) X(name, f1d_, bit) - D(fpu, 0) - D(vme, 1) - D(de, 2) - D(pse, 3) - D(tsc, 4) - D(msr, 5) - D(pae, 6) - D(mce, 7) - D(cx8, 8) - D(apic, 9) - D(sep, 11) - D(mtrr, 12) - D(pge, 13) - D(mca, 14) - D(cmov, 15) - D(pat, 16) - D(pse36, 17) - D(psn, 18) - D(clfsh, 19) - D(ds, 21) - D(acpi, 22) - D(mmx, 23) - D(fxsr, 24) - D(sse, 25) - D(sse2, 26) - D(ss, 27) - D(htt, 28) - D(tm, 29) - D(pbe, 31) -#undef D - -// cpuid(7): Extended Features. -#define B(name, bit) X(name, f7b_, bit) - B(bmi1, 3) - B(hle, 4) - B(avx2, 5) - B(smep, 7) - B(bmi2, 8) - B(erms, 9) - B(invpcid, 10) - B(rtm, 11) - B(mpx, 14) - B(avx512f, 16) - B(avx512dq, 17) - B(rdseed, 18) - B(adx, 19) - B(smap, 20) - B(avx512ifma, 21) - B(pcommit, 22) - B(clflushopt, 23) - B(clwb, 24) - B(avx512pf, 26) - B(avx512er, 27) - B(avx512cd, 28) - B(sha, 29) - B(avx512bw, 30) - B(avx512vl, 31) -#undef B - -#define E(name, bit) X(name, f7c_, bit) - E(prefetchwt1, 0) - E(avx512vbmi, 1) -#undef E - -#undef X - - private: - TORCH_API static uint32_t f1c_; - TORCH_API static uint32_t f1d_; - TORCH_API static uint32_t f7b_; - TORCH_API static uint32_t f7c_; -}; - -} // namespace caffe2 diff --git a/caffe2/utils/cpuid_test.cc b/caffe2/utils/cpuid_test.cc deleted file mode 100644 index f3694f5d0bac..000000000000 --- a/caffe2/utils/cpuid_test.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include "caffe2/utils/cpuid.h" - -namespace caffe2 { - -TEST(CpuIdTest, ShouldAlwaysHaveMMX) { - EXPECT_TRUE(GetCpuId().mmx()); -} - -} // namespace caffe2 diff --git a/caffe2/utils/cub_namespace.cuh b/caffe2/utils/cub_namespace.cuh deleted file mode 100644 index 188a9936f9c6..000000000000 --- a/caffe2/utils/cub_namespace.cuh +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -// cub sort support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: -// https://github.com/NVIDIA/cub/pull/326 -// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake -// starting from CUDA 11.5 -#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) -#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true -#else -#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false -#endif - -#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE() -namespace caffe2 { -namespace cub = ::CUB_WRAPPED_NAMESPACE::cub; -} -#endif diff --git a/caffe2/utils/eigen_utils.h b/caffe2/utils/eigen_utils.h deleted file mode 100644 index c6c34dba9b5a..000000000000 --- a/caffe2/utils/eigen_utils.h +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#ifndef CAFFE2_OPERATORS_UTILS_EIGEN_H_ -#define CAFFE2_OPERATORS_UTILS_EIGEN_H_ - -#include "Eigen/Core" -#include "Eigen/Dense" - -#include -#include - -namespace caffe2 { - -// Common Eigen types that we will often use -template -using EigenMatrixMap = - Eigen::Map>; -template -using EigenArrayMap = - Eigen::Map>; -template -using EigenVectorMap = Eigen::Map>; -template -using EigenVectorArrayMap = Eigen::Map>; -template -using ConstEigenMatrixMap = - Eigen::Map>; -template -using ConstEigenArrayMap = - Eigen::Map>; -template -using ConstEigenVectorMap = - Eigen::Map>; -template -using ConstEigenVectorArrayMap = - Eigen::Map>; - -using EigenOuterStride = Eigen::OuterStride; -using EigenInnerStride = Eigen::InnerStride; -using EigenStride = Eigen::Stride; -template -using EigenOuterStridedMatrixMap = Eigen:: - Map, 0, EigenOuterStride>; -template -using EigenOuterStridedArrayMap = Eigen:: - Map, 0, EigenOuterStride>; -template -using ConstEigenOuterStridedMatrixMap = Eigen::Map< - const Eigen::Matrix, - 0, - EigenOuterStride>; -template -using ConstEigenOuterStridedArrayMap = Eigen::Map< - const Eigen::Array, - 0, - EigenOuterStride>; -template -using EigenStridedMatrixMap = Eigen:: - Map, 0, EigenStride>; -template -using EigenStridedArrayMap = - Eigen::Map, 0, EigenStride>; -template -using ConstEigenStridedMatrixMap = Eigen:: - Map, 0, EigenStride>; -template -using ConstEigenStridedArrayMap = Eigen:: - Map, 0, EigenStride>; - -// 1-d array -template -using EArrXt = Eigen::Array; -using EArrXf = Eigen::ArrayXf; -using EArrXd = Eigen::ArrayXd; -using EArrXi = Eigen::ArrayXi; -using EArrXb = EArrXt; -using EArrXI32 = EArrXt; -using EArrXU16 = EArrXt; -using EArrXU8 = EArrXt; -using EArr3U8 = Eigen::Array; - -// 2-d array, column major -template -using EArrXXt = Eigen::Array; -using EArrXXf = Eigen::ArrayXXf; -using EArrXXI32 = EArrXXt; -using EArrXXU16 = EArrXXt; -using EArrXXU8 = EArrXXt; -using EArrXXi = EArrXXt; - -// 2-d array, row major -template -using ERArrXXt = - Eigen::Array; -using ERArrXXf = ERArrXXt; -using ERArrXXI32t = ERArrXXt; -using ERArrXXU16t = ERArrXXt; -using ERArrXXU8t = ERArrXXt; -using ERArrXXi = ERArrXXt; -using ERArrXXi64t = ERArrXXt; -using ERArrXXi32t = ERArrXXt; - -// 1-d vector -template -using EVecXt = Eigen::Matrix; -using EVecXd = Eigen::VectorXd; -using EVecXf = Eigen::VectorXf; - -// 1-d row vector -using ERVecXd = Eigen::RowVectorXd; -using ERVecXf = Eigen::RowVectorXf; - -// 2-d matrix, column major -template -using EMatXt = Eigen::Matrix; -using EMatXd = Eigen::MatrixXd; -using EMatXf = Eigen::MatrixXf; -using EMatXU8 = EMatXt; -using EMatXU16 = EMatXt; - -// 2-d matrix, row major -template -using ERMatXt = - Eigen::Matrix; -using ERMatXd = ERMatXt; -using ERMatXf = ERMatXt; -using ERMatXU8 = ERMatXt; - -namespace utils { - -template -Eigen::Map> AsEArrXt(const std::vector& arr) { - return {arr.data(), static_cast(arr.size())}; -} -template -Eigen::Map> AsEArrXt(std::vector& arr) { - return {arr.data(), static_cast(arr.size())}; -} - -// return a sub array of 'array' based on indices 'indices' -template -void GetSubArray( - const Eigen::ArrayBase& array, - const Eigen::ArrayBase& indices, - Eigen::ArrayBase* out_array) { - CAFFE_ENFORCE_EQ(array.cols(), 1); - // using T = typename Derived::Scalar; - - out_array->derived().resize(indices.size()); - for (const auto i : c10::irange(indices.size())) { - TORCH_DCHECK_LT(indices[i], array.size()); - (*out_array)[i] = array[indices[i]]; - } -} - -// return a sub array of 'array' based on indices 'indices' -template -EArrXt GetSubArray( - const Eigen::ArrayBase& array, - const Eigen::ArrayBase& indices) { - using T = typename Derived::Scalar; - EArrXt ret(indices.size()); - GetSubArray(array, indices, &ret); - return ret; -} - -// return a sub array of 'array' based on indices 'indices' -template -EArrXt GetSubArray( - const Eigen::ArrayBase& array, - const std::vector& indices) { - return GetSubArray(array, AsEArrXt(indices)); -} - -// return 2d sub array of 'array' based on row indices 'row_indices' -template -void GetSubArrayRows( - const Eigen::ArrayBase& array2d, - const Eigen::ArrayBase& row_indices, - Eigen::ArrayBase* out_array) { - out_array->derived().resize(row_indices.size(), array2d.cols()); - - for (const auto i : c10::irange(row_indices.size())) { - TORCH_DCHECK_LT(row_indices[i], array2d.size()); - out_array->row(i) = - array2d.row(row_indices[i]).template cast(); - } -} - -// return indices of 1d array for elements evaluated to true -template -std::vector GetArrayIndices(const Eigen::ArrayBase& array) { - std::vector ret; - for (const auto i : c10::irange(array.size())) { - if (array[i]) { - ret.push_back(i); - } - } - return ret; -} - -} // namespace utils -} // namespace caffe2 - -#endif diff --git a/caffe2/utils/fatal_signal_asan_no_sig_test.cc b/caffe2/utils/fatal_signal_asan_no_sig_test.cc deleted file mode 100644 index 9c64102981c3..000000000000 --- a/caffe2/utils/fatal_signal_asan_no_sig_test.cc +++ /dev/null @@ -1,148 +0,0 @@ -#include "caffe2/utils/signal_handler.h" -#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) -#include -#include -#include - -#include -#include -#include - -#include "caffe2/core/common.h" - -namespace { -void* dummy_thread(void*) { - while (1) { - } - return nullptr; -} - -bool forkAndPipe( - std::string& stderrBuffer, - std::function callback) { - std::array stderrPipe; - if (pipe(stderrPipe.data()) != 0) { - perror("STDERR pipe"); - return false; - } - pid_t child = fork(); - if (child == 0) { - // Replace this process' stderr so we can read it. - if (dup2(stderrPipe[1], STDERR_FILENO) < 0) { - close(stderrPipe[0]); - close(stderrPipe[1]); - perror("dup2 STDERR"); - exit(5); - } - - // This is for the parent to work with. - close(stderrPipe[0]); - close(stderrPipe[1]); - - callback(); - exit(7); - } else if (child > 0) { - const int bufferSize = 128; - std::array buffer; - - // We want to close the writing end of the pipe right away so our - // read actually gets an EOF. - close(stderrPipe[1]); - - // wait for child to finish crashing. - int statloc; - if (wait(&statloc) < 0) { - close(stderrPipe[0]); - perror("wait"); - return false; - } - - ssize_t bytesRead; - while ((bytesRead = read(stderrPipe[0], buffer.data(), bufferSize)) > 0) { - const std::string tmp(buffer.data(), bytesRead); - std::cout << tmp; - stderrBuffer += tmp; - } - - // The child should have exited due to signal. - if (!WIFSIGNALED(statloc)) { - fprintf(stderr, "Child didn't exit because it received a signal\n"); - if (WIFEXITED(statloc)) { - fprintf(stderr, "Exited with code: %d\n", WEXITSTATUS(statloc) & 0xff); - } - return false; - } - - if (bytesRead < 0) { - perror("read"); - return false; - } - - close(stderrPipe[0]); - return true; - } else { - perror("fork"); - return false; - } -} -} // namespace - -#define _TEST_FATAL_SIGNAL(signum, name, threadCount, print, expected) \ - do { \ - std::string stderrBuffer; \ - ASSERT_TRUE(forkAndPipe(stderrBuffer, [=]() { \ - caffe2::setPrintStackTracesOnFatalSignal(print); \ - pthread_t pt; \ - for (int i = 0; i < threadCount; i++) { \ - if (pthread_create(&pt, nullptr, ::dummy_thread, nullptr)) { \ - perror("pthread_create"); \ - } \ - } \ - raise(signum); \ - })); \ - int keyPhraseCount = 0; \ - std::string keyPhrase = \ - std::string(name) + "(" + c10::to_string(signum) + ")"; \ - size_t loc = 0; \ - while ((loc = stderrBuffer.find(keyPhrase, loc)) != std::string::npos) { \ - keyPhraseCount += 1; \ - loc += 1; \ - } \ - EXPECT_GE(keyPhraseCount, expected); \ - } while (0) - -#define TEST_FATAL_SIGNAL(signum, name, threadCount) \ - _TEST_FATAL_SIGNAL(signum, name, threadCount, true, threadCount + 1) - -#define TEST_FATAL_SIGNAL_NO_PRINT(signum, name, threadCount) \ - _TEST_FATAL_SIGNAL(signum, name, threadCount, false, 0) - -TEST(fatalSignalTest, SIGABRT8) { - TEST_FATAL_SIGNAL(SIGABRT, "SIGABRT", 8); -} - -TEST(fatalSignalTest, SIGINT8) { - TEST_FATAL_SIGNAL(SIGINT, "SIGINT", 8); -} - -TEST(fatalSignalTest, SIGILL8) { - TEST_FATAL_SIGNAL(SIGILL, "SIGILL", 8); -} - -TEST(fatalSignalTest, SIGFPE8) { - TEST_FATAL_SIGNAL(SIGFPE, "SIGFPE", 8); -} - -TEST(fatalSignalTest, SIGBUS8) { - TEST_FATAL_SIGNAL(SIGBUS, "SIGBUS", 8); -} - -TEST(fatalSignalTest, SIGSEGV8) { - TEST_FATAL_SIGNAL(SIGSEGV, "SIGSEGV", 8); -} - -// Test that if we don't enable printing stack traces then we don't get any. -TEST(fatalSignalTest, SIGABRT8_NOPRINT) { - TEST_FATAL_SIGNAL_NO_PRINT(SIGABRT, "SIGABRT", 8); -} -#endif // defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) diff --git a/caffe2/utils/filler.h b/caffe2/utils/filler.h deleted file mode 100644 index 3d0e399ba73b..000000000000 --- a/caffe2/utils/filler.h +++ /dev/null @@ -1,140 +0,0 @@ -#ifndef CAFFE2_FILLER_H_ -#define CAFFE2_FILLER_H_ - -#include - -#include "caffe2/core/logging.h" -#include "caffe2/core/tensor.h" -#include "caffe2/utils/math.h" - -namespace caffe2 { - -// TODO: replace filler distribution enum with a better abstraction -enum FillerDistribution { FD_UNIFORM, FD_FIXEDSUM, FD_SYNTHETIC }; - -class TensorFiller { - public: - template - void Fill(Tensor* tensor, Context* context) const { - CAFFE_ENFORCE(context, "context is null"); - CAFFE_ENFORCE(tensor, "tensor is null"); - auto min = (min_ < (double)std::numeric_limits::min()) - ? std::numeric_limits::min() - : static_cast(min_); - auto max = (max_ > (double)std::numeric_limits::max()) - ? std::numeric_limits::max() - : static_cast(max_); - CAFFE_ENFORCE_LE(min, max); - - Tensor temp_tensor(shape_, Context::GetDeviceType()); - std::swap(*tensor, temp_tensor); - Type* data = tensor->template mutable_data(); - - // select distribution - switch (dist_) { - case FD_UNIFORM: { - math::RandUniform( - tensor->numel(), min, max, data, context); - break; - } - case FD_FIXEDSUM: { - auto fixed_sum = static_cast(fixed_sum_); - CAFFE_ENFORCE_LE(min * tensor->numel(), fixed_sum); - CAFFE_ENFORCE_GE(max * tensor->numel(), fixed_sum); - math::RandFixedSum( - tensor->numel(), min, max, fixed_sum_, data, context); - break; - } - case FD_SYNTHETIC: { - math::RandSyntheticData( - tensor->numel(), min, max, data, context); - break; - } - } - } - - TensorFiller& Dist(FillerDistribution dist) { - dist_ = dist; - return *this; - } - - template - TensorFiller& Min(Type min) { - min_ = (double)min; - return *this; - } - - template - TensorFiller& Max(Type max) { - max_ = (double)max; - return *this; - } - - template - TensorFiller& FixedSum(Type fixed_sum) { - dist_ = FD_FIXEDSUM; - fixed_sum_ = (double)fixed_sum; - return *this; - } - - // A helper function to construct the lengths vector for sparse features - // We try to pad least one index per batch unless the total_length is 0 - template - TensorFiller& SparseLengths(Type total_length) { - return FixedSum(total_length) - .Min(std::min(static_cast(1), total_length)) - .Max(total_length); - } - - // a helper function to construct the segments vector for sparse features - template - TensorFiller& SparseSegments(Type max_segment) { - CAFFE_ENFORCE(dist_ != FD_FIXEDSUM); - return Min(0).Max(max_segment).Dist(FD_SYNTHETIC); - } - - TensorFiller& Shape(const std::vector& shape) { - shape_ = shape; - return *this; - } - - template - TensorFiller(const std::vector& shape, Type fixed_sum) - : shape_(shape), dist_(FD_FIXEDSUM), fixed_sum_((double)fixed_sum) {} - - TensorFiller(const std::vector& shape) - : shape_(shape), dist_(FD_UNIFORM), fixed_sum_(0) {} - - TensorFiller() : TensorFiller(std::vector()) {} - - std::string DebugString() const { - std::stringstream stream; - stream << "shape = [" << shape_ << "]; min = " << min_ - << "; max = " << max_; - switch (dist_) { - case FD_FIXEDSUM: - stream << "; dist = FD_FIXEDSUM"; - break; - case FD_SYNTHETIC: - stream << "; dist = FD_SYNTHETIC"; - break; - default: - stream << "; dist = FD_UNIFORM"; - break; - } - return stream.str(); - } - - private: - std::vector shape_; - // TODO: type is unknown until a user starts to fill data; - // cast everything to double for now. - double min_ = 0.0; - double max_ = 1.0; - FillerDistribution dist_; - double fixed_sum_; -}; - -} // namespace caffe2 - -#endif // CAFFE2_FILLER_H_ diff --git a/caffe2/utils/fixed_divisor_test.cc b/caffe2/utils/fixed_divisor_test.cc deleted file mode 100644 index 6093bc764c39..000000000000 --- a/caffe2/utils/fixed_divisor_test.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "caffe2/utils/fixed_divisor.h" - -#include - -#include - -namespace caffe2 { - -namespace { - -void CompareDivMod(int32_t v, int32_t divisor) { - auto fixed = FixedDivisor(divisor); - - int native_q = v / divisor; - int native_r = v % divisor; - - int fixed_q = fixed.Div(v); - int fixed_r = fixed.Mod(v); - -#if !defined(USE_ROCM) - EXPECT_EQ(native_q, fixed_q) - << v << " / " << divisor << " magic " << fixed.magic() << " shift " - << fixed.shift() << " quot " << fixed_q << " " << native_q; - - EXPECT_EQ(native_r, fixed_r) - << v << " / " << divisor << " magic " << fixed.magic() << " shift " - << fixed.shift() << " rem " << fixed_r << " " << native_r; -#endif -} - -} // namespace - -TEST(FixedDivisorTest, FixedDivisorInt32Test) { - constexpr int32_t kMax = std::numeric_limits::max(); - - // divide by 1 - CompareDivMod(kMax, 1); - CompareDivMod(0, 1); - CompareDivMod(1, 1); - - // divide by max - CompareDivMod(kMax, kMax); - CompareDivMod(0, kMax); - CompareDivMod(1, kMax); - - // divide by random positive values - std::random_device rd; - std::uniform_int_distribution v_dist(0, kMax); - std::uniform_int_distribution q_dist(1, kMax); - - std::uniform_int_distribution v_small_dist(0, 1000); - std::uniform_int_distribution q_small_dist(1, 1000); - for (int i = 0; i < 10000; ++i) { - auto q = q_dist(rd); - auto v = v_dist(rd); - auto q_small = q_small_dist(rd); - auto v_small = v_small_dist(rd); - - // random value - CompareDivMod(v_small, q_small); - CompareDivMod(v_small, q); - CompareDivMod(v, q_small); - CompareDivMod(v, q); - - // special values - CompareDivMod(kMax, q_small); - CompareDivMod(0, q_small); - CompareDivMod(1, q_small); - CompareDivMod(kMax, q); - CompareDivMod(0, q); - CompareDivMod(1, q); - - CompareDivMod(v_small, 1); - CompareDivMod(v_small, kMax); - CompareDivMod(v, 1); - CompareDivMod(v, kMax); - } -} - -} // namespace caffe2 diff --git a/caffe2/utils/hip/math_blas_gpu_test.cc b/caffe2/utils/hip/math_blas_gpu_test.cc deleted file mode 100644 index 07d4bf11f5a4..000000000000 --- a/caffe2/utils/hip/math_blas_gpu_test.cc +++ /dev/null @@ -1,379 +0,0 @@ -#include -#include "caffe2/core/blob.h" -#include "caffe2/core/context.h" -#include "caffe2/core/hip/context_gpu.h" -#include "caffe2/core/tensor.h" -#include "caffe2/operators/utility_ops.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/conversions.h" -#include "caffe2/utils/math.h" - -namespace caffe2 { - -TEST(MathROCBLASTest, GemmNoTransNoTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobX = ws.CreateBlob("X"); - Blob* blobW = ws.CreateBlob("W"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeX{5, 10}; - vector shapeW{10, 6}; - vector shapeY{5, 6}; - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorW = BlobGetMutableTensor(blobW, HIP); - tensorW->Resize(shapeW); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorX->size(), 50); - EXPECT_EQ(tensorW->size(), 60); - EXPECT_EQ(tensorY->size(), 30); - - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - math::Set( - tensorW->size(), 1, tensorW->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; - } - - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; - } - - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kPointFive, - tensorX->template data(), - tensorW->template data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; - } -} - -TEST(MathROCBLASTest, GemmNoTransTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobX = ws.CreateBlob("X"); - Blob* blobW = ws.CreateBlob("W"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeX{5, 10}; - vector shapeW{6, 10}; - vector shapeY{5, 6}; - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorW = BlobGetMutableTensor(blobW, HIP); - tensorW->Resize(shapeW); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorX->size(), 50); - EXPECT_EQ(tensorW->size(), 60); - EXPECT_EQ(tensorY->size(), 30); - - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - math::Set( - tensorW->size(), 1, tensorW->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; - } - - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - tensorX->template data(), - tensorW->template data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; - } - - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kPointFive, - tensorX->template data(), - tensorW->template data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - EXPECT_EQ(tensorY_host->size(), 30); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; - } -} - -TEST(MathROCBLASTest, GemvNoTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobA = ws.CreateBlob("A"); - Blob* blobX = ws.CreateBlob("X"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeA{5, 10}; - vector shapeX{10}; - vector shapeY{5}; - auto* tensorA = BlobGetMutableTensor(blobA, HIP); - tensorA->Resize(shapeA); - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorA->size(), 50); - EXPECT_EQ(tensorX->size(), 10); - EXPECT_EQ(tensorY->size(), 5); - math::Set( - tensorA->size(), 1, tensorA->mutable_data(), &context); - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 10) << i; - } - - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 15) << i; - } - - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kPointFive, - tensorA->data(), - tensorX->data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 20) << i; - } -} - -TEST(MathROCBLASTest, GemvTrans) { - if (!HasHipGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_HIP); - HIPContext context(option); - - Blob* blobA = ws.CreateBlob("A"); - Blob* blobX = ws.CreateBlob("X"); - Blob* blobY = ws.CreateBlob("Y"); - Blob* blobY_host = ws.CreateBlob("Y_host"); - - vector shapeA{6, 10}; - vector shapeX{6}; - vector shapeY{10}; - auto* tensorA = BlobGetMutableTensor(blobA, HIP); - tensorA->Resize(shapeA); - auto* tensorX = BlobGetMutableTensor(blobX, HIP); - tensorX->Resize(shapeX); - auto* tensorY = BlobGetMutableTensor(blobY, HIP); - tensorY->Resize(shapeY); - auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU); - tensorY_host->Resize(shapeY); - - EXPECT_EQ(tensorA->size(), 60); - EXPECT_EQ(tensorX->size(), 6); - EXPECT_EQ(tensorY->size(), 10); - math::Set( - tensorA->size(), 1, tensorA->mutable_data(), &context); - math::Set( - tensorX->size(), 1, tensorX->mutable_data(), &context); - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kZero, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 6) << i; - } - - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - tensorA->data(), - tensorX->data(), - kPointFive, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 9) << i; - } - - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kPointFive, - tensorA->data(), - tensorX->data(), - kOne, - tensorY->mutable_data(), - &context); - context.FinishDeviceComputation(); - tensorY_host->CopyFrom(*tensorY); - for (int i = 0; i < tensorY_host->size(); ++i) { - TORCH_CHECK_EQ(tensorY_host->data()[i], 12) << i; - } -} -} // namespace caffe2 diff --git a/caffe2/utils/knob_patcher.cc b/caffe2/utils/knob_patcher.cc deleted file mode 100644 index e099ea61dd87..000000000000 --- a/caffe2/utils/knob_patcher.cc +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and its affiliates. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include -#include -#include - -#include "caffe2/utils/knobs.h" -#include "caffe2/utils/knob_patcher.h" - -namespace caffe2 { -namespace detail { -std::map& getRegisteredKnobs(); -} // namespace detail - -namespace { -class PatchNode { - public: - PatchNode(c10::string_view name, bool value); - ~PatchNode(); - - std::string name; - bool oldValue{false}; - // Nodes to form a linked list of existing PatchState objects for this knob. - // This allows us to restore state correctly even if KnobPatcher objects - // are destroyed in any arbitrary order. - PatchNode* prev{nullptr}; - PatchNode* next{nullptr}; -}; -} // namespace - -class KnobPatcher::PatchState : public PatchNode { - using PatchNode::PatchNode; -}; - -KnobPatcher::KnobPatcher(c10::string_view name, bool value) - : state_{std::make_unique(name, value)} {} - -KnobPatcher::~KnobPatcher() = default; -KnobPatcher::KnobPatcher(KnobPatcher&&) noexcept = default; -KnobPatcher& KnobPatcher::operator=(KnobPatcher&&) noexcept = default; - -namespace { - -class Patcher { - public: - void patch(PatchNode* node, bool value) { - std::lock_guard lock{mutex_}; - - node->oldValue = setKnobValue(node->name, value); - auto ret = patches_.emplace(node->name, node); - if (!ret.second) { - // There was already another patcher for this knob - // Append the new node to the linked list. - node->prev = ret.first->second; - CHECK(!node->prev->next); - node->prev->next = node; - ret.first->second = node; - } - } - - void unpatch(PatchNode* node) { - std::lock_guard lock{mutex_}; - - // Remove this PatchNode from the linked list - if (node->prev) { - node->prev->next = node->next; - } - if (node->next) { - // There was another patch applied after this one. - node->next->prev = node->prev; - node->next->oldValue = node->oldValue; - } else { - // This was the most recently applied patch for this knob, - // so restore the knob value. - setKnobValue(node->name, node->oldValue); - - // The patches_ map should point to this node. - // Update it to point to the previous patch, if there is one. - auto iter = patches_.find(node->name); - if (iter == patches_.end()) { - LOG(FATAL) << "patch node not found when unpatching knob value"; - } - TORCH_CHECK_EQ(iter->second, node); - if (node->prev) { - iter->second = node->prev; - } else { - patches_.erase(iter); - } - } - } - - private: - bool setKnobValue(c10::string_view name, bool value) { - auto& knobs = caffe2::detail::getRegisteredKnobs(); - auto iter = knobs.find(name); - if (iter == knobs.end()) { - throw std::invalid_argument( - "attempted to patch unknown knob \"" + std::string(name) + "\""); - } - bool oldValue = *(iter->second); - *iter->second = value; - return oldValue; - } - - std::mutex mutex_; - std::map patches_; -}; - -Patcher& getPatcher() { - static Patcher patcher; - return patcher; -} - -PatchNode::PatchNode(c10::string_view knobName, bool value) - : name{knobName} { - getPatcher().patch(this, value); -} - -PatchNode::~PatchNode() { - try { - getPatcher().unpatch(this); - } catch (const std::exception& ex) { - // This shouldn't ever happen unless we have a programming bug, but it keeps - // clang-tidy happy if we put a catch block here to handle the theoretical - // error if unpatch() calls setKnobValue() and it throws due to not finding - // the knob by name. - LOG(FATAL) << "error removing knob patch: " << ex.what(); - } -} - -} // namespace -} // namespace caffe2 diff --git a/caffe2/utils/knob_patcher.h b/caffe2/utils/knob_patcher.h deleted file mode 100644 index ec2b6277760d..000000000000 --- a/caffe2/utils/knob_patcher.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#include - -#include - -namespace caffe2 { - -/** - * Patch the value of a knob during a unit test. - * - * This forces the knob to the specified value for as long as the KnobPatcher - * object exists. When the KnobPatcher object is destroyed the knob will revert - * to its previous value. - */ -class KnobPatcher { - public: - KnobPatcher(c10::string_view name, bool value); - ~KnobPatcher(); - - KnobPatcher(KnobPatcher&&) noexcept; - KnobPatcher& operator=(KnobPatcher&&) noexcept; - KnobPatcher(const KnobPatcher&) = delete; - KnobPatcher& operator=(const KnobPatcher&) = delete; - - private: - class PatchState; - - std::unique_ptr state_; -}; - -} // namespace caffe2 diff --git a/caffe2/utils/knobs.cc b/caffe2/utils/knobs.cc deleted file mode 100644 index 63941a573edf..000000000000 --- a/caffe2/utils/knobs.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// This is a very basic knob implementation that purely uses command line flags. -// This can be replaced with a more sophisticated implementation for use in -// other production environments. - -#include - -#include -#include - -#include "caffe2/utils/knobs.h" - -namespace caffe2 { - -namespace detail { -// Get the map of knob names to pointers to their command-line controlled -// boolean value. -std::map& getRegisteredKnobs() { - // It's safe to store the keys as string_view, since DEFINE_KNOB() ensures - // that these views always point to string literals. - static std::map registeredKnobs; - return registeredKnobs; -} -} // namespace detail - -bool CheckKnob(c10::string_view name) { - const auto& knobs = detail::getRegisteredKnobs(); - auto iter = knobs.find(name); - if (iter == knobs.end()) { - throw std::invalid_argument( - "attempted to check unknown knob \"" + std::string(name) + "\""); - } - return *iter->second; -} - -namespace { -class RegisterKnob { - public: - RegisterKnob(c10::string_view name, bool* cmdlineFlag) { - auto ret = caffe2::detail::getRegisteredKnobs().emplace(name, cmdlineFlag); - if (!ret.second) { - throw std::runtime_error("duplicate knob name: " + std::string(name)); - } - } -}; -} // namespace -} // namespace caffe2 - -/** - * Define a knob. - * - * This will define a --caffe2_knob_ command line flag to control the - * knob. - * - * The knob can be checked in code by calling CheckKnob(name) - * or CheckKnob() - */ -#define DEFINE_KNOB(name, check_fn_name, default_value, docstring) \ - C10_DEFINE_bool(caffe2_knob_##name, default_value, docstring); \ - namespace caffe2 { \ - bool CheckKnob##check_fn_name() { \ - return FLAGS_caffe2_knob_##name; \ - } \ - } \ - static caffe2::RegisterKnob _knob_##name(#name, &FLAGS_caffe2_knob_##name) - -/* - * Definitions of well-known knobs. - */ - -DEFINE_KNOB( - example_knob, - ExampleKnob, - false, - "An example knob, mainly intended for use in unit tests"); diff --git a/caffe2/utils/knobs.h b/caffe2/utils/knobs.h deleted file mode 100644 index fbebd90cf741..000000000000 --- a/caffe2/utils/knobs.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -// This file contains functions for checking rollout knobs to enable staged -// roll out of specific code functionality. - -#include - -#include - -namespace caffe2 { - -/** - * Check an arbitrary knob by name. - */ -bool CheckKnob(c10::string_view name); - -/* - * The following are functions for checking specific known knob values. - * - * These APIs are more efficient than checking by name. - */ - -// An example knob, just for use in unit tests. -bool CheckKnobExampleKnob(); - -} // namespace caffe2 diff --git a/caffe2/utils/knobs_test.cc b/caffe2/utils/knobs_test.cc deleted file mode 100644 index 95f29cff2471..000000000000 --- a/caffe2/utils/knobs_test.cc +++ /dev/null @@ -1,34 +0,0 @@ -#include - -#include "caffe2/utils/knobs.h" -#include "caffe2/utils/knob_patcher.h" - -using namespace caffe2; - -TEST(KnobsTest, TestKnob) { - auto p = KnobPatcher("example_knob", false); - EXPECT_FALSE(CheckKnobExampleKnob()); - EXPECT_FALSE(CheckKnob("example_knob")); - - p = KnobPatcher("example_knob", true); - EXPECT_TRUE(CheckKnobExampleKnob()); - EXPECT_TRUE(CheckKnob("example_knob")); - - // Test nested patchers - { - auto p2 = KnobPatcher("example_knob", false); - EXPECT_FALSE(CheckKnobExampleKnob()); - EXPECT_FALSE(CheckKnob("example_knob")); - - auto p3 = KnobPatcher("example_knob", true); - EXPECT_TRUE(CheckKnobExampleKnob()); - EXPECT_TRUE(CheckKnob("example_knob")); - } - EXPECT_TRUE(CheckKnobExampleKnob()); - EXPECT_TRUE(CheckKnob("example_knob")); -} - -TEST(KnobsTest, TestUnknownKnob) { - // Unknown knob names should throw an exception - EXPECT_THROW(CheckKnob("this_knob_does_not_exist"), std::exception); -} diff --git a/caffe2/utils/map_utils.h b/caffe2/utils/map_utils.h deleted file mode 100644 index ef8ff0cab707..000000000000 --- a/caffe2/utils/map_utils.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -namespace caffe2 { - -// Get value from map given key. Return supplied default value if not found -// This is a stripped down version from folly: -// https://github.com/facebook/folly/blob/5a07e203d79324b68d69f294fa38e43b9671e9b1/folly/MapUtil.h#L35-L45 -template < - class Map, - typename Key = typename Map::key_type, - typename Value = typename Map::mapped_type> -typename Map::mapped_type -get_default(const Map& map, const Key& key, Value&& dflt) { - using M = typename Map::mapped_type; - auto pos = map.find(key); - return (pos != map.end()) ? (pos->second) : M(std::forward(dflt)); -} - -} // namespace caffe2 diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h deleted file mode 100644 index f2ecc711995a..000000000000 --- a/caffe2/utils/math-detail.h +++ /dev/null @@ -1,90 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_DETAIL_H_ -#define CAFFE2_UTILS_MATH_DETAIL_H_ -namespace caffe2 { - -class CPUContext; - -namespace math { -namespace detail { - -// proxy to a class because of partial specialization limitations for functions - -template -struct ScaleImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - Scale(N, alpha, x, y, context); - } -}; - -// Put light-weight implementations in .h file to enable inlining -template -struct ScaleImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - CPUContext* /*context*/) { - TORCH_DCHECK_EQ(N, 1); - *y = *x * alpha; - } -}; - -template -struct AxpyImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - Axpy(N, alpha, x, y, context); - } -}; - -// Put light-weight implementations in .h file to enable inlining -template -struct AxpyImpl { - inline void operator()( - const int N, - const float alpha, - const T* x, - T* y, - CPUContext* /*context*/) { - TORCH_DCHECK_EQ(N, 1); - *y += *x * alpha; - } -}; - - -} // namespace detail - -template -inline void ScaleFixedSize( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - detail::ScaleImpl()(N, alpha, x, y, context); -} - -template -inline void AxpyFixedSize( - const int N, - const float alpha, - const T* x, - T* y, - Context* context) { - detail::AxpyImpl()(N, alpha, x, y, context); -} - -} // namespace math -} // namespace caffe2 - -#endif // CAFFE2_UTILS_MATH_DETAIL_H_ diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h deleted file mode 100644 index 6acc50e8e748..000000000000 --- a/caffe2/utils/math.h +++ /dev/null @@ -1,467 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_H_ -#define CAFFE2_UTILS_MATH_H_ -// This is a simple translation from the old Caffe math interfaces. We aim to -// still keep it simple, so all platforms would be able to support it fairly -// easily. - -// We include the cblas header here so that we can obtain the macros from cblas. -extern "C" { -#include "caffe2/utils/cblas.h" -} - -#ifdef CAFFE2_USE_ACCELERATE -#include -#endif // CAFFE2_USE_ACCELERATE - -#include "caffe2/core/common.h" -#include "caffe2/core/types.h" -#include "caffe2/utils/math/broadcast.h" -#include "caffe2/utils/math/elementwise.h" -#include "caffe2/utils/math/reduce.h" -#include "caffe2/utils/math/transpose.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { - -// TODO: Change dims related arguments to int64_t? -class Tensor; - -// An empty class as a placeholder for a math function that has no specific -// engine specified. -class TORCH_API DefaultEngine {}; - -namespace math { - -#define C10_DECLARE_COMPARE_OP(Comp) \ - template \ - void Rowwise##Comp( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - bool* C, \ - Context* context); \ - \ - template \ - void Colwise##Comp( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - bool* C, \ - Context* context); \ - \ - template \ - void Comp( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const T* A, \ - const T* B, \ - bool* C, \ - Context* context); - -C10_DECLARE_COMPARE_OP(EQ) -C10_DECLARE_COMPARE_OP(NE) -C10_DECLARE_COMPARE_OP(LT) -C10_DECLARE_COMPARE_OP(LE) -C10_DECLARE_COMPARE_OP(GT) -C10_DECLARE_COMPARE_OP(GE) - -#undef C10_DECLARE_COMPARE_OP - -#define C10_DECLARE_BINARY_OP(Func) \ - template \ - void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template \ - void Colwise##Func( \ - const int rows, \ - const int cols, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); \ - \ - template \ - void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const T* A, \ - const T* B, \ - T* C, \ - Context* context); - -C10_DECLARE_BINARY_OP(Add) -C10_DECLARE_BINARY_OP(Sub) -C10_DECLARE_BINARY_OP(Mul) -C10_DECLARE_BINARY_OP(Div) - -C10_DECLARE_BINARY_OP(And) -C10_DECLARE_BINARY_OP(Or) -C10_DECLARE_BINARY_OP(Xor) - -C10_DECLARE_BINARY_OP(BitwiseAnd) -C10_DECLARE_BINARY_OP(BitwiseOr) -C10_DECLARE_BINARY_OP(BitwiseXor) - -#undef C10_DECLARE_BINARY_OP - -// Broadcasts X with X_dims to Y with Y_dims. -template -TORCH_API void Broadcast( - const int X_ndim, - const int* X_dims, - const int Y_ndim, - const int* Y_dims, - const T alpha, - const T* X, - T* Y, - Context* context, - bool allow_broadcast_fastpath=false); - -// Computes inv_std from variance. -template -TORCH_API void InvStd( - const int N, - const T epsilon, - const T* var, - T* inv_std, - Context* context); - -// Adds batch sub-tensors elementwise to output. Stripe is the stripe length -// and N is the number of elements to add (size of Y). -template -TORCH_API void AddStripedBatch( - const int N, - const T* first, - T* y, - const int stripe, - const int batch, - Context* context); - -// Compute the row-wise max of a N*D matrix X, and write it to a N -// dimensional vector y. -template -TORCH_API void -RowwiseMax(const int N, const int D, const T* x, T* y, Context* context); - -// Compute the column-wise max of a N*D matrix X, and write it to a D -// dimensional vector y. -template -TORCH_API void -ColwiseMax(const int N, const int D, const T* x, T* y, Context* context); - -// Elemwise maximum of vector x and scalar alpha. y[i] = max(x[i], alpha) -template -TORCH_API void -Maximum(const int N, const float alpha, const T* x, T* y, Context* context); - -// Decaf gemm provides a simpler interface to the gemm functions, with the -// limitation that the data has to be contiguous in memory. -template -TORCH_API void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const T* A, - const T* B, - const float beta, - T* C, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -// We also provide a gemm that has explicit lda, ldb and ldc specified. -// In most cases you probably want to use the function above, though. -template -TORCH_API void GemmEx( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc, - Context* context); - -// GemmBatched provides a simple abstraction into library routines -template -TORCH_API void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const T** A, - const T** B, - const float beta, - T** C, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -template -TORCH_API void GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const T* A, - const int A_stride, - const T* B, - const int B_stride, - const float beta, - T* C, - const int C_stride, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -// Gemv always takes in a M*N matrix A, and depending on whether we set TransA -// to Trans, the output is: -// CblasNoTrans: x is an N dim vector and y is an M dim vector. -// CblasTrans: x is an M dim vector and y is an N dim vector. -template -TORCH_API void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const T* A, - const T* x, - const float beta, - T* y, - Context* context, - TensorProto::DataType math_type = TensorProto_DataType_FLOAT); - -template -TORCH_API void -RandUniform(const size_t n, const T a, const T b, T* r, Context* context); - -// Generate n values that sum up to a fixed sum -// and subject to a restriction a <= x <= b for each x generated -template -TORCH_API void RandFixedSum( - const size_t n, - const T a, - const T b, - const T sum, - T* r, - Context* context); - -template -TORCH_API void RandUniformUnique( - const size_t n, - const T a, - const T b, - T* r, - const size_t m, - const T* avoid, - Context* context); - -// Generate n values from synthetic data distribution, -// define by unique accesses and stack distances -template -TORCH_API void -RandSyntheticData(const size_t n, const T a, const T b, T* r, Context* context); - -template -TORCH_API void -RandGaussian(const size_t n, const T mean, const T std, T* r, Context* context); - -// Dot matrix of vector a and b, and writes the result to a single value y. -template -TORCH_API void -Dot(const int N, const T* a, const T* b, T* y, Context* context); - -// Sum of vector x, and writes the result to a single value y. -template -TORCH_API void Sum( - const int N, - const T* x, - T* y, - Context* context, - Tensor* scratch_ptr = nullptr); - -// Sum of squares of vector x, and writes the result to a single value y. -template -TORCH_API void SumSqr( - const int N, - const T* x, - T* y, - Context* context, - Tensor* scratch_ptr = nullptr); - -// Select does index selection of the rows a N*D matrix x, and gives the N -// dimensional vector y that contains the selected data. -template -TORCH_API void Select( - const int N, - const int D, - const T* x, - const int* idx, - T* y, - Context* context); - -// groups must be 1 for GPU -// For NHWC order with groups > 1, the result will be layout in -// NHW G RS C/G order to make data within the same group to be contiguous. -// For NCHW order, groups doesn't make any difference because we're doing Im2Col -// for each N and C is the slowest moving dimension among CHW. -template -TORCH_API void Im2Col( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const T* img_data, - T* col_data, - Context* context, - const int groups = 1); - -// groups must be 1 for GPU -template -TORCH_API void Im2ColNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const T* img_data, - T* col_data, - Context* context, - const int groups = 1); - -// groups must be 1 for GPU -// For NHWC order with groups > 1, the result will be layout in -// NHW G RS C/G order to make data within the same group to be contiguous. -// For NCHW order, groups doesn't make any difference because we're doing Im2Col -// for each N and C is the slowest moving dimension among CHW. -template -TORCH_API void Col2Im( - const int channels, - const int height, - const int width, - const int patch_h, - const int patch_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const T* col_data, - T* img_data, - Context* context, - const int groups = 1); - -// groups must be 1 for GPU -// For NHWC order with groups > 1, the result will be layout in -// NHW G RS C/G order to make data within the same group to be contiguous. -// For NCHW order, groups doesn't make any difference because we're doing Im2Col -// for each N and C is the slowest moving dimension among CHW. -template -TORCH_API void Col2ImNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const T* col_data, - T* img_data, - Context* context, - const int groups = 1); - -// Applies a per-channel bias value to each channel of the input -// image. image_size is H * W -template -TORCH_API void BiasCHW( - const T* bias, - const T* bias_multiplier, - const int bias_channels, - const int image_size, - T* image, - Context* context); - -template -TORCH_API void CopyMatrix( - const size_t item_size, - const int M, - const int N, - const void* A, - const int lda, - void* B, - const int ldb, - Context* context, - TypeMeta::Copy copy = nullptr); - -template -TORCH_API void CopyMatrix( - const int M, - const int N, - const T* A, - const int lda, - T* B, - const int ldb, - Context* context); - -template -TORCH_API void CopyMatrix( - const int M, - const int N, - const T* A, - const int A_outer_stride, - const int A_inner_stride, - T* B, - const int B_outer_stride, - const int B_inner_stride, - Context* context); - -template -TORCH_API void CopyVector(const int N, const T* A, T* B, Context* context); - -} // namespace math -} // namespace caffe2 - -#include "caffe2/utils/math-detail.h" -#endif // CAFFE2_UTILS_MATH_H_ diff --git a/caffe2/utils/math/broadcast.cu b/caffe2/utils/math/broadcast.cu deleted file mode 100644 index 8c0c57951926..000000000000 --- a/caffe2/utils/math/broadcast.cu +++ /dev/null @@ -1,110 +0,0 @@ -#include "caffe2/utils/math/broadcast.h" - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -template -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNCHWCUDAKernel( - const int C, - const int M, - const int HxW, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int nc = blockIdx.x / M; - const int c = nc % C; - const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (w < HxW) { - const int index = nc * HxW + w; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif - } -} - -template -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const T* X, - const T* scale, - const T* bias, - T* Y); - -template <> -__global__ void AffineChannelNHWCCUDAKernel( - const int C, - const float* X, - const float* scale, - const float* bias, - float* Y) { - const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (c < C) { - const int index = blockIdx.x * C + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[index] = fmaf(__ldg(X + index), __ldg(scale + c), __ldg(bias + c)); -#else - Y[index] = fmaf(X[index], scale[c], bias[c]); -#endif - } -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNCHWCUDAKernel \ - <<cuda_stream()>>>( \ - C, M, HxW, X, scale, bias, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void AffineChannel( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - const T* scale, \ - const T* bias, \ - T* Y, \ - CUDAContext* context) { \ - const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \ - AffineChannelNHWCCUDAKernel \ - <<cuda_stream()>>>(C, X, scale, bias, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL(float) -#undef CAFFE2_SPECIALIZED_CUDA_AFFINE_CHANNEL - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu deleted file mode 100644 index d1911ae4db4c..000000000000 --- a/caffe2/utils/math/elementwise.cu +++ /dev/null @@ -1,918 +0,0 @@ -#include "caffe2/utils/math/elementwise.h" - -#include - -#include -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/conversions.h" -#include "caffe2/utils/math/half_utils.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -template -__global__ void SinCosCUDAKernel(const int N, const T* X, T* S, T* C) { - const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (i < N) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - c10::cuda::compat::sincos(__ldg(X + i), S + i, C + i); -#else - c10::cuda::compat::sincos(X[i], S + i, C + i); -#endif - } -} - -#if defined(USE_ROCM) - -template -__global__ void AxpyCUDAKernel( - const std::int64_t N, - const TAlpha alpha, - const TData* X, - TData* Y) { - const int64_t index = static_cast(blockIdx.x) * - static_cast(CAFFE_CUDA_NUM_THREADS) + - static_cast(threadIdx.x); - if (index < N) { - Y[index] += static_cast(alpha) * __ldg(X + index); - } -} - -template -__global__ void AxpyCUDAKernel( - const std::int64_t N, - const TAlpha* alpha, - const TData* X, - TData* Y) { - __shared__ TData a; - if (threadIdx.x == 0) { - a = static_cast(__ldg(alpha)); - } - __syncthreads(); - const int64_t index = static_cast(blockIdx.x) * - static_cast(CAFFE_CUDA_NUM_THREADS) + - static_cast(threadIdx.x); - if (index < N) { - Y[index] += a * __ldg(X + index); - } -} - -#define DELEGATE_HALF_AXPY_CUDA_KERNEL(TAlpha, FMAFunc) \ - template <> \ - __global__ void AxpyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - alpha, \ - convert::To(X[index]), \ - convert::To(Y[index]))); \ - } \ - } \ - template <> \ - __global__ void AxpyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - __shared__ TAlpha a; \ - if (threadIdx.x == 0) { \ - a = __ldg(alpha); \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - a, \ - convert::To(X[index]), \ - convert::To(Y[index]))); \ - } \ - } -DELEGATE_HALF_AXPY_CUDA_KERNEL(float, fmaf) -#undef DELEGATE_HALF_AXPY_CUDA_KERNEL - -#endif // USE_ROCM - -template -__global__ void AxpbyCUDAKernel( - const std::int64_t N, - const TAlpha alpha, - const TData* X, - const TAlpha beta, - TData* Y); - -template -__global__ void AxpbyCUDAKernel( - const std::int64_t N, - const TAlpha* alpha, - const TData* X, - const TAlpha* beta, - TData* Y); - -#define DELEGATE_AXPBY_CUDA_KERNEL(TAlpha, TData, FMAFunc) \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - const TAlpha beta, \ - TData* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = FMAFunc( \ - static_cast(alpha), \ - X[index], \ - static_cast(beta) * Y[index]); \ - } \ - } \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - const TAlpha* beta, \ - TData* Y) { \ - __shared__ TData a; \ - __shared__ TData b; \ - if (threadIdx.x == 0) { \ - a = static_cast(*alpha); \ - b = static_cast(*beta); \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = FMAFunc(a, X[index], b * Y[index]); \ - } \ - } -DELEGATE_AXPBY_CUDA_KERNEL(float, float, fmaf) -DELEGATE_AXPBY_CUDA_KERNEL(float, double, fma) -#undef DELEGATE_AXPBY_CUDA_KERNEL - -#define DELEGATE_HALF_AXPBY_CUDA_KERNEL(TAlpha, FMAFunc) \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const at::Half* X, \ - const TAlpha beta, \ - at::Half* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - alpha, \ - convert::To(X[index]), \ - beta * convert::To(Y[index]))); \ - } \ - } \ - template <> \ - __global__ void AxpbyCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const at::Half* X, \ - const TAlpha* beta, \ - at::Half* Y) { \ - __shared__ TAlpha a; \ - __shared__ TAlpha b; \ - if (threadIdx.x == 0) { \ - a = *alpha; \ - b = *beta; \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To(FMAFunc( \ - a, \ - convert::To(X[index]), \ - b * convert::To(Y[index]))); \ - } \ - } -DELEGATE_HALF_AXPBY_CUDA_KERNEL(float, fmaf) -#undef DELEGATE_HALF_AXPBY_CUDA_KERNEL - -template -__global__ void ScaleCUDAKernel( - const std::int64_t N, - const TAlpha alpha, - const TData* X, - TData* Y); - -template -__global__ void ScaleCUDAKernel( - const std::int64_t N, - const TAlpha* alpha, - const TData* X, - TData* Y); - -#define CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(TAlpha, TData) \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, const TAlpha alpha, const TData* X, TData* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = static_cast(alpha) * X[index]; \ - } \ - } \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, const TAlpha* alpha, const TData* X, TData* Y) { \ - __shared__ TData a; \ - if (threadIdx.x == 0) { \ - a = static_cast(*alpha); \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = a * X[index]; \ - } \ - } -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(float, float) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(double, double) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(float, double) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL(std::int64_t, std::int64_t) -#undef CAFFE2_SPECIALIZED_SCALE_CUDA_KERNEL - -#define CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(TAlpha) \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To( \ - alpha * convert::To(X[index])); \ - } \ - } \ - template <> \ - __global__ void ScaleCUDAKernel( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const at::Half* X, \ - at::Half* Y) { \ - __shared__ TAlpha a; \ - if (threadIdx.x == 0) { \ - a = *alpha; \ - } \ - __syncthreads(); \ - const int64_t index = static_cast(blockIdx.x) * \ - static_cast(CAFFE_CUDA_NUM_THREADS) + \ - static_cast(threadIdx.x); \ - if (index < N) { \ - Y[index] = convert::To( \ - a * convert::To(X[index])); \ - } \ - } -CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(float) -#undef CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_SET(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Set( \ - const std::int64_t N, const T alpha, T* Y, CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (alpha == T(0)) { \ - C10_CUDA_CHECK(cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream())); \ - } else { \ - thrust::fill( \ - thrust::cuda::par.on(context->cuda_stream()), Y, Y + N, alpha); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SET(bool) -CAFFE2_SPECIALIZED_CUDA_SET(char) -CAFFE2_SPECIALIZED_CUDA_SET(std::int8_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::int16_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::uint8_t) -CAFFE2_SPECIALIZED_CUDA_SET(std::uint16_t) -CAFFE2_SPECIALIZED_CUDA_SET(float) -CAFFE2_SPECIALIZED_CUDA_SET(double) -CAFFE2_SPECIALIZED_CUDA_SET(at::Half) -CAFFE2_SPECIALIZED_CUDA_SET(at::BFloat16) -#undef CAFFE2_SPECIALIZED_CUDA_SET - -#define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(T, Func, DeviceFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* X, T* Y, CUDAContext* context) { \ - if (N > 0) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - X, \ - X + N, \ - Y, \ - [] __device__(const T x) { return DeviceFunc(x); }); \ - } \ - } -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log1p, log1pf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sin, sinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Asin, asinf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cos, cosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Acos, acosf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tan, tanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Atan, atanf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sinh, sinhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cosh, coshf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Tanh, tanhf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Abs, fabsf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Inv, utils::Inv) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Inv, utils::Inv) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, utils::Square) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt, rsqrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int32_t, - Cube, - utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int64_t, - Cube, - utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube, utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, CdfNorm, normcdff) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, CdfNorm, normcdf) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not, utils::Not) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int32_t, - Neg, - utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int64_t, - Neg, - utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Neg, utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Neg, utils::Negate) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int32_t, - Sign, - utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION( - std::int64_t, - Sign, - utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sign, utils::Sign) -DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sign, utils::Sign) -#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION - -#define DELEGATE_CUDA_POWX(T, DeviceFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Powx( \ - const int N, const T* A, const T b, T* Y, CUDAContext* context) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - A, \ - A + N, \ - Y, \ - [b] __device__(const T x) { return DeviceFunc(x, b); }); \ - } -DELEGATE_CUDA_POWX(float, powf) -#undef DELEGATE_CUDA_POWX - -#define CAFFE2_SPECIALIZED_CUDA_SINCOS(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void SinCos( \ - const int N, const T* X, T* S, T* C, CUDAContext* context) { \ - if (N > 0) { \ - const int K = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - SinCosCUDAKernel \ - <<cuda_stream()>>>( \ - N, X, S, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SINCOS(float) -CAFFE2_SPECIALIZED_CUDA_SINCOS(double) -#undef CAFFE2_SPECIALIZED_CUDA_SINCOS - -#define DELEGATE_CUDA_SCALE(T, CuBLASFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(CuBLASFunc(context->cublas_handle(), N, &alpha, Y, 1)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const T* alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(CuBLASFunc(context->cublas_handle(), N, alpha, Y, 1)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -DELEGATE_CUDA_SCALE(float, cublasSscal) -DELEGATE_CUDA_SCALE(double, cublasDscal) -#undef DELEGATE_CUDA_SCALE - -#if !defined(USE_ROCM) - -#define DELEGATE_CUDA_SCALE_EX( \ - TAlpha, TData, kAlphaType, kDataType, kExecutionType) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(cublasScalEx( \ - context->cublas_handle(), \ - N, \ - &alpha, \ - kAlphaType, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N == 0) { \ - return; \ - } \ - if (Y == X) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(cublasScalEx( \ - context->cublas_handle(), \ - N, \ - alpha, \ - kAlphaType, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } else { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -DELEGATE_CUDA_SCALE_EX(float, double, CUDA_R_32F, CUDA_R_64F, CUDA_R_64F) -DELEGATE_CUDA_SCALE_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) -#undef DELEGATE_CUDA_SCALE_EX - -#endif // USE_ROCM - -#define CAFFE2_SPECIALIZED_CUDA_SCALE(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N > 0) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Scale( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N > 0) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - ScaleCUDAKernel \ - <<cuda_stream()>>>( \ - N, *alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -CAFFE2_SPECIALIZED_CUDA_SCALE(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_SCALE(std::int64_t, std::int64_t) - -#if defined(USE_ROCM) -CAFFE2_SPECIALIZED_CUDA_SCALE(float, double) -CAFFE2_SPECIALIZED_CUDA_SCALE(float, at::Half) -#endif // USE_ROCM -#undef CAFFE2_SPECIALIZED_CUDA_SCALE - -#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Func, DeviceFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* A, const T* B, T* C, CUDAContext* context) { \ - if (N > 0) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - A, \ - A + N, \ - B, \ - C, \ - DeviceFunc); \ - } \ - } -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Add, - thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Add, - thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Add, thrust::plus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Add, utils::HalfAddFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Sub, - thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Sub, - thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Sub, thrust::minus()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Sub, utils::HalfSubFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Mul, - thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Mul, - thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Mul, thrust::multiplies()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Mul, utils::HalfMulFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - Div, - thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - Div, - thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Div, thrust::divides()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(at::Half, Div, utils::HalfDivFunctor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Min, thrust::minimum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Min, thrust::minimum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Max, thrust::maximum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Max, thrust::maximum()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, And, thrust::logical_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Or, thrust::logical_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, Xor, thrust::bit_xor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseAnd, thrust::bit_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - BitwiseAnd, - thrust::bit_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - BitwiseAnd, - thrust::bit_and()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseOr, thrust::bit_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - BitwiseOr, - thrust::bit_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - BitwiseOr, - thrust::bit_or()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(bool, BitwiseXor, thrust::bit_xor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int32_t, - BitwiseXor, - thrust::bit_xor()) -DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION( - std::int64_t, - BitwiseXor, - thrust::bit_xor()) -#undef DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION - -#define DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(T, Func, DeviceComp) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int N, const T* A, const T* B, bool* C, CUDAContext* context) { \ - if (N > 0) { \ - thrust::transform( \ - thrust::cuda::par.on(context->cuda_stream()), \ - A, \ - A + N, \ - B, \ - C, \ - DeviceComp); \ - } \ - } -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, EQ, thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - EQ, - thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - EQ, - thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, EQ, thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, EQ, thrust::equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, NE, thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - NE, - thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - NE, - thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, NE, thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - double, - NE, - thrust::not_equal_to()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LT, thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - LT, - thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - LT, - thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LT, thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LT, thrust::less()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, LE, thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - LE, - thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - LE, - thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, LE, thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, LE, thrust::less_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GT, thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - GT, - thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - GT, - thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GT, thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(double, GT, thrust::greater()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(bool, GE, thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int32_t, - GE, - thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - std::int64_t, - GE, - thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION(float, GE, thrust::greater_equal()) -DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION( - double, - GE, - thrust::greater_equal()) -#undef DELEGATE_SIMPLE_CUDA_COMPARE_FUNCTION - -#define DELEGATE_CUDA_AXPY(T, CuBLASFunc) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE( \ - CuBLASFunc(context->cublas_handle(), N, &alpha, X, 1, Y, 1)); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const T* alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE( \ - cublasSaxpy(context->cublas_handle(), N, alpha, X, 1, Y, 1)); \ - } -DELEGATE_CUDA_AXPY(float, cublasSaxpy) -#undef DELEGATE_CUDA_AXPY - -#if !defined(USE_ROCM) - -#define DELEGATE_CUDA_AXPY_EX( \ - TAlpha, TData, kAlphaType, kDataType, kExecutionType) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(cublasAxpyEx( \ - context->cublas_handle(), \ - N, \ - &alpha, \ - kAlphaType, \ - X, \ - kDataType, \ - 1, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); \ - CUBLAS_ENFORCE(cublasAxpyEx( \ - context->cublas_handle(), \ - N, \ - alpha, \ - kAlphaType, \ - X, \ - kDataType, \ - 1, \ - Y, \ - kDataType, \ - 1, \ - kExecutionType)); \ - } -DELEGATE_CUDA_AXPY_EX(float, double, CUDA_R_32F, CUDA_R_64F, CUDA_R_64F) -DELEGATE_CUDA_AXPY_EX(float, at::Half, CUDA_R_32F, CUDA_R_16F, CUDA_R_32F) -#undef DELEGATE_CUDA_AXPY_EX - -#else // USE_ROCM - -#define CAFFE2_SPECIALIZED_CUDA_AXPY(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpy( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_AXPY(float, double) -CAFFE2_SPECIALIZED_CUDA_AXPY(float, at::Half) -#undef CAFFE2_SPECIALIZED_CUDA_AXPY - -#endif // USE_ROCM - -#define CAFFE2_SPECIALIZED_CUDA_AXPBY(TAlpha, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpby( \ - const std::int64_t N, \ - const TAlpha alpha, \ - const TData* X, \ - const TAlpha beta, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpbyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, beta, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Axpby( \ - const std::int64_t N, \ - const TAlpha* alpha, \ - const TData* X, \ - const TAlpha* beta, \ - TData* Y, \ - CUDAContext* context) { \ - const std::int64_t M = DivUp(N, CAFFE_CUDA_NUM_THREADS); \ - AxpbyCUDAKernel \ - <<cuda_stream()>>>( \ - N, alpha, X, beta, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_AXPBY(float, float) -CAFFE2_SPECIALIZED_CUDA_AXPBY(float, double) -CAFFE2_SPECIALIZED_CUDA_AXPBY(float, at::Half) -#undef CAFFE2_SPECIALIZED_CUDA_AXPBY - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu deleted file mode 100644 index d59cbd387753..000000000000 --- a/caffe2/utils/math/reduce.cu +++ /dev/null @@ -1,593 +0,0 @@ -#include "caffe2/utils/math/reduce.h" - -#include -#include -#include -#include -#include -#include "caffe2/utils/cub_namespace.cuh" -#include - -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math/elementwise.h" -#include "caffe2/utils/math/reduce.cuh" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -template -__global__ void RowwiseReduceCUDAKernel( - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - const int r = blockIdx.x; - T val = init; - for (int c = threadIdx.x; c < cols; c += blockDim.x) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + r * cols + c)); -#else - val = reducer(val, X[r * cols + c]); -#endif - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[r] = val * alpha; - } -} - -template -__global__ void ColwiseReduceCUDAKernel( - const int rows, - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - const int c = blockIdx.x; - T val = init; - for (int r = threadIdx.x; r < rows; r += blockDim.x) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + r * cols + c)); -#else - val = reducer(val, X[r * cols + c]); -#endif - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[c] = val * alpha; - } -} - -template -__global__ void BothEndsReduceCUDAKernel( - const int M, - const int N, - const int K, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce2D::TempStorage - temp_storage; - const int n = blockIdx.x; - T val = init; - for (int m = threadIdx.x; m < M; m += blockDim.x) { - for (int k = threadIdx.y; k < K; k += blockDim.y) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + (m * N + n) * K + k)); -#else - val = reducer(val, X[(m * N + n) * K + k]); -#endif - } - } - val = BlockReduce2D(temp_storage) - .Reduce(val, reducer); - if (threadIdx.x == 0 && threadIdx.y == 0) { - Y[n] = val * alpha; - } -} - -template -__global__ void ReduceTensorCUDAKernel( - const int inner_size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - const int x = blockIdx.x; - T val = init; - for (int y = threadIdx.x; y < inner_size; y += blockDim.x) { - int X_index = 0; - int Y_index = x * inner_size + y; -#pragma unroll - for (int d = D - 1; d >= 0; --d) { - X_index += Y_index % Y_dims.data[d] * X_strides.data[d]; - Y_index /= Y_dims.data[d]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val = reducer(val, __ldg(X + X_index)); -#else - val = reducer(val, X[X_index]); -#endif - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[x] = val * alpha; - } -} - -template -void ReduceTensorCUDAImpl( - const int outer_size, - const int inner_size, - const int* dims, - const int* axes, - const Reducer& reducer, - const T init, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - SimpleArray X_strides; - SimpleArray Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = dims[axes[i]]; - } - ReduceTensorCUDAKernel - <<cuda_stream()>>>( - inner_size, X_strides, Y_dims, reducer, init, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void ReduceTensorCUDA( - const int ndim, - const int* X_dims, - const int* Y_dims, - const Reducer& reducer, - const T init, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - CAFFE_ENFORCE(utils::CheckReduceDims(ndim, X_dims, Y_dims)); - const int X_size = - std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies()); - const int Y_size = - std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies()); - if (X_size == 0) { - Set(Y_size, init * alpha, Y, context); - return; - } - if (std::equal(X_dims, X_dims + ndim, Y_dims)) { - Scale(X_size, alpha, X, Y, context); - return; - } - int rows; - int cols; - if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - RowwiseReduceCUDAKernel - <<cuda_stream()>>>( - cols, reducer, init, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - ColwiseReduceCUDAKernel - <<cuda_stream()>>>( - rows, cols, reducer, init, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - int M; - int N; - int K; - if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) { - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_2( - K, - BothEndsReduceCUDAKernel, - T, - Reducer, - N, - context->cuda_stream(), - M, - N, - K, - reducer, - init, - alpha, - X, - Y); - return; - } - std::vector axes(ndim); - utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data()); - const int outer_size = Y_size; - const int inner_size = X_size / Y_size; - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( - ndim, - ReduceTensorCUDAImpl, - T, - Reducer, - outer_size, - inner_size, - X_dims, - axes.data(), - reducer, - init, - alpha, - X, - Y, - context); -} - -template -__global__ void -RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) { - __shared__ typename BlockReduce::TempStorage m_storage; - __shared__ typename BlockReduce::TempStorage v_storage; - const T scale = T(1) / static_cast(cols); - const int r = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int c = threadIdx.x; c < cols; c += blockDim.x) { - const int X_index = r * cols + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - m_val = BlockReduce(m_storage).Sum(m_val); - v_val = BlockReduce(v_storage).Sum(v_val); - if (threadIdx.x == 0) { - const T mu = m_val * scale; - mean[r] = mu; - var[r] = v_val * scale - mu * mu; - } -} - -template -__global__ void ColwiseMomentsCUDAKernel( - const int rows, - const int cols, - const T* X, - T* mean, - T* var) { - __shared__ typename BlockReduce::TempStorage m_storage; - __shared__ typename BlockReduce::TempStorage v_storage; - const T scale = T(1) / static_cast(rows); - const int c = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int r = threadIdx.x; r < rows; r += blockDim.x) { - const int X_index = r * cols + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - m_val = BlockReduce(m_storage).Sum(m_val); - v_val = BlockReduce(v_storage).Sum(v_val); - if (threadIdx.x == 0) { - const T mu = m_val * scale; - mean[c] = mu; - var[c] = v_val * scale - mu * mu; - } -} - -template -__global__ void BothEndsMomentsCUDAKernel( - const int M, - const int N, - const int K, - const T* X, - T* mean, - T* var) { - __shared__ - typename BlockReduce2D::TempStorage m_storage; - __shared__ - typename BlockReduce2D::TempStorage v_storage; - const T scale = T(1) / static_cast(M * K); - const int n = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int m = threadIdx.x; m < M; m += blockDim.x) { - for (int k = threadIdx.y; k < K; k += blockDim.y) { - const int X_index = (m * N + n) * K + k; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - } - m_val = BlockReduce2D(m_storage).Sum(m_val); - v_val = BlockReduce2D(v_storage).Sum(v_val); - if (threadIdx.x == 0 && threadIdx.y == 0) { - const T mu = m_val * scale; - mean[n] = mu; - var[n] = v_val * scale - mu * mu; - } -} - -template -__global__ void MomentsCUDAKernel( - const int inner_size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const T* X, - T* mean, - T* var) { - __shared__ typename BlockReduce::TempStorage m_storage; - __shared__ typename BlockReduce::TempStorage v_storage; - const T scale = T(1) / static_cast(inner_size); - const int x = blockIdx.x; - T m_val = 0; - T v_val = 0; - for (int y = threadIdx.x; y < inner_size; y += blockDim.x) { - int X_index = 0; - int Y_index = x * inner_size + y; -#pragma unroll - for (int d = D - 1; d >= 0; --d) { - X_index += Y_index % Y_dims.data[d] * X_strides.data[d]; - Y_index /= Y_dims.data[d]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - m_val += __ldg(X + X_index); - v_val += __ldg(X + X_index) * __ldg(X + X_index); -#else - m_val += X[X_index]; - v_val += X[X_index] * X[X_index]; -#endif - } - m_val = BlockReduce(m_storage).Sum(m_val); - v_val = BlockReduce(v_storage).Sum(v_val); - if (threadIdx.x == 0) { - const T mu = m_val * scale; - mean[x] = mu; - var[x] = v_val * scale - mu * mu; - } -} - -template -void MomentsCUDAImpl( - const int outer_size, - const int inner_size, - const int* dims, - const int* axes, - const T* X, - T* mean, - T* var, - CUDAContext* context) { - SimpleArray X_strides; - SimpleArray Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = dims[axes[i]]; - } - MomentsCUDAKernel - <<cuda_stream()>>>( - inner_size, X_strides, Y_dims, X, mean, var); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void MomentsCUDA( - const int ndim, - const int* X_dims, - const int* Y_dims, - const T* X, - T* mean, - T* var, - CUDAContext* context) { - CAFFE_ENFORCE(utils::CheckReduceDims(ndim, X_dims, Y_dims)); - const int X_size = - std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies()); - const int Y_size = - std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies()); - if (X_size == 0) { - Set(Y_size, T(0), mean, context); - Set(Y_size, T(0), var, context); - return; - } - if (std::equal(X_dims, X_dims + ndim, Y_dims)) { - C10_CUDA_CHECK(cudaMemcpyAsync( - mean, - X, - sizeof(T) * X_size, - cudaMemcpyDeviceToDevice, - context->cuda_stream())); - Set(Y_size, T(0), var, context); - return; - } - int rows; - int cols; - if (utils::IsRowwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - RowwiseMomentsCUDAKernel - <<cuda_stream()>>>( - cols, X, mean, var); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - if (utils::IsColwiseReduce(ndim, X_dims, Y_dims, &rows, &cols)) { - ColwiseMomentsCUDAKernel - <<cuda_stream()>>>( - rows, cols, X, mean, var); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - int M; - int N; - int K; - if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) { - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( - K, - BothEndsMomentsCUDAKernel, - T, - N, - context->cuda_stream(), - M, - N, - K, - X, - mean, - var); - return; - } - std::vector axes(ndim); - utils::ComputeTransposeAxesForReduceOp(ndim, Y_dims, axes.data()); - const int outer_size = Y_size; - const int inner_size = X_size / Y_size; - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - ndim, - MomentsCUDAImpl, - T, - outer_size, - inner_size, - X_dims, - axes.data(), - X, - mean, - var, - context); -} - -} // namespace - -#define DELEGATE_CUDA_REDUCE_FUNCTION(T, Func, Reducer, kInit) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context, \ - bool) { \ - ReduceTensorCUDA( \ - ndim, X_dims, Y_dims, Reducer(), kInit, alpha, X, Y, context); \ - } -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int32_t, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int64_t, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - float, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - double, - ReduceMin, - cub::Min, - std::numeric_limits::max()) -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int32_t, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION( - std::int64_t, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION( - float, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION( - double, - ReduceMax, - cub::Max, - std::numeric_limits::lowest()) -DELEGATE_CUDA_REDUCE_FUNCTION(std::int32_t, ReduceSum, cub::Sum, 0) -DELEGATE_CUDA_REDUCE_FUNCTION(std::int64_t, ReduceSum, cub::Sum, 0LL) -DELEGATE_CUDA_REDUCE_FUNCTION(float, ReduceSum, cub::Sum, 0.0f) -DELEGATE_CUDA_REDUCE_FUNCTION(double, ReduceSum, cub::Sum, 0.0) -#undef DELEGATE_CUDA_REDUCE_FUNCTION - -#define CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ReduceMean( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context, \ - bool) { \ - int scale = 1; \ - for (int i = 0; i < ndim; ++i) { \ - if (Y_dims[i] == 1) { \ - scale *= X_dims[i]; \ - } \ - } \ - ReduceTensorCUDA( \ - ndim, \ - X_dims, \ - Y_dims, \ - cub::Sum(), \ - T(0), \ - alpha / static_cast(scale), \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN(float) -#undef CAFFE2_SPECIALIZED_CUDA_REDUCE_MEAN - -#define CAFFE2_SPECIALIZED_CUDA_MOMENTS(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Moments( \ - const int ndim, \ - const int* X_dims, \ - const int* Y_dims, \ - const T* X, \ - T* mean, \ - T* var, \ - CUDAContext* context, \ - bool) { \ - MomentsCUDA(ndim, X_dims, Y_dims, X, mean, var, context); \ - } -CAFFE2_SPECIALIZED_CUDA_MOMENTS(float) -CAFFE2_SPECIALIZED_CUDA_MOMENTS(double) -#undef CAFFE2_SPECIALIZED_CUDA_MOMENTS - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh deleted file mode 100644 index 18bdca11b9de..000000000000 --- a/caffe2/utils/math/reduce.cuh +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef CAFFE2_UTILS_MATH_REDUCE_CUH_ -#define CAFFE2_UTILS_MATH_REDUCE_CUH_ - -#include "caffe2/utils/cub_namespace.cuh" -#include - -#include "caffe2/core/common_gpu.h" - -namespace caffe2 { - -template -using BlockReduce = cub::BlockReduce; - -template -using BlockReduce2D = cub:: - BlockReduce; - -#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( \ - size, Func, T, grid_dim, cuda_stream, ...) \ - do { \ - if (size >= 128) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 64) { \ - Func<<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 32) { \ - Func<<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - Func<<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } while (false) - -#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_2( \ - size, Func, T1, T2, grid_dim, cuda_stream, ...) \ - do { \ - if (size >= 128) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 64) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else if (size >= 32) { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - Func \ - <<>>(__VA_ARGS__); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } while (false) - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_MATH_REDUCE_CUH_ diff --git a/caffe2/utils/math/transpose.cu b/caffe2/utils/math/transpose.cu deleted file mode 100644 index c3e213190856..000000000000 --- a/caffe2/utils/math/transpose.cu +++ /dev/null @@ -1,233 +0,0 @@ -#include "caffe2/utils/math/transpose.h" - -#include -#include -#include - -#include "caffe2/core/common_gpu.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math/utils.h" - -namespace caffe2 { -namespace math { - -namespace { - -constexpr int kTileDim = 32; -constexpr int kBlockRows = 8; - -// Splits the original matrix into submatrices with size 32 * 32. -// Each block transposes one submatrix by loading it into shared memory. -// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/ -template -__global__ void BatchTranspose2DCUDAKernel( - const TIndex H, - const TIndex W, - const TIndex dh, - const TIndex dw, - const TData* X, - TData* Y) { - __shared__ TData tile[kTileDim][kTileDim + 1]; - const TIndex n = blockIdx.x / (dh * dw); - const TIndex k = blockIdx.x % (dh * dw); - const TIndex r = k / dw; - const TIndex c = k % dw; - const TIndex offset = n * H * W; - int x = c * kTileDim + threadIdx.x; - int y = r * kTileDim + threadIdx.y; - if (x < W) { - for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) { -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - tile[threadIdx.y + i][threadIdx.x] = __ldg(X + offset + (y + i) * W + x); -#else - tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x]; -#endif - } - } - __syncthreads(); - x = r * kTileDim + threadIdx.x; - y = c * kTileDim + threadIdx.y; - if (x < H) { - for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) { - Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i]; - } - } -} - -template -void BatchTranspose2DCUDAImpl( - const TIndex N, - const TIndex H, - const TIndex W, - const TData* X, - TData* Y, - CUDAContext* context) { - const TIndex dh = DivUp(H, kTileDim); - const TIndex dw = DivUp(W, kTileDim); - BatchTranspose2DCUDAKernel - <<cuda_stream()>>>( - H, W, dh, dw, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -#define DELEGATE_TRANSPOSE_2D_CUDA_IMPL(TIndex, TData, CuBLASFunc) \ - template <> \ - void BatchTranspose2DCUDAImpl( \ - const TIndex N, \ - const TIndex H, \ - const TIndex W, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - if (N == 1) { \ - const TData kAlpha = TData(1); \ - const TData kBeta = TData(0); \ - CUBLAS_ENFORCE(cublasSetPointerMode( \ - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); \ - CUBLAS_ENFORCE(CuBLASFunc( \ - context->cublas_handle(), \ - CUBLAS_OP_T, \ - CUBLAS_OP_N, \ - H, \ - W, \ - &kAlpha, \ - X, \ - W, \ - &kBeta, \ - Y, \ - H, \ - Y, \ - H)); \ - } else { \ - const TIndex dh = DivUp(H, kTileDim); \ - const TIndex dw = DivUp(W, kTileDim); \ - BatchTranspose2DCUDAKernel \ - <<cuda_stream()>>>(H, W, dh, dw, X, Y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, float, cublasSgeam) -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int64_t, float, cublasSgeam) -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, double, cublasDgeam) -DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int64_t, double, cublasDgeam) -#undef DELEGATE_TRANSPOSE_2D_CUDA_IMPL - -template -__global__ void TransposeCUDAKernel( - const TIndex size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const TData* X, - TData* Y) { - const int Y_index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (Y_index < size) { - TIndex X_index = 0; - TIndex v = Y_index; -#pragma unroll - for (int i = D - 1; i >= 0; --i) { - X_index += v % Y_dims.data[i] * X_strides.data[i]; - v /= Y_dims.data[i]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[Y_index] = __ldg(X + X_index); -#else - Y[Y_index] = X[X_index]; -#endif - } -} - -template -void TransposeCUDAImpl( - const TIndex* dims, - const int* axes, - const TData* X, - TData* Y, - CUDAContext* context) { - SimpleArray X_strides; - SimpleArray Y_dims; - utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); - TIndex size = 1; - for (int i = 0; i < D; ++i) { - Y_dims.data[i] = dims[axes[i]]; - size *= dims[i]; - } - const TIndex M = DivUp(size, CAFFE_CUDA_NUM_THREADS); - TransposeCUDAKernel - <<cuda_stream()>>>( - size, X_strides, Y_dims, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(TIndex, TData) \ - template <> \ - CAFFE2_CUDA_EXPORT void Transpose( \ - const int ndim, \ - const TIndex* dims, \ - const int* axes, \ - const TData* X, \ - TData* Y, \ - CUDAContext* context) { \ - const TIndex size = std::accumulate( \ - dims, dims + ndim, TIndex(1), std::multiplies()); \ - if (size == 0) { \ - return; \ - } \ - if (utils::IsIdentityPermutation(ndim, axes)) { \ - context->template CopySameDevice(size, X, Y); \ - return; \ - } \ - if (utils::IsBatchTranspose2D(ndim, axes)) { \ - const int H = dims[ndim - 2]; \ - const int W = dims[ndim - 1]; \ - const int N = size / (H * W); \ - BatchTranspose2DCUDAImpl(N, H, W, X, Y, context); \ - return; \ - } \ - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( \ - ndim, TransposeCUDAImpl, TIndex, TData, dims, axes, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, float) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, float) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, double) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, double) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, std::int32_t) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, std::int64_t) -CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, std::int64_t) -#undef CAFFE2_SPECIALIZED_CUDA_TRANSPOSE - -#define CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void NCHW2NHWC( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - BatchTranspose2DCUDAImpl(N, C, HxW, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(float) -#undef CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC - -#define CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void NHWC2NCHW( \ - const int N, \ - const int C, \ - const int HxW, \ - const T* X, \ - T* Y, \ - CUDAContext* context) { \ - BatchTranspose2DCUDAImpl(N, HxW, C, X, Y, context); \ - } -CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW(float) -#undef CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu deleted file mode 100644 index e6dfbf85039f..000000000000 --- a/caffe2/utils/math_gpu.cu +++ /dev/null @@ -1,2871 +0,0 @@ -// Implements the math functions for GPU. - -#include "caffe2/utils/math.h" - -#include -#include -#include -#include - -#include -#include -#include "caffe2/utils/cub_namespace.cuh" - -#include -#include -#include - -#include "caffe2/core/context_gpu.h" -#include "caffe2/utils/GpuAtomics.cuh" -#include "caffe2/utils/conversions.h" - -#include "caffe2/utils/fixed_divisor.h" -// TODO: Move this to fixed_divisor.h -#if defined(USE_ROCM) -#define FIXED_DIVISOR int32_t -#define FIXED_DIVISOR_DIV(d, n) (n / d) -#define FIXED_DIVISOR_MOD(d, n) (n % d) -#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) \ - do { \ - const auto n_copy = n; \ - *q = n_copy / d; \ - *r = n_copy % d; \ - } while (0) -#else // USE_ROCM -#define FIXED_DIVISOR FixedDivisor -#define FIXED_DIVISOR_DIV(d, n) (d.Div(n)) -#define FIXED_DIVISOR_MOD(d, n) (d.Mod(n)) -#define FIXED_DIVISOR_DIV_MOD(d, n, q, r) (d.DivMod(n, q, r)) -#endif // USE_ROCM - -#if defined(USE_ROCM) -#define CUBLAS_HALF_TYPE hipblasHalf -#define HIPBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT -// until we use hipblas v2 -// hipify correctly maps things like CUDA_R_16F to HIP_R_16F, -// however hipblas v1 is still using its custom type -#ifndef HIPBLAS_V2 -#define HIP_R_16F HIPBLAS_R_16F -#define HIP_R_32F HIPBLAS_R_32F -#endif // HIPBLAS_V2 -#else // USE_ROCM -#define CUBLAS_HALF_TYPE __half -#endif // USE_ROCM - -#include "caffe2/utils/math/utils.h" - -#if THRUST_VERSION >= 100800 -#define THRUST_SUPPORTS_PER_THREAD -#endif // THRUST_VERSION >= 100800 - -namespace caffe2 { -namespace math { - -namespace { - -#define DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Func, expr) \ - template \ - struct Func##Functor { \ - inline __host__ __device__ T \ - operator()(const T& lhs, const T& rhs) const { \ - return lhs expr rhs; \ - } \ - }; \ - template <> \ - struct Func##Functor { \ - inline __host__ __device__ at::Half operator()( \ - const at::Half& lhs, \ - const at::Half& rhs) const { \ - return convert::To(convert::To( \ - lhs) expr convert::To(rhs)); \ - } \ - }; -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Add, +) -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Sub, -) -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Mul, *) -DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Div, /) -#undef DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR - -template -__global__ void SimpleBinaryOpCUDAKernel( - const int N, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(i, N) { - C[i] = op(A[i], B[i]); - } -} - -template -__global__ void RowwiseBinaryOpCUDAKenel( - const int size, - const FIXED_DIVISOR cols, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(C_index, size) { - const int j = FIXED_DIVISOR_MOD(cols, C_index); - const int A_index = broadcast_1st ? j : C_index; - const int B_index = broadcast_1st ? C_index : j; - C[C_index] = op(A[A_index], B[B_index]); - } -} - -template -__global__ void ColwiseBinaryOpCUDAKenel( - const int size, - const FIXED_DIVISOR cols, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(C_index, size) { - const int i = FIXED_DIVISOR_DIV(cols, C_index); - const int A_index = broadcast_1st ? i : C_index; - const int B_index = broadcast_1st ? C_index : i; - C[C_index] = op(A[A_index], B[B_index]); - } -} - -template -__global__ void BroadcastBinaryOpCUDAKernel( - const int size, - const SimpleArray A_strides, - const SimpleArray B_strides, - const SimpleArray C_dims, - const BinaryOperator op, - const TIn* A, - const TIn* B, - TOut* C) { - CUDA_1D_KERNEL_LOOP(C_index, size) { - int A_index = 0; - int B_index = 0; - int C_index_val = C_index; -#pragma unroll - for (int i = D - 1; i >= 0; --i) { - int d; - FIXED_DIVISOR_DIV_MOD(C_dims.data[i], C_index_val, &C_index_val, &d); - A_index += d * A_strides.data[i]; - B_index += d * B_strides.data[i]; - } - C[C_index] = op(A[A_index], B[B_index]); - } -} - -template -CAFFE2_CUDA_EXPORT void BinaryOpWith2DBroadcasting( - const int rows, - const int cols, - const bool rowwise_broadcast, - const bool broadcast_1st, - const BinaryOperator& op, - const TIn* A, - const TIn* B, - TOut* C, - CUDAContext* context) { - if (rows == 0 || cols == 0) { - return; - } - const int size = rows * cols; - const FIXED_DIVISOR cols_div(cols); - if (rowwise_broadcast) { - if (broadcast_1st) { - RowwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - RowwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } else { - if (broadcast_1st) { - ColwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - ColwiseBinaryOpCUDAKenel - <<cuda_stream()>>>(size, cols_div, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } -} - -template -CAFFE2_CUDA_EXPORT void BroadcastBinaryOpImpl( - const int* A_dims, - const int* B_dims, - const int* C_dims, - const BinaryOperator& op, - const TIn* A, - const TIn* B, - TOut* C, - CUDAContext* context) { - SimpleArray A_strides_array; - SimpleArray B_strides_array; - SimpleArray C_dims_array; - int A_stride = 1; - int B_stride = 1; - for (int i = D - 1; i >= 0; --i) { - if (C_dims[i] == 0) { - return; - } - A_strides_array.data[i] = A_dims[i] == 1 ? 0 : A_stride; - B_strides_array.data[i] = B_dims[i] == 1 ? 0 : B_stride; - A_stride *= A_dims[i]; - B_stride *= B_dims[i]; - C_dims_array.data[i] = FIXED_DIVISOR(C_dims[i]); - } - const int size = - std::accumulate(C_dims, C_dims + D, 1, std::multiplies()); - BroadcastBinaryOpCUDAKernel - <<cuda_stream()>>>( - size, A_strides_array, B_strides_array, C_dims_array, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -CAFFE2_CUDA_EXPORT void BroadcastBinaryOp( - const int A_ndim, - const int* A_dims, - const int B_ndim, - const int* B_dims, - const BinaryOperator& op, - const TIn* A, - const TIn* B, - TOut* C, - CUDAContext* context) { - const int ndim = std::max(A_ndim, B_ndim); - std::vector A_dims_array(ndim); - std::vector B_dims_array(ndim); - std::vector C_dims_array(ndim); - utils::ComputeBroadcastBinaryOpDims( - A_ndim, - A_dims, - B_ndim, - B_dims, - A_dims_array.data(), - B_dims_array.data(), - C_dims_array.data()); - if (A_dims_array == B_dims_array) { - const int size = std::accumulate( - C_dims_array.cbegin(), C_dims_array.cend(), 1, std::multiplies()); - SimpleBinaryOpCUDAKernel - <<cuda_stream()>>>(size, op, A, B, C); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return; - } - int rows; - int cols; - bool broadcast_1st; - if (utils::IsRowwiseBroadcastBinaryOp( - ndim, - A_dims_array.data(), - B_dims_array.data(), - &rows, - &cols, - &broadcast_1st)) { - BinaryOpWith2DBroadcasting( - rows, cols, true, broadcast_1st, op, A, B, C, context); - return; - } - if (utils::IsColwiseBroadcastBinaryOp( - ndim, - A_dims_array.data(), - B_dims_array.data(), - &rows, - &cols, - &broadcast_1st)) { - BinaryOpWith2DBroadcasting( - rows, cols, false, broadcast_1st, op, A, B, C, context); - return; - } - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3( - ndim, - BroadcastBinaryOpImpl, - TIn, - TOut, - BinaryOperator, - A_dims_array.data(), - B_dims_array.data(), - C_dims_array.data(), - op, - A, - B, - C, - context); -} - -} // namespace - -#define DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - CAFFE2_CUDA_EXPORT void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - RowwiseBinaryOpCUDAKenel, true> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Rowwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - RowwiseBinaryOpCUDAKenel, false> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Colwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - ColwiseBinaryOpCUDAKenel, true> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - template <> \ - CAFFE2_CUDA_EXPORT void Colwise##Func( \ - const int rows, \ - const int cols, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FIXED_DIVISOR cols_div(cols); \ - ColwiseBinaryOpCUDAKenel, false> \ - <<cuda_stream()>>>(size, cols_div, Op(), A, B, C); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - -#define DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(EQ, thrust::equal_to) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(NE, thrust::not_equal_to) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(LT, thrust::less) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(LE, thrust::less_equal) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(GT, thrust::greater) -DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION(GE, thrust::greater_equal) - -#undef DEFINE_2D_BROADCAST_CUDA_COMPARE_FUNCTION - -#define DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(double, double, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(at::Half, at::Half, Func, Op) - -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Add, AddFunctor) -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Sub, SubFunctor) -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Mul, MulFunctor) -DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION(Div, DivFunctor) - -#undef DEFINE_2D_BROADCAST_CUDA_BINARY_FUNCTION - -DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) -DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) -DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) - -#define DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int64_t, std::int64_t, Func, Op) - -DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) -DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) -DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) - -#undef DEFINE_2D_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION - -#undef DELEGATE_2D_BROADCAST_CUDA_BINARY_FUNCTION - -#define DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - CAFFE2_CUDA_EXPORT void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CUDAContext* context) { \ - BroadcastBinaryOp>( \ - A_ndim, A_dims, B_ndim, B_dims, Op(), A, B, C, context); \ - } - -#define DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(double, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) - -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(EQ, thrust::equal_to) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(NE, thrust::not_equal_to) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(LT, thrust::less) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(LE, thrust::less_equal) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(GT, thrust::greater) -DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION(GE, thrust::greater_equal) - -#undef DEFINE_BROADCAST_CUDA_COMPARE_FUNCTION - -#define DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(double, double, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(at::Half, at::Half, Func, Op) - -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Add, AddFunctor) -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Sub, SubFunctor) -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Mul, MulFunctor) -DEFINE_BROADCAST_CUDA_BINARY_FUNCTION(Div, DivFunctor) - -#undef DEFINE_BROADCAST_CUDA_BINARY_FUNCTION - -DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) -DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) -DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) - -#define DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION( \ - std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) - -DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) -DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) -DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) - -#undef DEFINE_BROADCAST_CUDA_BITWISE_BINARY_FUNCTION - -#undef DELEGATE_BROADCAST_CUDA_BINARY_FUNCTION - -#define DELEGATE_REDUCTION_FUNCTION(T, Funcname, func) \ - template <> \ - CAFFE2_CUDA_EXPORT void Funcname( \ - const int N, \ - const T* src, \ - T* dst, \ - Tensor* scratch_ptr, \ - CUDAContext* context) { \ - size_t memRequired = 0; \ - cub::DeviceReduce::func( \ - nullptr, memRequired, src, dst, N, context->cuda_stream()); \ - auto buffer_size = \ - static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); \ - scratch_ptr->Resize(std::vector{buffer_size}); \ - cub::DeviceReduce::func( \ - static_cast(scratch_ptr->mutable_data()), \ - memRequired, \ - src, \ - dst, \ - N, \ - context->cuda_stream()); \ - } - -DELEGATE_REDUCTION_FUNCTION(float, ReduceMin, Min) -DELEGATE_REDUCTION_FUNCTION(float, ReduceMax, Max) -DELEGATE_REDUCTION_FUNCTION(int32_t, ReduceMax, Max) -DELEGATE_REDUCTION_FUNCTION(int64_t, ReduceMax, Max) - -#undef DELEGATE_REDUCTION_FUNCTION - -// Caffe2 gemm provides a simpler interface to the gemm functions, with the -// limitation that the data has to be contiguous in memory. -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemm( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - N)); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const at::Half* B, - const float beta, - at::Half* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - // hipblas doesn't support hipblasSgemmEx type API. - // It has more general hipblasGemmEx API which is more close to cublasGemmEx. - // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, - // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C - HIPBLAS_ENFORCE(hipblasGemmEx( - context->hipblas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - HIPBLAS_R_16F, - ldb, - A, - HIPBLAS_R_16F, - lda, - &beta, - C, - HIPBLAS_R_16F, - N, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT)); -#else - CUBLAS_ENFORCE(cublasSgemmEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &beta, - C, - CUDA_R_16F, - N)); -#endif // USE_ROCM - } else if (math_type == TensorProto_DataType_FLOAT16) { - // convert alpha, beta from float -> __half - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - // call cublasHgemm - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemm( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(B), - ldb, - reinterpret_cast(A), - lda, - reinterpret_cast(&beta_fp16), - reinterpret_cast(C), - N)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - -template <> -CAFFE2_CUDA_EXPORT void BiasCHW( - const float* bias, - const float* bias_multiplier, - const int bias_channels, - const int image_size, - float* image, - CUDAContext* context) { - Gemm( - CblasNoTrans, - CblasNoTrans, - bias_channels, - image_size, - 1, - 1, - bias, - bias_multiplier, - 1, - image, - context); -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float** A, - const float** B, - const float beta, - float** C, - CUDAContext* context, - TensorProto::DataType math_type) { -#if defined(USE_ROCM) - // loop over matrices in the batch - for (int i = 0; i < batch_size; ++i) { - Gemm( - trans_A, - trans_B, - M, - N, - K, - alpha, - A[i], - B[i], - beta, - C[i], - context, - math_type); - } -#else - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - thrust::device_vector A_device(A, A + batch_size); - thrust::device_vector B_device(B, B + batch_size); - thrust::device_vector C_device(C, C + batch_size); - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemmBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B_device.data().get(), - ldb, - A_device.data().get(), - lda, - &beta, - C_device.data().get(), - ldc, - batch_size)); -#endif -} - -template <> -CAFFE2_CUDA_EXPORT void GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int A_stride, - const float* B, - const int B_stride, - const float beta, - float* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemmStridedBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - ldb, - B_stride, - A, - lda, - A_stride, - &beta, - C, - ldc, - C_stride, - batch_size)); -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half** A, - const at::Half** B, - const float beta, - at::Half** C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - thrust::device_vector A_device(A, A + batch_size); - thrust::device_vector B_device(B, B + batch_size); - thrust::device_vector C_device(C, C + batch_size); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - auto compute_type = HIPBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - CUBLAS_ENFORCE(cublasGemmBatchedEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B_device.data().get(), - CUDA_R_16F, - ldb, - A_device.data().get(), - CUDA_R_16F, - lda, - &beta, - C_device.data().get(), - CUDA_R_16F, - ldc, - batch_size, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } else if (math_type == TensorProto_DataType_FLOAT16) { - // Convert alpha, beta from float -> __half - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - thrust::host_vector A_array(batch_size); - thrust::host_vector B_array(batch_size); - thrust::host_vector<__half*> C_array(batch_size); - for (int i = 0; i < batch_size; ++i) { - A_array[i] = reinterpret_cast(A[i]); - B_array[i] = reinterpret_cast(B[i]); - C_array[i] = reinterpret_cast<__half*>(C[i]); - } - thrust::device_vector A_device( - A_array.cbegin(), A_array.cend()); - thrust::device_vector B_device( - B_array.cbegin(), B_array.cend()); - thrust::device_vector<__half*> C_device(C_array.cbegin(), C_array.cend()); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemmBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(B_device.data().get()), - ldb, - reinterpret_cast(A_device.data().get()), - lda, - reinterpret_cast(&beta_fp16), - reinterpret_cast(C_device.data().get()), - ldc, - batch_size)); - } else { - CAFFE_THROW("Unsupported math type"); - } -} - -template <> -CAFFE2_CUDA_EXPORT void GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const int A_stride, - const at::Half* B, - const int B_stride, - const float beta, - at::Half* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const int ldc = N; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - auto compute_type = HIPBLAS_COMPUTE_32F; -#else - auto compute_type = CUDA_R_32F; -#endif - CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - B_stride, - A, - CUDA_R_16F, - lda, - A_stride, - &beta, - C, - CUDA_R_16F, - ldc, - C_stride, - batch_size, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } else if (math_type == TensorProto_DataType_FLOAT16) { - // Convert alpha, beta from float -> __half - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemmStridedBatched( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(B), - ldb, - B_stride, - reinterpret_cast(A), - lda, - A_stride, - reinterpret_cast(&beta_fp16), - reinterpret_cast(C), - ldc, - C_stride, - batch_size)); - } else { - CAFFE_THROW("Unsupported math type"); - } -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CUDAContext* context, - TensorProto::DataType math_type) { - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemv( - context->cublas_handle(), - cu_trans_A, - N, - M, - &alpha, - A, - N, - x, - 1, - &beta, - y, - 1)); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const at::Half* A, - const at::Half* x, - const float beta, - at::Half* y, - CUDAContext* context, - TensorProto::DataType math_type) { - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - - // sort out what we need to call cublasSgemmEx / cublasHgemm - const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M; - const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N; - const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k; - const int ldc = m; - - if (math_type == TensorProto_DataType_FLOAT) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); -#if defined(USE_ROCM) - // hipblas doesn't support hipblasSgemmEx type API. - // It has more general hipblasGemmEx API which is more close to cublasGemmEx. - // hipblasGemmEx does D = alpha*op( A )*op( B ) + beta*C, - // whereas cublasSgemmEx does C = alpha*op( A )*op( B ) + beta*C - HIPBLAS_ENFORCE(hipblasGemmEx( - context->hipblas_handle(), - cu_trans_A, - HIPBLAS_OP_N, - m, - 1, - k, - &alpha, - A, - HIPBLAS_R_16F, - lda, - x, - HIPBLAS_R_16F, - k, - &beta, - y, - HIPBLAS_R_16F, - ldc, - HIPBLAS_COMPUTE_32F, - HIPBLAS_GEMM_DEFAULT)); -#else - CUBLAS_ENFORCE(cublasSgemmEx( - context->cublas_handle(), - cu_trans_A, - CUBLAS_OP_N, - m, - 1, - k, - &alpha, - A, - CUDA_R_16F, - lda, - x, - CUDA_R_16F, - k, - &beta, - y, - CUDA_R_16F, - ldc)); -#endif // USE_ROCM - } else if (math_type == TensorProto_DataType_FLOAT16) { - const __half alpha_fp16 = at::Half(alpha); - const __half beta_fp16 = at::Half(beta); - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasHgemm( - context->cublas_handle(), - cu_trans_A, - CUBLAS_OP_N, - m, - 1, - k, - reinterpret_cast(&alpha_fp16), - reinterpret_cast(A), - lda, - reinterpret_cast(x), - k, - reinterpret_cast(&beta_fp16), - reinterpret_cast(y), - ldc)); - } else { - // fail - CAFFE_THROW("Unsupported math type"); - } -} - -#if !defined(USE_ROCM) - -// No change, but required. Defer to default CUDA engine -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C, - CUDAContext* context, - TensorProto::DataType math_type) { - return Gemm( - trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemm( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const at::Half* B, - const float beta, - at::Half* C, - CUDAContext* context, - TensorProto::DataType math_type) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (trans_A == CblasNoTrans) ? K : M; - const int ldb = (trans_B == CblasNoTrans) ? N : K; - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // enable TensorCore for this call on this handle - if (TensorCoreAvailable()) { - CUBLAS_ENFORCE( - cublasSetMathMode(context->cublas_handle(), CUBLAS_TENSOR_OP_MATH)); - } - - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasGemmEx( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - CUDA_R_16F, - ldb, - A, - CUDA_R_16F, - lda, - &beta, - C, - CUDA_R_16F, - N, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - // Now disable TensorCore math for subsequent calls to this handle - if (TensorCoreAvailable()) { - CUBLAS_ENFORCE( - cublasSetMathMode(context->cublas_handle(), CUBLAS_DEFAULT_MATH)); - } -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float** A, - const float** B, - const float beta, - float** C, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - B, - beta, - C, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void GemmBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half** A, - const at::Half** B, - const float beta, - at::Half** C, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - B, - beta, - C, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void -GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int A_stride, - const float* B, - const int B_stride, - const float beta, - float* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmStridedBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - A_stride, - B, - B_stride, - beta, - C, - C_stride, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void -GemmStridedBatched( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int batch_size, - const int M, - const int N, - const int K, - const float alpha, - const at::Half* A, - const int A_stride, - const at::Half* B, - const int B_stride, - const float beta, - at::Half* C, - const int C_stride, - CUDAContext* context, - TensorProto::DataType math_type) { - GemmStridedBatched( - trans_A, - trans_B, - batch_size, - M, - N, - K, - alpha, - A, - A_stride, - B, - B_stride, - beta, - C, - C_stride, - context, - math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y, - CUDAContext* context, - TensorProto::DataType math_type) { - Gemv( - trans_A, M, N, alpha, A, x, beta, y, context, math_type); -} - -template <> -CAFFE2_CUDA_EXPORT void Gemv( - const CBLAS_TRANSPOSE trans_A, - const int M, - const int N, - const float alpha, - const at::Half* A, - const at::Half* x, - const float beta, - at::Half* y, - CUDAContext* context, - TensorProto::DataType math_type) { - Gemv( - trans_A, M, N, alpha, A, x, beta, y, context, math_type); -} - -#endif - -template <> -CAFFE2_CUDA_EXPORT void GemmEx( - const CBLAS_TRANSPOSE trans_A, - const CBLAS_TRANSPOSE trans_B, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - CUDAContext* context) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const cublasOperation_t cu_trans_A = - (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - const cublasOperation_t cu_trans_B = - (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_ENFORCE( - cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST)); - CUBLAS_ENFORCE(cublasSgemm( - context->cublas_handle(), - cu_trans_B, - cu_trans_A, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc)); -} - -// Batched Add variants -namespace { - -template -__global__ void AddStripedBatchKernel( - const int N, - const T* first, - T* Y, - const int stripe, - const int batch) { - for (int j = 0; j < batch; j++) { - const T* x = first + j * stripe; - CUDA_1D_KERNEL_LOOP(i, N) { - float tmpY = convert::To(Y[i]); - tmpY += convert::To(x[i]); - Y[i] = convert::To(tmpY); - } - } -} -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void AddStripedBatch( \ - const int N, \ - const T* first, \ - T* Y, \ - const int stripe, \ - const int batch, \ - CUDAContext* context) { \ - AddStripedBatchKernel \ - <<cuda_stream()>>>(N, first, Y, stripe, batch); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - -CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float); -CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(at::Half); -#undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH - -namespace { -template -__global__ void -UniformShift(const size_t N, const float min, const float max, T* x) { - float scale = max - min; - CUDA_1D_KERNEL_LOOP(i, N) { - x[i] = convert::To(convert::To(x[i]) * scale + min); - } -} - -__global__ void -UniformIntFit(const size_t N, const int min, const int max, unsigned int* x) { - int* x_int = reinterpret_cast(x); - int range = (max - min + 1); - CUDA_1D_KERNEL_LOOP(i, N) { - x_int[i] = min + static_cast(x[i] % range); - } -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void RandUniform( - const size_t n, - const float min, - const float max, - float* r, - CUDAContext* context) { - CURAND_ENFORCE(curandGenerateUniform(context->curand_generator(), r, n)); - UniformShift - <<cuda_stream()>>>(n, min, max, r); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void RandUniform( - const size_t n, - const double min, - const double max, - double* r, - CUDAContext* context) { - CURAND_ENFORCE( - curandGenerateUniformDouble(context->curand_generator(), r, n)); - UniformShift - <<cuda_stream()>>>(n, min, max, r); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void RandUniform( - const size_t n, - const int min, - const int max, - int* r, - CUDAContext* context) { - CURAND_ENFORCE(curandGenerate( - context->curand_generator(), reinterpret_cast(r), n)); - UniformIntFit<<< - CAFFE_GET_BLOCKS(n), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>( - n, min, max, reinterpret_cast(r)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -size_t HandleOddLengthRandGaussian( - const size_t n, - const T mean, - const T std, - T* r, - CUDAContext* context) { - if (n % 2 == 1) { - std::default_random_engine generator; - std::normal_distribution distribution(mean, std); - const T random_value = distribution(generator); - Set(1, random_value, r + (n - 1), context); - return n - 1; - } - return n; -} - -template <> -CAFFE2_CUDA_EXPORT void RandGaussian( - const size_t n, - const float mean, - const float std, - float* r, - CUDAContext* context) { - // If n is odd, we add a random Gaussian value at the end manually - // and generate n-1 random values using curandGenerateNormal. - // curandGenerateNormal requires n to be even. - const size_t even_n = - HandleOddLengthRandGaussian(n, mean, std, r, context); - CURAND_ENFORCE( - curandGenerateNormal(context->curand_generator(), r, even_n, mean, std)); -} - -template <> -CAFFE2_CUDA_EXPORT void RandGaussian( - const size_t n, - const double mean, - const double std, - double* r, - CUDAContext* context) { - const size_t even_n = - HandleOddLengthRandGaussian(n, mean, std, r, context); - CURAND_ENFORCE(curandGenerateNormalDouble( - context->curand_generator(), r, even_n, mean, std)); -} - -template <> -CAFFE2_CUDA_EXPORT void Dot( - const int n, - const float* a, - const float* b, - float* y, - CUDAContext* context) { - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); - CUBLAS_ENFORCE(cublasSdot(context->cublas_handle(), n, a, 1, b, 1, y)); -} - -template <> -CAFFE2_CUDA_EXPORT void Dot( - const int n, - const at::Half* a, - const at::Half* b, - at::Half* y, - CUDAContext* context) { - // execute with 32-bit math - CUBLAS_ENFORCE(cublasSetPointerMode( - context->cublas_handle(), CUBLAS_POINTER_MODE_DEVICE)); - CUBLAS_ENFORCE(cublasDotEx( - context->cublas_handle(), - n, - a, - CUDA_R_16F, - 1, - b, - CUDA_R_16F, - 1, - y, - CUDA_R_16F, - CUDA_R_32F)); -} - -// A previous version of caffe2 used Thrust but it turns out that thrust -// reduction has an implicit scratch space allocation and deallocation, which -// may interfere with NCCL and create a deadlock. Hence we are using a custom -// reduction here. -#define SUM_KERNEL_NTHREADS 128 -template -__global__ void SumKernel(const int N, const T* X, T* Y, bool square) { - const int idx = threadIdx.x; - __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS]; - - reduction_buffer[idx] = 0; - - // A multilevel reduction. - // N -> 128 - if (!square) { - for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) { - reduction_buffer[idx] += convert::To(X[i]); - } - } else { - for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) { - float Xi = convert::To(X[i]); - reduction_buffer[idx] += Xi * Xi; - } - } - __syncthreads(); - // 128 -> 32 - if (idx < 32) { - reduction_buffer[idx] += reduction_buffer[idx + 32] + - reduction_buffer[idx + 64] + reduction_buffer[idx + 96]; - } - __syncthreads(); - // 32 -> 1 - if (idx == 0) { - float tmp = 0; - for (int i = 0; i < 32; ++i) { - tmp += reduction_buffer[i]; - } - *Y = convert::To(tmp); - } -} - -// According to the benchmarks script -// caffe2/caffe2/experiments/python/device_reduce_sum_bench.py, -// device reduce is slower for N <= 10000. -#define DEVICE_REDUCE_SIZE_THRESHOLD 10000 - -namespace { - -template -__global__ void SumConvertKernel(float* sum, T* dest) { - *dest = convert::To(*sum); -} - -template -CAFFE2_CUDA_EXPORT void SumGenericIter( - const int N, - IterT it, - T*& dest, - CUDAContext* context, - Tensor* scratch_ptr) { - size_t memRequired = 0; - cub::DeviceReduce::Sum( - nullptr, memRequired, it, dest, N, context->cuda_stream()); - auto buffer_size = - static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); - if (!dest) { - // allocate one more T at the end of scratch for dest - scratch_ptr->Resize(std::vector{buffer_size + 1}); - dest = scratch_ptr->template mutable_data() + buffer_size; - } else { - scratch_ptr->Resize(std::vector{buffer_size}); - } - cub::DeviceReduce::Sum( - static_cast(scratch_ptr->template mutable_data()), - memRequired, - it, - dest, - N, - context->cuda_stream()); -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Sum( - const int N, - const float* x, - float* y, - CUDAContext* context, - Tensor* scratch_ptr) { - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { - SumGenericIter(N, x, y, context, scratch_ptr); - } else { - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( - N, x, y, false); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -template <> -CAFFE2_CUDA_EXPORT void Sum( - const int N, - const int32_t* x, - int32_t* y, - CUDAContext* context, - Tensor* scratch_ptr) { - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { - SumGenericIter(N, x, y, context, scratch_ptr); - } else { - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( - N, x, y, false); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -namespace { -template -struct FloatTransform { - inline __host__ __device__ float operator()(const T v) const { - return convert::To(v); - } -}; -} // namespace - -#define CAFFE2_MATH_SUM_FUNC(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Sum( \ - const int N, \ - const T* x, \ - T* y, \ - CUDAContext* context, \ - Tensor* scratch_ptr) { \ - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ - FloatTransform transform; \ - cub::TransformInputIterator, const T*> it( \ - x, transform); \ - float* sum = nullptr; \ - SumGenericIter(N, it, sum, context, scratch_ptr); \ - SumConvertKernel<<<1, 1, 0, context->cuda_stream()>>>(sum, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( \ - N, x, y, false); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } - -CAFFE2_MATH_SUM_FUNC(at::Half) -#undef CAFFE2_MATH_SUM_FUNC - -namespace { -template -struct SqrTransform { - inline __host__ __device__ T operator()(const T v) const { - return v * v; - } -}; -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void SumSqr( - const int N, - const float* x, - float* y, - CUDAContext* context, - Tensor* scratch_ptr) { - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { - SqrTransform transform; - cub::TransformInputIterator, const float*> it( - x, transform); - SumGenericIter(N, it, y, context, scratch_ptr); - } else { - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( - N, x, y, true); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } -} - -#define CAFFE2_MATH_SUMSQR_FUNC(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void SumSqr( \ - const int N, \ - const T* x, \ - T* y, \ - CUDAContext* context, \ - Tensor* scratch_ptr) { \ - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ - FloatTransform float_transform; \ - cub::TransformInputIterator, const T*> \ - float_it(x, float_transform); \ - SqrTransform sqr_transform; \ - cub::TransformInputIterator< \ - float, \ - SqrTransform, \ - decltype(float_it)> \ - it(float_it, sqr_transform); \ - float* sum = nullptr; \ - SumGenericIter(N, it, sum, context, scratch_ptr); \ - SumConvertKernel<<<1, 1, 0, context->cuda_stream()>>>(sum, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } else { \ - SumKernel<<<1, SUM_KERNEL_NTHREADS, 0, context->cuda_stream()>>>( \ - N, x, y, true); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } \ - } - -CAFFE2_MATH_SUMSQR_FUNC(at::Half) -#undef CAFFE2_MATH_SUMSQR_FUNC -#undef DEVICE_REDUCE_SIZE_THRESHOLD - -namespace { -template -__global__ void -SelectKernel(const int N, const int D, const T* x, const int* idx, T* y) { - CUDA_1D_KERNEL_LOOP(i, N) { - y[i] = x[i * D + idx[i]]; - } -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Select( - const int N, - const int D, - const float* x, - const int* idx, - float* y, - CUDAContext* context) { - SelectKernel - <<cuda_stream()>>>(N, D, x, idx, y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Select( - const int N, - const int D, - const at::Half* x, - const int* idx, - at::Half* y, - CUDAContext* context) { - SelectKernel - <<cuda_stream()>>>(N, D, x, idx, y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -namespace { - -template -__global__ void Im2ColNCHWCUDAKernel( - const int n, - const int input_h, - const int input_w, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_h, - const int output_w, - const T* img_data, - T* col_data) { - CUDA_1D_KERNEL_LOOP(index, n) { - const int w_out = index % output_w; - const int h_index = index / output_w; - const int h_out = h_index % output_h; - const int channel_in = h_index / output_h; - const int channel_out = channel_in * kernel_h * kernel_w; - const int h_in = h_out * stride_h - pad_t; - const int w_in = w_out * stride_w - pad_l; - const int output_size = output_h * output_w; - T* col_data_ptr = - col_data + (channel_out * output_h + h_out) * output_w + w_out; - const T* img_data_ptr = - img_data + (channel_in * input_h + h_in) * input_w + w_in; - int dh = 0; - for (int i = 0; i < kernel_h; ++i) { - int dw = 0; - for (int j = 0; j < kernel_w; ++j) { - const int h = h_in + dh; - const int w = w_in + dw; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? __ldg(img_data_ptr + dh * input_w + dw) - : 0; -#else - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? img_data_ptr[dh * input_w + dw] - : 0; -#endif - col_data_ptr += output_size; - dw += dilation_w; - } - dh += dilation_h; - } - } -} - -template -__global__ void Im2ColNHWCCUDAKernel( - const int n, - const int input_h, - const int input_w, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_w, - const int channels, - const T* img_data, - T* col_data) { - CUDA_1D_KERNEL_LOOP(index, n) { - const int channel_in = index % channels; - const int w_out = index / channels % output_w; - const int h_out = index / channels / output_w; - const int h_in = h_out * stride_h - pad_t; - const int w_in = w_out * stride_w - pad_l; - T* col_data_ptr = col_data + - (h_out * output_w + w_out) * channels * kernel_h * kernel_w + - channel_in; - int dh = 0; - for (int i = 0; i < kernel_h; ++i) { - int dw = 0; - for (int j = 0; j < kernel_w; ++j) { - const int h = h_in + dh; - const int w = w_in + dw; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? __ldg(img_data + (h * input_w + w) * channels + channel_in) - : 0; -#else - *col_data_ptr = utils::IsAGeZeroAndALtB(h, input_h) && - utils::IsAGeZeroAndALtB(w, input_w) - ? img_data[(h * input_w + w) * channels + channel_in] - : 0; -#endif - col_data_ptr += channels; - dw += dilation_w; - } - dh += dilation_h; - } - } -} - -template -__global__ void Col2ImNCHWCUDAKernel( - const int n, - const int input_h, - const int input_w, - const int patch_h, - const int patch_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_h, - const int output_w, - const T* col_data, - T* img_data) { - const int dpatch_h = dilation_h * (patch_h - 1) + 1; - const int dpatch_w = dilation_w * (patch_w - 1) + 1; - - CUDA_1D_KERNEL_LOOP(index, n) { - T val = 0; - const int w = index % input_w + pad_l; - const int h = index / input_w % input_h + pad_t; - const int c = index / (input_h * input_w); - - // compute the start and end of the output - const int w_col_start = (w < dpatch_w) ? 0 : (w - dpatch_w) / stride_w + 1; - const int w_col_end = min(w / stride_w + 1, output_w); - const int h_col_start = (h < dpatch_h) ? 0 : (h - dpatch_h) / stride_h + 1; - const int h_col_end = min(h / stride_h + 1, output_h); - - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - int h_k = (h - h_col * stride_h); - int w_k = (w - w_col * stride_w); - if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { - h_k /= dilation_h; - w_k /= dilation_w; - const int col_data_index = - (((c * patch_h + h_k) * patch_w + w_k) * output_h + h_col) * - output_w + - w_col; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val += __ldg(col_data + col_data_index); -#else - val += col_data[col_data_index]; -#endif - } - } - } - img_data[index] = val; - } -} - -template -__global__ void Col2ImNHWCCUDAKernel( - const int n, - const int input_w, - const int channels, - const int patch_h, - const int patch_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int stride_h, - const int stride_w, - const int output_h, - const int output_w, - const T* col_data, - T* img_data) { - const int dpatch_h = dilation_h * (patch_h - 1) + 1; - const int dpatch_w = dilation_w * (patch_w - 1) + 1; - - CUDA_1D_KERNEL_LOOP(index, n) { - T val = 0; - const int c = index % channels; - const int w = index / channels % input_w + pad_l; - const int h = index / channels / input_w + pad_t; - // compute the start and end of the output - const int w_col_start = (w < dpatch_w) ? 0 : (w - dpatch_w) / stride_w + 1; - const int w_col_end = min(w / stride_w + 1, output_w); - const int h_col_start = (h < dpatch_h) ? 0 : (h - dpatch_h) / stride_h + 1; - const int h_col_end = min(h / stride_h + 1, output_h); - const int channels_col = patch_h * patch_w * channels; - - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - int h_k = h - h_col * stride_h; - int w_k = w - w_col * stride_w; - if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { - h_k /= dilation_h; - w_k /= dilation_w; - const int c_col = (h_k * patch_w + w_k) * channels + c; -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - val += __ldg( - col_data + (h_col * output_w + w_col) * channels_col + c_col); -#else - val += col_data[(h_col * output_w + w_col) * channels_col + c_col]; -#endif - } - } - } - img_data[index] = val; - } -} - -template -__global__ void Im2ColNdNCHWCUDAKernel( - const int outer_size, - const int inner_size, - const int kernel_size, - SimpleArray img_shape, - SimpleArray col_shape, - SimpleArray kernel_shape, - SimpleArray stride, - SimpleArray dilation, - SimpleArray pad, - const T* X_data, - T* Y_data) { - int d_offset[N]; - int d_iter[N]; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - int offset_i = i; -#pragma unroll - for (int d_i = N - 1; d_i >= 0; --d_i) { - d_offset[d_i] = offset_i % kernel_shape.data[d_i]; - offset_i /= kernel_shape.data[d_i]; - } - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - int offset_j = j; -#pragma unroll - for (int d_i = N - 1; d_i >= 0; --d_i) { - d_iter[d_i] = offset_j % col_shape.data[d_i + 1]; - offset_j /= col_shape.data[d_i + 1]; - } - const int col_index = i * inner_size + j; - int img_index = i / kernel_size; - bool is_padding = false; -#pragma unroll - for (int d_i = 0; d_i < N; ++d_i) { - const int d_img = d_iter[d_i] * stride.data[d_i] - pad.data[d_i] + - d_offset[d_i] * dilation.data[d_i]; - is_padding |= !utils::IsAGeZeroAndALtB(d_img, img_shape.data[d_i + 1]); - img_index = img_index * img_shape.data[d_i + 1] + d_img; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - if (!kCol2Im) { - Y_data[col_index] = is_padding ? 0 : __ldg(X_data + img_index); - } else if (!is_padding) { - gpu_atomic_add(Y_data + img_index, __ldg(X_data + col_index)); - } -#else - if (!kCol2Im) { - Y_data[col_index] = is_padding ? 0 : X_data[img_index]; - } else if (!is_padding) { - gpu_atomic_add(Y_data + img_index, X_data[col_index]); - } -#endif - } - } -} - -template -CAFFE2_CUDA_EXPORT void Im2ColNdNCHWCUDAImpl( - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* img_data, - float* col_data, - CUDAContext* context) { - const int outer_size = col_shape[0]; - const int inner_size = col_size / outer_size; - const int kernel_size = std::accumulate( - kernel_shape, kernel_shape + N, 1, std::multiplies()); - SimpleArray img_shape_array; - SimpleArray col_shape_array; - SimpleArray kernel_shape_array; - SimpleArray stride_array; - SimpleArray dilation_array; - SimpleArray pad_array; - std::memcpy(img_shape_array.data, img_shape, (N + 1) * sizeof(int)); - std::memcpy(col_shape_array.data, col_shape, (N + 1) * sizeof(int)); - std::memcpy(kernel_shape_array.data, kernel_shape, N * sizeof(int)); - std::memcpy(stride_array.data, stride, N * sizeof(int)); - std::memcpy(dilation_array.data, dilation, N * sizeof(int)); - std::memcpy(pad_array.data, pad, N * sizeof(int)); - Im2ColNdNCHWCUDAKernel - <<cuda_stream()>>>( - outer_size, - inner_size, - kernel_size, - img_shape_array, - col_shape_array, - kernel_shape_array, - stride_array, - dilation_array, - pad_array, - img_data, - col_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -CAFFE2_CUDA_EXPORT void Col2ImNdNCHWCUDAImpl( - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* col_data, - float* img_data, - CUDAContext* context) { - const int outer_size = col_shape[0]; - const int inner_size = col_size / outer_size; - const int kernel_size = std::accumulate( - kernel_shape, kernel_shape + N, 1, std::multiplies()); - SimpleArray img_shape_array; - SimpleArray col_shape_array; - SimpleArray kernel_shape_array; - SimpleArray stride_array; - SimpleArray dilation_array; - SimpleArray pad_array; - std::memcpy(img_shape_array.data, img_shape, (N + 1) * sizeof(int)); - std::memcpy(col_shape_array.data, col_shape, (N + 1) * sizeof(int)); - std::memcpy(kernel_shape_array.data, kernel_shape, N * sizeof(int)); - std::memcpy(stride_array.data, stride, N * sizeof(int)); - std::memcpy(dilation_array.data, dilation, N * sizeof(int)); - std::memcpy(pad_array.data, pad, N * sizeof(int)); - Set(img_size, 0, img_data, context); - Im2ColNdNCHWCUDAKernel - <<cuda_stream()>>>( - outer_size, - inner_size, - kernel_size, - img_shape_array, - col_shape_array, - kernel_shape_array, - stride_array, - dilation_array, - pad_array, - col_data, - img_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Im2Col( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* img_data, - float* col_data, - CUDAContext* context, - const int /* groups */) { - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = channels * output_h * output_w; - Im2ColNCHWCUDAKernel - <<cuda_stream()>>>( - num_kernels, - height, - width, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_h, - output_w, - img_data, - col_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Im2Col( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* img_data, - float* col_data, - CUDAContext* context, - const int groups) { - CAFFE_ENFORCE_EQ(groups, 1, "groups must be 1 for GPU NHWC Im2Col"); - - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = output_h * output_w * channels; - Im2ColNHWCCUDAKernel - <<cuda_stream()>>>( - num_kernels, - height, - width, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_w, - channels, - img_data, - col_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Col2Im( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* col_data, - float* img_data, - CUDAContext* context, - const int /* groups */) { - // In NCHW, the number of groups doesn't affect Col2Im. - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = channels * height * width; - Col2ImNCHWCUDAKernel - <<cuda_stream()>>>( - num_kernels, - height, - width, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_h, - output_w, - col_data, - img_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Col2Im( - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int dilation_h, - const int dilation_w, - const int pad_t, - const int pad_l, - const int pad_b, - const int pad_r, - const int stride_h, - const int stride_w, - const float* col_data, - float* img_data, - CUDAContext* context, - const int groups) { - CAFFE_ENFORCE_EQ(groups, 1, "groups must be 1 for GPU NHWC Col2Im"); - - const int dkernel_h = dilation_h * (kernel_h - 1) + 1; - const int dkernel_w = dilation_w * (kernel_w - 1) + 1; - const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; - const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; - const int num_kernels = height * width * channels; - Col2ImNHWCCUDAKernel - <<cuda_stream()>>>( - num_kernels, - width, - channels, - kernel_h, - kernel_w, - dilation_h, - dilation_w, - pad_t, - pad_l, - stride_h, - stride_w, - output_h, - output_w, - col_data, - img_data); -C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template <> -CAFFE2_CUDA_EXPORT void Im2ColNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* img_data, - float* col_data, - CUDAContext* context, - const int /* groups */) { - // In NCHW, the number of groups doesn't affect Im2Col. - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - N, - Im2ColNdNCHWCUDAImpl, - float, - img_size, - col_size, - img_shape, - col_shape, - kernel_shape, - stride, - dilation, - pad, - img_data, - col_data, - context); -} - -template <> -CAFFE2_CUDA_EXPORT void Im2ColNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* img_data, - float* col_data, - CUDAContext* context, - const int groups) { - CAFFE_NOT_IMPLEMENTED; -} - -template <> -CAFFE2_CUDA_EXPORT void Col2ImNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* col_data, - float* img_data, - CUDAContext* context, - int /* groups */) { - // In NCHW, the number of groups doesn't affect Col2Im. - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - N, - Col2ImNdNCHWCUDAImpl, - float, - img_size, - col_size, - img_shape, - col_shape, - kernel_shape, - stride, - dilation, - pad, - col_data, - img_data, - context); -} - -template <> -CAFFE2_CUDA_EXPORT void Col2ImNd( - const int N, - const int img_size, - const int col_size, - const int* img_shape, - const int* col_shape, - const int* kernel_shape, - const int* stride, - const int* dilation, - const int* pad, - const float* col_data, - float* img_data, - CUDAContext* context, - int groups) { - CAFFE_NOT_IMPLEMENTED; -} - -template <> -CAFFE2_CUDA_EXPORT void CopyMatrix( - const size_t itemsize, - const int M, - const int N, - const void* A, - const int lda, - void* B, - const int ldb, - CUDAContext* context, - TypeMeta::Copy copy) { - CAFFE_ENFORCE(!copy, "Copy constructor is not supported in CUDA context"); - cudaMemcpy2DAsync( - B, - ldb * itemsize, - A, - lda * itemsize, - N * itemsize, - M, - cudaMemcpyDeviceToDevice, - context->cuda_stream()); -} - -#define CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(T) \ - template <> \ - void CopyMatrix( \ - const int M, \ - const int N, \ - const T* A, \ - const int lda, \ - T* B, \ - const int ldb, \ - CUDAContext* context) { \ - if (M == 0 || N == 0) { \ - return; \ - } \ - cudaMemcpy2DAsync( \ - B, \ - sizeof(T) * ldb, \ - A, \ - sizeof(T) * lda, \ - sizeof(T) * N, \ - M, \ - cudaMemcpyDeviceToDevice, \ - context->cuda_stream()); \ - } -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(float) -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(double) -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(int) -CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX(int64_t) -#undef CAFFE2_SPECIALIZED_CUDA_COPY_MATRIX - -template <> -CAFFE2_CUDA_EXPORT void CopyVector( - const int N, - const float* src, - float* dst, - CUDAContext* context) { - if (src != dst && N > 0) { - C10_CUDA_CHECK(cudaMemcpyAsync( - dst, - src, - sizeof(float) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream())); - } -} - -template <> -CAFFE2_CUDA_EXPORT void CopyVector( - const int N, - const int* src, - int* dst, - CUDAContext* context) { - if (src != dst && N > 0) { - C10_CUDA_CHECK(cudaMemcpyAsync( - dst, - src, - sizeof(int) * N, - cudaMemcpyDeviceToDevice, - context->cuda_stream())); - } -} - -namespace { - -template -using BlockReduce = cub::BlockReduce; - -template -__global__ void RowwiseReduceKernel( - const int rows, - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int i = blockIdx.x; i < rows; i += gridDim.x) { - T val = init; - for (int j = threadIdx.x; j < cols; j += blockDim.x) { - val = reducer(X[i * cols + j], val); - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[i] = val * alpha; - } - __syncthreads(); - } -} - -template -__global__ void ColwiseReduceKernel( - const int rows, - const int cols, - const Reducer reducer, - const T init, - const T alpha, - const T* X, - T* Y) { - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int i = blockIdx.x; i < cols; i += gridDim.x) { - T val = init; - for (int j = threadIdx.x; j < rows; j += blockDim.x) { - val = reducer(X[j * cols + i], val); - } - val = BlockReduce(temp_storage).Reduce(val, reducer); - if (threadIdx.x == 0) { - Y[i] = val * alpha; - } - __syncthreads(); - } -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_ROWWISE_MAX(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void RowwiseMax( \ - const int N, const int D, const T* x, T* y, CUDAContext* context) { \ - RowwiseReduceKernel<<< \ - std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>( \ - N, D, cub::Max(), std::numeric_limits::lowest(), T(1), x, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_ROWWISE_MAX(float) -#undef CAFFE2_SPECIALIZED_CUDA_ROWWISE_MAX - -#define CAFFE2_SPECIALIZED_CUDA_COLWISE_MAX(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void ColwiseMax( \ - const int N, const int D, const T* x, T* y, CUDAContext* context) { \ - ColwiseReduceKernel<<< \ - std::min(D, CAFFE_MAXIMUM_NUM_BLOCKS), \ - CAFFE_CUDA_NUM_THREADS, \ - 0, \ - context->cuda_stream()>>>( \ - N, D, cub::Max(), std::numeric_limits::lowest(), T(1), x, y); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_COLWISE_MAX(float) -#undef CAFFE2_SPECIALIZED_CUDA_COLWISE_MAX - -namespace { -__global__ void -maximum_kernel(const int N, const float alpha, const float* x, float* y) { - CUDA_1D_KERNEL_LOOP(i, N) { - y[i] = fmaxf(x[i], alpha); - } -} -} // namespace - -template <> -CAFFE2_CUDA_EXPORT void Maximum( - const int N, - const float alpha, - const float* x, - float* y, - CUDAContext* context) { - maximum_kernel<<< - std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(N, alpha, x, y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -namespace { - -template -__global__ void BroadcastCUDAKernel( - const int Y_size, - const SimpleArray X_strides, - const SimpleArray Y_dims, - const T alpha, - const T* X, - T* Y) { - CUDA_1D_KERNEL_LOOP(Y_index, Y_size) { - int X_index = 0; - int Y_index_val = Y_index; -#pragma unroll - for (int i = D - 1; i >= 0; --i) { - int d; - FIXED_DIVISOR_DIV_MOD(Y_dims.data[i], Y_index_val, &Y_index_val, &d); - X_index += d * X_strides.data[i]; - } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - Y[Y_index] = __ldg(X + X_index) * alpha; -#else - Y[Y_index] = X[X_index] * alpha; -#endif - } -} - -template -CAFFE2_CUDA_EXPORT void BroadcastCUDAImpl( - const int X_ndim, - const int* X_dims, - const int* Y_dims, - const T alpha, - const T* X, - T* Y, - CUDAContext* context) { - SimpleArray X_strides_array; - SimpleArray Y_dims_array; - const int d = D - X_ndim; - std::fill(X_strides_array.data, X_strides_array.data + d, 0); - int cur_stride = 1; - for (int i = D - 1; i >= d; --i) { - CAFFE_ENFORCE(X_dims[i - d] == 1 || X_dims[i - d] == Y_dims[i]); - X_strides_array.data[i] = X_dims[i - d] == 1 ? 0 : cur_stride; - cur_stride *= X_dims[i - d]; - } - for (int i = 0; i < D; ++i) { - if (Y_dims[i] == 0) { - return; - } - Y_dims_array.data[i] = FIXED_DIVISOR(Y_dims[i]); - } - const int Y_size = - std::accumulate(Y_dims, Y_dims + D, 1, std::multiplies()); - BroadcastCUDAKernel - <<cuda_stream()>>>( - Y_size, X_strides_array, Y_dims_array, alpha, X, Y); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_BROADCAST(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void Broadcast( \ - const int X_ndim, \ - const int* X_dims, \ - const int Y_ndim, \ - const int* Y_dims, \ - const T alpha, \ - const T* X, \ - T* Y, \ - CUDAContext* context, \ - bool) { \ - CAFFE_ENFORCE_LE(X_ndim, Y_ndim); \ - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( \ - Y_ndim, \ - BroadcastCUDAImpl, \ - T, \ - X_ndim, \ - X_dims, \ - Y_dims, \ - alpha, \ - X, \ - Y, \ - context); \ - } -CAFFE2_SPECIALIZED_CUDA_BROADCAST(std::int32_t) -CAFFE2_SPECIALIZED_CUDA_BROADCAST(std::int64_t) -CAFFE2_SPECIALIZED_CUDA_BROADCAST(float) -CAFFE2_SPECIALIZED_CUDA_BROADCAST(double) -#undef CAFFE2_SPECIALIZED_CUDA_BROADCAST - -namespace { - -template -__global__ void -InvStdCUDAKernel(const int N, const T epsilon, const T* var, T* inv_std); - -#define DELEGATE_INV_STD_KERNEL_FUNCTION(T, Func) \ - template <> \ - __global__ void InvStdCUDAKernel( \ - const int N, const T epsilon, const T* var, T* inv_std) { \ - CUDA_1D_KERNEL_LOOP(i, N) { \ - inv_std[i] = Func(var[i] + epsilon); \ - } \ - } -DELEGATE_INV_STD_KERNEL_FUNCTION(float, rsqrtf) -#undef DELEGATE_INV_STD_KERNEL_FUNCTION - -} // namespace - -#define CAFFE2_SPECIALIZED_CUDA_INV_STD(T) \ - template <> \ - CAFFE2_CUDA_EXPORT void InvStd( \ - const int N, \ - const T epsilon, \ - const T* var, \ - T* inv_std, \ - CUDAContext* context) { \ - InvStdCUDAKernel \ - <<cuda_stream()>>>(N, epsilon, var, inv_std); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } -CAFFE2_SPECIALIZED_CUDA_INV_STD(float) -#undef CAFFE2_SPECIALIZED_CUDA_INV_STD - -} // namespace math -} // namespace caffe2 diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc deleted file mode 100644 index 330a724162cd..000000000000 --- a/caffe2/utils/math_gpu_test.cc +++ /dev/null @@ -1,429 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "caffe2/core/context.h" -#include "caffe2/core/context_gpu.h" -#include "caffe2/core/flags.h" -#include "caffe2/operators/utility_ops.h" -#include "caffe2/utils/math.h" - -C10_DECLARE_string(caffe_test_root); - -namespace caffe2 { - -void executeGpuBinaryOpTest( - int shapex0, - int shapex1, - int shapey, - std::function input0, - std::function input1, - std::function operation, - std::function correct_output) { - if (!HasCudaGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_CUDA); - CUDAContext context(option); - - Blob* blobx0 = ws.CreateBlob("X0"); - Blob* blobx1 = ws.CreateBlob("X1"); - Blob* bloby = ws.CreateBlob("Y"); - Blob* bloby_host = ws.CreateBlob("Y_host"); - - auto* tensorx0 = BlobGetMutableTensor(blobx0, CUDA); - auto* tensorx1 = BlobGetMutableTensor(blobx1, CUDA); - auto* tensory = BlobGetMutableTensor(bloby, CUDA); - - vector shapex0_vector{shapex0}; - vector shapex1_vector{shapex1}; - vector shapey_vector{shapey}; - - tensorx0->Resize(shapex0_vector); - tensorx1->Resize(shapex1_vector); - tensory->Resize(shapey_vector); - - for (int i = 0; i < shapex0; i++) { - math::Set( - 1, input0(i), tensorx0->mutable_data() + i, &context); - } - for (int i = 0; i < shapex1; i++) { - math::Set( - 1, input1(i), tensorx1->mutable_data() + i, &context); - } - operation( - shapex0, - shapex1, - tensorx0->template data(), - tensorx1->template data(), - tensory->mutable_data(), - &context); - context.FinishDeviceComputation(); - - // Copy result to CPU so we can inspect it - auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU); - tensory_host->CopyFrom(*tensory); - - for (int i = 0; i < shapey; ++i) { - EXPECT_EQ(tensory_host->data()[i], correct_output(i)); - } -} - -TEST(MathUtilGPUTest, testAddStripedBatch) { - if (!HasCudaGPU()) - return; - Workspace ws; - DeviceOption option; - option.set_device_type(PROTO_CUDA); - CUDAContext context(option); - Blob* blobx = ws.CreateBlob("X"); - Blob* bloby = ws.CreateBlob("Y"); - Blob* bloby_host = ws.CreateBlob("Y_host"); - - vector shapex{33 * 9, 25}; - vector shapey{33, 25}; - - auto* tensorx = BlobGetMutableTensor(blobx, CUDA); - tensorx->Resize(shapex); - int stripe = 33 * 25; - vector tot(33, 0.0); - for (int j = 0; j < 9; j++) { - // Have different values for each line - for (int k = 0; k < 33; k++) { - math::Set( - 33, - 1.0 + j + k, - tensorx->mutable_data() + j * stripe + k * 25, - &context); - tot[k] += 1.0 + j + k; - } - } - - auto* tensory = BlobGetMutableTensor(bloby, CUDA); - tensory->Resize(shapey); - math::Set( - stripe, 0.0, tensory->mutable_data(), &context); - - math::AddStripedBatch( - stripe, - tensorx->template data(), - tensory->mutable_data(), - stripe, - 9, - &context); - context.FinishDeviceComputation(); - - // Copy result to CPU so we can inspect it - auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU); - tensory_host->CopyFrom(*tensory); - - for (int k = 0; k < 33; k++) { - for (int i = 0; i < 25; i++) { - EXPECT_EQ(tensory_host->data()[k * 25 + i], tot[k]); - } - } -} - -TEST(MathUtilGPUTest, testReduceMin) { - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int /*i*/) { return 11.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMin(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 11.0f; }); - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int i) { return i == 3 ? 11.0f : 17.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMin(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 11.0f; }); -} - -TEST(MathUtilGPUTest, testReduceMax) { - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int /*i*/) { return 11.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMax(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 11.0f; }); - executeGpuBinaryOpTest( - 6, - 1, - 1, - [](int i) { return i == 3 ? 17.0f : 11.0f; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - Tensor aux(CUDA); - math::ReduceMax(N0, src0, dst, &aux, context); - }, - [](int /*i*/) { return 17.0f; }); -} - -TEST(MathUtilGPUTest, testCopyVector) { - executeGpuBinaryOpTest( - 6, - 1, - 6, - [](int i) { return 5.0f - i; }, - [](int /*i*/) { return 0.0f; }, - [](int N0, - int /*N1*/, - const float* src0, - const float* /*src1*/, - float* dst, - CUDAContext* context) { - math::CopyVector(N0, src0, dst, context); - }, - [](int i) { return 5.0f - i; }); -} - -namespace { - -class GemmBatchedGPUTest - : public testing::TestWithParam> { - protected: - void SetUp() override { - if (!HasCudaGPU()) { - return; - } - option_.set_device_type(PROTO_CUDA); - cuda_context_ = make_unique(option_); - Blob* X_blob = ws_.CreateBlob("X"); - Blob* W_blob = ws_.CreateBlob("W"); - Blob* Y_blob = ws_.CreateBlob("Y"); - X_ = BlobGetMutableTensor(X_blob, CUDA); - W_ = BlobGetMutableTensor(W_blob, CUDA); - Y_ = BlobGetMutableTensor(Y_blob, CUDA); - X_->Resize(std::vector{3, 5, 10}); - W_->Resize(std::vector{3, 6, 10}); - Y_->Resize(std::vector{3, 5, 6}); - math::Set( - X_->numel(), 1.0f, X_->mutable_data(), cuda_context_.get()); - math::Set( - W_->numel(), 1.0f, W_->mutable_data(), cuda_context_.get()); - trans_X_ = std::get<0>(GetParam()); - trans_W_ = std::get<1>(GetParam()); - } - - void RunGemmBatched(const float alpha, const float beta) { - const float* X_data = X_->template data(); - const float* W_data = W_->template data(); - float* Y_data = Y_->template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - std::array X_array = { - X_data, X_data + X_stride, X_data + 2 * X_stride}; - std::array W_array = { - W_data, W_data + W_stride, W_data + 2 * W_stride}; - std::array Y_array = { - Y_data, Y_data + Y_stride, Y_data + 2 * Y_stride}; - math::GemmBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_array.data(), - W_array.data(), - beta, - Y_array.data(), - cuda_context_.get()); - } - - void RunGemmStridedBatched(const float alpha, const float beta) { - const float* X_data = X_->template data(); - const float* W_data = W_->template data(); - float* Y_data = Y_->template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - math::GemmStridedBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_data, - X_stride, - W_data, - W_stride, - beta, - Y_data, - Y_stride, - cuda_context_.get()); - } - - void VerifyOutput(const float value) const { - Tensor Y_cpu(*Y_, CPU); - for (int i = 0; i < Y_cpu.numel(); ++i) { - EXPECT_FLOAT_EQ(value, Y_cpu.template data()[i]); - } - } - - Workspace ws_; - DeviceOption option_; - std::unique_ptr cuda_context_; - Tensor* X_ = nullptr; - Tensor* W_ = nullptr; - Tensor* Y_ = nullptr; - bool trans_X_; - bool trans_W_; -}; - -TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) { - if (!HasCudaGPU()) { - return; - } - RunGemmBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -TEST_P(GemmBatchedGPUTest, GemmStridedBatchedGPUFloatTest) { - if (!HasCudaGPU()) { - return; - } - RunGemmStridedBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmStridedBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmStridedBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -INSTANTIATE_TEST_CASE_P( - GemmBatchedGPUTrans, - GemmBatchedGPUTest, - testing::Combine(testing::Bool(), testing::Bool())); - -class BroadcastGPUTest : public testing::Test { - protected: - void SetUp() override { - if (!HasCudaGPU()) { - return; - } - option_.set_device_type(PROTO_CUDA); - cuda_context_ = make_unique(option_); - Blob* blob_x = ws_.CreateBlob("X"); - Blob* blob_y = ws_.CreateBlob("Y"); - X_ = BlobGetMutableTensor(blob_x, CUDA); - Y_ = BlobGetMutableTensor(blob_y, CUDA); - } - - void SetUpData( - const std::vector& X_dims, - const std::vector& Y_dims, - const std::vector& X_data) { - X_->Resize(X_dims); - Y_->Resize(Y_dims); - ASSERT_EQ(X_data.size(), X_->numel()); - cuda_context_->CopyFromCPU( - X_data.size(), X_data.data(), X_->mutable_data()); - } - - void VerifyResult(const std::vector& expected_output) { - Blob* blob_y_host = ws_.CreateBlob("Y_host"); - auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU); - Y_host->CopyFrom(*Y_); - ASSERT_EQ(expected_output.size(), Y_host->numel()); - for (std::size_t i = 0; i < expected_output.size(); ++i) { - EXPECT_FLOAT_EQ(expected_output[i], Y_host->data()[i]); - } - } - - void RunBroadcastTest( - const std::vector& X_dims, - const std::vector& Y_dims, - const std::vector& X_data, - const std::vector& Y_data) { - SetUpData(X_dims, Y_dims, X_data); - math::Broadcast( - X_dims.size(), - X_dims.data(), - Y_dims.size(), - Y_dims.data(), - 1.0f, - X_->data(), - Y_->mutable_data(), - cuda_context_.get()); - VerifyResult(Y_data); - } - - Workspace ws_; - DeviceOption option_; - std::unique_ptr cuda_context_; - Tensor* X_ = nullptr; - Tensor* Y_ = nullptr; -}; - -TEST_F(BroadcastGPUTest, BroadcastGPUFloatTest) { - if (!HasCudaGPU()) { - return; - } - RunBroadcastTest({2}, {2}, {1.0f, 2.0f}, {1.0f, 2.0f}); - RunBroadcastTest({1}, {2}, {1.0f}, {1.0f, 1.0f}); - RunBroadcastTest({1}, {2, 2}, {1.0f}, {1.0f, 1.0f, 1.0f, 1.0f}); - RunBroadcastTest({2, 1}, {2, 2}, {1.0f, 2.0f}, {1.0f, 1.0f, 2.0f, 2.0f}); - RunBroadcastTest( - {2, 1}, - {2, 2, 2}, - {1.0f, 2.0f}, - {1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f}); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc deleted file mode 100644 index 0389a10f29e0..000000000000 --- a/caffe2/utils/math_test.cc +++ /dev/null @@ -1,523 +0,0 @@ -#include -#include -#include - -#include - -#include "caffe2/core/blob.h" -#include "caffe2/core/context.h" -#include "caffe2/core/tensor.h" -#include "caffe2/proto/caffe2_pb.h" -#include "caffe2/utils/conversions.h" -#include "caffe2/utils/math.h" - -#include - -namespace caffe2 { - -TEST(MathTest, GemmNoTransNoTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor X(std::vector{5, 10}, CPU); - Tensor W(std::vector{10, 6}, CPU); - Tensor Y(std::vector{5, 6}, CPU); - EXPECT_EQ(X.numel(), 50); - EXPECT_EQ(W.numel(), 60); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - math::Set( - W.numel(), 1, W.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - for (int i = 0; i < W.numel(); ++i) { - TORCH_CHECK_EQ(W.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kZero, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 10) << i; - } - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 15) << i; - } - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasNoTrans, - 5, - 6, - 10, - kPointFive, - X.data(), - W.data(), - kOne, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 20) << i; - } -} - -TEST(MathTest, GemmNoTransTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor X(std::vector{5, 10}, CPU); - Tensor W(std::vector{6, 10}, CPU); - Tensor Y(std::vector{5, 6}, CPU); - EXPECT_EQ(X.numel(), 50); - EXPECT_EQ(W.numel(), 60); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - math::Set( - W.numel(), 1, W.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - for (int i = 0; i < W.numel(); ++i) { - TORCH_CHECK_EQ(W.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kZero, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 10) << i; - } - // Test Accumulate - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kOne, - X.data(), - W.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 15) << i; - } - math::Gemm( - CblasNoTrans, - CblasTrans, - 5, - 6, - 10, - kPointFive, - X.data(), - W.data(), - kOne, - Y.mutable_data(), - &cpu_context); - EXPECT_EQ(Y.numel(), 30); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 20) << i; - } -} - -namespace { - -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class GemmBatchedTest - : public testing::TestWithParam> { - protected: - void SetUp() override { - cpu_context_ = make_unique(option_); - ReinitializeTensor( - &X_, std::vector{3, 5, 10}, at::dtype().device(CPU)); - ReinitializeTensor( - &W_, std::vector{3, 6, 10}, at::dtype().device(CPU)); - ReinitializeTensor( - &Y_, std::vector{3, 5, 6}, at::dtype().device(CPU)); - math::Set( - X_.numel(), 1, X_.mutable_data(), cpu_context_.get()); - math::Set( - W_.numel(), 1, W_.mutable_data(), cpu_context_.get()); - trans_X_ = std::get<0>(GetParam()); - trans_W_ = std::get<1>(GetParam()); - } - - void RunGemmBatched(const float alpha, const float beta) { - const float* X_data = X_.template data(); - const float* W_data = W_.template data(); - float* Y_data = Y_.template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - std::array X_array = { - X_data, X_data + X_stride, X_data + 2 * X_stride}; - std::array W_array = { - W_data, W_data + W_stride, W_data + 2 * W_stride}; - std::array Y_array = { - Y_data, Y_data + Y_stride, Y_data + 2 * Y_stride}; - math::GemmBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_array.data(), - W_array.data(), - beta, - Y_array.data(), - cpu_context_.get()); - } - - void RunGemmStridedBatched(const float alpha, const float beta) { - const float* X_data = X_.template data(); - const float* W_data = W_.template data(); - float* Y_data = Y_.template mutable_data(); - const int X_stride = 5 * 10; - const int W_stride = 6 * 10; - const int Y_stride = 5 * 6; - math::GemmStridedBatched( - trans_X_ ? CblasTrans : CblasNoTrans, - trans_W_ ? CblasTrans : CblasNoTrans, - 3, - 5, - 6, - 10, - alpha, - X_data, - X_stride, - W_data, - W_stride, - beta, - Y_data, - Y_stride, - cpu_context_.get()); - } - - void VerifyOutput(const float value) const { - for (int i = 0; i < Y_.numel(); ++i) { - EXPECT_FLOAT_EQ(value, Y_.template data()[i]); - } - } - - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - DeviceOption option_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr cpu_context_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor X_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor W_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor Y_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - bool trans_X_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - bool trans_W_; -}; - -TEST_P(GemmBatchedTest, GemmBatchedFloatTest) { - RunGemmBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -TEST_P(GemmBatchedTest, GemmStridedBatchedFloatTest) { - RunGemmStridedBatched(1.0f, 0.0f); - VerifyOutput(10.0f); - RunGemmStridedBatched(1.0f, 0.5f); - VerifyOutput(15.0f); - RunGemmStridedBatched(0.5f, 1.0f); - VerifyOutput(20.0f); -} - -INSTANTIATE_TEST_CASE_P( - GemmBatchedTrans, - GemmBatchedTest, - testing::Combine(testing::Bool(), testing::Bool())); - -} // namespace - -TEST(MathTest, GemvNoTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor A(std::vector{5, 10}, CPU); - Tensor X(std::vector{10}, CPU); - Tensor Y(std::vector{5}, CPU); - EXPECT_EQ(A.numel(), 50); - EXPECT_EQ(X.numel(), 10); - math::Set( - A.numel(), 1, A.mutable_data(), &cpu_context); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 5); - for (int i = 0; i < A.numel(); ++i) { - TORCH_CHECK_EQ(A.data()[i], 1); - } - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - A.data(), - X.data(), - kZero, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 10) << i; - } - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kOne, - A.data(), - X.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 15) << i; - } - // Test Accumulate - math::Gemv( - CblasNoTrans, - 5, - 10, - kPointFive, - A.data(), - X.data(), - kOne, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 20) << i; - } -} - -TEST(MathTest, GemvTrans) { - DeviceOption option; - CPUContext cpu_context(option); - Tensor A(std::vector{6, 10}, CPU); - Tensor X(std::vector{6}, CPU); - Tensor Y(std::vector{10}, CPU); - EXPECT_EQ(A.numel(), 60); - EXPECT_EQ(X.numel(), 6); - math::Set( - A.numel(), 1, A.mutable_data(), &cpu_context); - math::Set( - X.numel(), 1, X.mutable_data(), &cpu_context); - EXPECT_EQ(Y.numel(), 10); - for (int i = 0; i < A.numel(); ++i) { - TORCH_CHECK_EQ(A.data()[i], 1); - } - for (int i = 0; i < X.numel(); ++i) { - TORCH_CHECK_EQ(X.data()[i], 1); - } - - const float kOne = 1.0; - const float kPointFive = 0.5; - const float kZero = 0.0; - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - A.data(), - X.data(), - kZero, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 6) << i; - } - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kOne, - A.data(), - X.data(), - kPointFive, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 9) << i; - } - // Test Accumulate - math::Gemv( - CblasTrans, - 6, - 10, - kPointFive, - A.data(), - X.data(), - kOne, - Y.mutable_data(), - &cpu_context); - for (int i = 0; i < Y.numel(); ++i) { - TORCH_CHECK_EQ(Y.data()[i], 12) << i; - } -} - -TEST(MathTest, FloatToHalfConversion) { - float a = 1.0f; - float b = 1.75f; - float c = 128.125f; - - float converted_a = static_cast(at::Half(a)); - float converted_b = static_cast(at::Half(b)); - float converted_c = static_cast(at::Half(c)); - - TORCH_CHECK_EQ(a, converted_a); - TORCH_CHECK_EQ(b, converted_b); - TORCH_CHECK_EQ(c, converted_c); -} - -namespace { - -class BroadcastTest : public testing::Test { - protected: - void SetUp() override { - cpu_context_ = make_unique(option_); - } - - void RunBroadcastTest( - const std::vector& X_dims, - const std::vector& Y_dims, - const std::vector& X_data, - const std::vector& Y_data) { - std::vector X_dims_64; - std::vector Y_dims_64; - std::copy(X_dims.cbegin(), X_dims.cend(), std::back_inserter(X_dims_64)); - std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64)); - ReinitializeTensor(&X_, X_dims_64, at::dtype().device(CPU)); - ReinitializeTensor(&Y_, Y_dims_64, at::dtype().device(CPU)); - ASSERT_EQ(X_data.size(), X_.numel()); - cpu_context_->CopyFromCPU( - X_data.size(), X_data.data(), X_.mutable_data()); - for (bool allow_broadcast_fastpath : {false, true}) { - math::Broadcast( - X_dims.size(), - X_dims.data(), - Y_dims.size(), - Y_dims.data(), - 1.0f, - X_.data(), - Y_.mutable_data(), - cpu_context_.get(), - allow_broadcast_fastpath); - ASSERT_EQ(Y_data.size(), Y_.numel()); - for (const auto i : c10::irange(Y_data.size())) { - EXPECT_FLOAT_EQ(Y_data[i], Y_.data()[i]); - } - } - } - - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - DeviceOption option_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr cpu_context_; - - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor X_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - Tensor Y_; -}; - -TEST_F(BroadcastTest, BroadcastFloatTest) { - RunBroadcastTest({2}, {2}, {1.0f, 2.0f}, {1.0f, 2.0f}); - RunBroadcastTest({1}, {2}, {1.0f}, {1.0f, 1.0f}); - RunBroadcastTest({1}, {2, 2}, {1.0f}, {1.0f, 1.0f, 1.0f, 1.0f}); - RunBroadcastTest({2, 1}, {2, 2}, {1.0f, 2.0f}, {1.0f, 1.0f, 2.0f, 2.0f}); - RunBroadcastTest({1, 2}, {2, 2}, {1.0f, 2.0f}, {1.0f, 2.0f, 1.0f, 2.0f}); - RunBroadcastTest( - {2, 1}, - {2, 2, 2}, - {1.0f, 2.0f}, - {1.0f, 1.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 2.0f}); - RunBroadcastTest( - {1, 2}, - {2, 2, 2}, - {1.0f, 2.0f}, - {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}); -} - -class RandFixedSumTest : public testing::Test { - protected: - void SetUp() override { - cpu_context_ = make_unique(option_); - } - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - DeviceOption option_; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unique_ptr cpu_context_; -}; - -TEST_F(RandFixedSumTest, UpperBound) { - std::vector l(20); - math::RandFixedSum( - 20, 1, 1000, 1000, l.data(), cpu_context_.get()); -} - -} // namespace - -} // namespace caffe2 diff --git a/caffe2/utils/murmur_hash3.cc b/caffe2/utils/murmur_hash3.cc deleted file mode 100644 index 68cce1fdd34e..000000000000 --- a/caffe2/utils/murmur_hash3.cc +++ /dev/null @@ -1,450 +0,0 @@ -//----------------------------------------------------------------------------- -// MurmurHash3 was written by Austin Appleby, and is placed in the public -// domain. The author hereby disclaims copyright to this source code. - -// Note - The x86 and x64 versions do _not_ produce the same results, as the -// algorithms are optimized for their respective platforms. You can still -// compile and run any of them on any platform, but your performance with the -// non-native version will be less than optimal. - -#include "caffe2/utils/murmur_hash3.h" - -//----------------------------------------------------------------------------- -// Platform-specific functions and macros - -// Microsoft Visual Studio - -#if defined(_MSC_VER) - -#define FORCE_INLINE __forceinline - -#include - -#define ROTL32(x, y) _rotl(x, y) -#define ROTL64(x, y) _rotl64(x, y) - -#define BIG_CONSTANT(x) (x) - -// Other compilers - -#else // defined(_MSC_VER) - -#define FORCE_INLINE inline __attribute__((__always_inline__)) - -inline uint32_t rotl32(uint32_t x, int8_t r) { - return (x << r) | (x >> (32 - r)); -} - -inline uint64_t rotl64(uint64_t x, int8_t r) { - return (x << r) | (x >> (64 - r)); -} - -#define ROTL32(x, y) rotl32(x, y) -#define ROTL64(x, y) rotl64(x, y) - -#define BIG_CONSTANT(x) (x##LLU) - -#endif // !defined(_MSC_VER) - -//----------------------------------------------------------------------------- -// Block read - if your platform needs to do endian-swapping or can only -// handle aligned reads, do the conversion here - -FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { - return p[i]; -} - -FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { - return p[i]; -} - -//----------------------------------------------------------------------------- -// Finalization mix - force all bits of a hash block to avalanche - -FORCE_INLINE uint32_t fmix32(uint32_t h) { - h ^= h >> 16; - h *= 0x85ebca6b; - h ^= h >> 13; - h *= 0xc2b2ae35; - h ^= h >> 16; - - return h; -} - -//---------- - -FORCE_INLINE uint64_t fmix64(uint64_t k) { - k ^= k >> 33; - k *= BIG_CONSTANT(0xff51afd7ed558ccd); - k ^= k >> 33; - k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); - k ^= k >> 33; - - return k; -} - -namespace caffe2 { - -void MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 4; - - uint32_t h1 = seed; - - const uint32_t c1 = 0xcc9e2d51; - const uint32_t c2 = 0x1b873593; - - //---------- - // body - - const uint32_t* blocks = (const uint32_t*)(data + nblocks * 4); - - for (int i = -nblocks; i; i++) { - uint32_t k1 = getblock32(blocks, i); - - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - - h1 ^= k1; - h1 = ROTL32(h1, 13); - h1 = h1 * 5 + 0xe6546b64; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + nblocks * 4); - - uint32_t k1 = 0; - - switch (len & 3) { - case 3: - k1 ^= tail[2] << 16; - [[fallthrough]]; - case 2: - k1 ^= tail[1] << 8; - [[fallthrough]]; - case 1: - k1 ^= tail[0]; - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - - h1 = fmix32(h1); - - *(uint32_t*)out = h1; -} - -//----------------------------------------------------------------------------- - -void MurmurHash3_x86_128( - const void* key, - const int len, - uint32_t seed, - void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 16; - - uint32_t h1 = seed; - uint32_t h2 = seed; - uint32_t h3 = seed; - uint32_t h4 = seed; - - const uint32_t c1 = 0x239b961b; - const uint32_t c2 = 0xab0e9789; - const uint32_t c3 = 0x38b34ae5; - const uint32_t c4 = 0xa1e38b93; - - //---------- - // body - - const uint32_t* blocks = (const uint32_t*)(data + nblocks * 16); - - for (int i = -nblocks; i; i++) { - uint32_t k1 = getblock32(blocks, i * 4 + 0); - uint32_t k2 = getblock32(blocks, i * 4 + 1); - uint32_t k3 = getblock32(blocks, i * 4 + 2); - uint32_t k4 = getblock32(blocks, i * 4 + 3); - - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - - h1 = ROTL32(h1, 19); - h1 += h2; - h1 = h1 * 5 + 0x561ccd1b; - - k2 *= c2; - k2 = ROTL32(k2, 16); - k2 *= c3; - h2 ^= k2; - - h2 = ROTL32(h2, 17); - h2 += h3; - h2 = h2 * 5 + 0x0bcaa747; - - k3 *= c3; - k3 = ROTL32(k3, 17); - k3 *= c4; - h3 ^= k3; - - h3 = ROTL32(h3, 15); - h3 += h4; - h3 = h3 * 5 + 0x96cd1c35; - - k4 *= c4; - k4 = ROTL32(k4, 18); - k4 *= c1; - h4 ^= k4; - - h4 = ROTL32(h4, 13); - h4 += h1; - h4 = h4 * 5 + 0x32ac3b17; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + nblocks * 16); - - uint32_t k1 = 0; - uint32_t k2 = 0; - uint32_t k3 = 0; - uint32_t k4 = 0; - - switch (len & 15) { - case 15: - k4 ^= tail[14] << 16; - [[fallthrough]]; - case 14: - k4 ^= tail[13] << 8; - [[fallthrough]]; - case 13: - k4 ^= tail[12] << 0; - k4 *= c4; - k4 = ROTL32(k4, 18); - k4 *= c1; - h4 ^= k4; - [[fallthrough]]; - - case 12: - k3 ^= tail[11] << 24; - [[fallthrough]]; - case 11: - k3 ^= tail[10] << 16; - [[fallthrough]]; - case 10: - k3 ^= tail[9] << 8; - [[fallthrough]]; - case 9: - k3 ^= tail[8] << 0; - k3 *= c3; - k3 = ROTL32(k3, 17); - k3 *= c4; - h3 ^= k3; - [[fallthrough]]; - - case 8: - k2 ^= tail[7] << 24; - [[fallthrough]]; - case 7: - k2 ^= tail[6] << 16; - [[fallthrough]]; - case 6: - k2 ^= tail[5] << 8; - [[fallthrough]]; - case 5: - k2 ^= tail[4] << 0; - k2 *= c2; - k2 = ROTL32(k2, 16); - k2 *= c3; - h2 ^= k2; - [[fallthrough]]; - - case 4: - k1 ^= tail[3] << 24; - [[fallthrough]]; - case 3: - k1 ^= tail[2] << 16; - [[fallthrough]]; - case 2: - k1 ^= tail[1] << 8; - [[fallthrough]]; - case 1: - k1 ^= tail[0] << 0; - k1 *= c1; - k1 = ROTL32(k1, 15); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - h2 ^= len; - h3 ^= len; - h4 ^= len; - - h1 += h2; - h1 += h3; - h1 += h4; - h2 += h1; - h3 += h1; - h4 += h1; - - h1 = fmix32(h1); - h2 = fmix32(h2); - h3 = fmix32(h3); - h4 = fmix32(h4); - - h1 += h2; - h1 += h3; - h1 += h4; - h2 += h1; - h3 += h1; - h4 += h1; - - ((uint32_t*)out)[0] = h1; - ((uint32_t*)out)[1] = h2; - ((uint32_t*)out)[2] = h3; - ((uint32_t*)out)[3] = h4; -} - -//----------------------------------------------------------------------------- - -void MurmurHash3_x64_128( - const void* key, - const int len, - const uint32_t seed, - void* out) { - const uint8_t* data = (const uint8_t*)key; - const int nblocks = len / 16; - - uint64_t h1 = seed; - uint64_t h2 = seed; - - const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); - const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); - - //---------- - // body - - const uint64_t* blocks = (const uint64_t*)(data); - - for (int i = 0; i < nblocks; i++) { - uint64_t k1 = getblock64(blocks, i * 2 + 0); - uint64_t k2 = getblock64(blocks, i * 2 + 1); - - k1 *= c1; - k1 = ROTL64(k1, 31); - k1 *= c2; - h1 ^= k1; - - h1 = ROTL64(h1, 27); - h1 += h2; - h1 = h1 * 5 + 0x52dce729; - - k2 *= c2; - k2 = ROTL64(k2, 33); - k2 *= c1; - h2 ^= k2; - - h2 = ROTL64(h2, 31); - h2 += h1; - h2 = h2 * 5 + 0x38495ab5; - } - - //---------- - // tail - - const uint8_t* tail = (const uint8_t*)(data + nblocks * 16); - - uint64_t k1 = 0; - uint64_t k2 = 0; - - switch (len & 15) { - case 15: - k2 ^= ((uint64_t)tail[14]) << 48; - [[fallthrough]]; - case 14: - k2 ^= ((uint64_t)tail[13]) << 40; - [[fallthrough]]; - case 13: - k2 ^= ((uint64_t)tail[12]) << 32; - [[fallthrough]]; - case 12: - k2 ^= ((uint64_t)tail[11]) << 24; - [[fallthrough]]; - case 11: - k2 ^= ((uint64_t)tail[10]) << 16; - [[fallthrough]]; - case 10: - k2 ^= ((uint64_t)tail[9]) << 8; - [[fallthrough]]; - case 9: - k2 ^= ((uint64_t)tail[8]) << 0; - k2 *= c2; - k2 = ROTL64(k2, 33); - k2 *= c1; - h2 ^= k2; - [[fallthrough]]; - - case 8: - k1 ^= ((uint64_t)tail[7]) << 56; - [[fallthrough]]; - case 7: - k1 ^= ((uint64_t)tail[6]) << 48; - [[fallthrough]]; - case 6: - k1 ^= ((uint64_t)tail[5]) << 40; - [[fallthrough]]; - case 5: - k1 ^= ((uint64_t)tail[4]) << 32; - [[fallthrough]]; - case 4: - k1 ^= ((uint64_t)tail[3]) << 24; - [[fallthrough]]; - case 3: - k1 ^= ((uint64_t)tail[2]) << 16; - [[fallthrough]]; - case 2: - k1 ^= ((uint64_t)tail[1]) << 8; - [[fallthrough]]; - case 1: - k1 ^= ((uint64_t)tail[0]) << 0; - k1 *= c1; - k1 = ROTL64(k1, 31); - k1 *= c2; - h1 ^= k1; - }; - - //---------- - // finalization - - h1 ^= len; - h2 ^= len; - - h1 += h2; - h2 += h1; - - h1 = fmix64(h1); - h2 = fmix64(h2); - - h1 += h2; - h2 += h1; - - ((uint64_t*)out)[0] = h1; - ((uint64_t*)out)[1] = h2; -} - -} // namespace caffe2 diff --git a/caffe2/utils/murmur_hash3.h b/caffe2/utils/murmur_hash3.h deleted file mode 100644 index ea67e7151c0b..000000000000 --- a/caffe2/utils/murmur_hash3.h +++ /dev/null @@ -1,34 +0,0 @@ -//----------------------------------------------------------------------------- -// MurmurHash3 was written by Austin Appleby, and is placed in the public -// domain. The author hereby disclaims copyright to this source code. - -#pragma once - -//----------------------------------------------------------------------------- -// Platform-specific functions and macros - -// Microsoft Visual Studio - -#if defined(_MSC_VER) && (_MSC_VER < 1600) - -typedef unsigned char uint8_t; -typedef unsigned int uint32_t; -typedef unsigned __int64 uint64_t; - -// Other compilers - -#else // defined(_MSC_VER) - -#include - -#endif // !defined(_MSC_VER) - -namespace caffe2 { - -void MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out); - -void MurmurHash3_x86_128(const void* key, int len, uint32_t seed, void* out); - -void MurmurHash3_x64_128(const void* key, int len, uint32_t seed, void* out); - -} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc deleted file mode 100644 index 8fc81586f3ca..000000000000 --- a/caffe2/utils/proto_utils.cc +++ /dev/null @@ -1,715 +0,0 @@ -#include "caffe2/utils/proto_utils.h" - -#include - -#include -#include -#include -#include - -#if defined(_MSC_VER) -#include -#else -#include -#endif - -#include - -#ifndef CAFFE2_USE_LITE_PROTO -#include -#include -#else -#include -#endif // !CAFFE2_USE_LITE_PROTO - -#include - -using ::google::protobuf::MessageLite; - -namespace caffe2 { - -C10_EXPORT std::string DeviceTypeName(const int32_t& d) { - return at::DeviceTypeName(static_cast(d)); -} - -void setTotalBytesLimit(::google::protobuf::io::CodedInputStream& stream, int bytes_limit, int warning_threshold) { - #if GOOGLE_PROTOBUF_VERSION >= 3011000 - // Only take one parameter since protobuf 3.11 - stream.SetTotalBytesLimit(bytes_limit); - #else - stream.SetTotalBytesLimit(bytes_limit, warning_threshold); - #endif -} - -C10_EXPORT int DeviceId(const DeviceOption& option) { - switch (option.device_type()) { - case PROTO_CPU: - return option.numa_node_id(); - case PROTO_CUDA: - case PROTO_HIP: - return option.device_id(); - case PROTO_MKLDNN: - return option.numa_node_id(); - default: - CAFFE_THROW("Unknown device id for device type: ", option.device_type()); - } -} - -C10_EXPORT bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) { - return ( - lhs.device_type() == rhs.device_type() && - lhs.device_id() == rhs.device_id() && - lhs.node_name() == rhs.node_name() && - lhs.numa_node_id() == rhs.numa_node_id()); -} - -C10_EXPORT bool IsCPUDeviceType(int device_type) { - static const std::unordered_set cpu_types{ - PROTO_CPU, - PROTO_MKLDNN, - PROTO_IDEEP, - }; - return cpu_types.count(device_type); -} - -C10_EXPORT bool IsGPUDeviceType(int device_type) { - static const std::unordered_set gpu_types{ - PROTO_CUDA, - PROTO_HIP, - }; - return gpu_types.count(device_type); -} - -C10_EXPORT bool ReadStringFromFile(const char* filename, string* str) { - std::ifstream ifs(filename, std::ios::in); - if (!ifs) { - VLOG(1) << "File cannot be opened: " << filename - << " error: " << ifs.rdstate(); - return false; - } - ifs.seekg(0, std::ios::end); - size_t n = ifs.tellg(); - str->resize(n); - ifs.seekg(0); - ifs.read(&(*str)[0], n); - return true; -} - -C10_EXPORT bool WriteStringToFile(const string& str, const char* filename) { - std::ofstream ofs(filename, std::ios::out | std::ios::trunc); - if (!ofs.is_open()) { - VLOG(1) << "File cannot be created: " << filename - << " error: " << ofs.rdstate(); - return false; - } - ofs << str; - return true; -} - -// IO-specific proto functions: we will deal with the protocol buffer lite and -// full versions differently. - -#ifdef CAFFE2_USE_LITE_PROTO - -// Lite runtime. - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const string& filename) - : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { - ifs_.close(); - } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -C10_EXPORT string ProtoDebugString(const MessageLite& proto) { - string serialized = proto.SerializeAsString(); - for (char& c : serialized) { - if (c < 0x20 || c >= 0x7f) { - c = '?'; - } - } - return serialized; -} - -C10_EXPORT bool ParseProtoFromLargeString( - const string& str, - MessageLite* proto) { - ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size()); - ::google::protobuf::io::CodedInputStream coded_stream(&input_stream); - // Set PlanDef message size limit to 2G. - setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -C10_EXPORT bool ReadProtoFromBinaryFile( - const char* filename, - MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(filename)); - stream.SetOwnsCopyingStream(true); - // Total bytes hard limit / warning limit are set to 2GB and 512MB - // respectively. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -C10_EXPORT void WriteProtoToBinaryFile( - const MessageLite& /*proto*/, - const char* /*filename*/) { - LOG(FATAL) << "Not implemented yet."; -} - -#else // CAFFE2_USE_LITE_PROTO - -// Full protocol buffer. - -using ::google::protobuf::Message; -using ::google::protobuf::io::CodedInputStream; -using ::google::protobuf::io::CodedOutputStream; -using ::google::protobuf::io::FileInputStream; -using ::google::protobuf::io::FileOutputStream; -using ::google::protobuf::io::ZeroCopyInputStream; -using ::google::protobuf::io::ZeroCopyOutputStream; - -namespace TextFormat { -C10_EXPORT bool ParseFromString(const string& spec, Message* proto) { - string bc_spec = spec; - - { - auto num_replaced = c10::ReplaceAll(bc_spec, "cuda_gpu_id", "device_id"); - if (num_replaced) { - LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and " - << "it has " << num_replaced - << " places using the deprecated field name 'cuda_gpu_id'!\n" - << spec - << "\nPlease re-export your model in Protobuf binary format " - << "to make it backward compatible for field renaming."; - } - } - - return ::google::protobuf::TextFormat::ParseFromString( - // NOLINTNEXTLINE(performance-move-const-arg) - std::move(bc_spec), proto); -} -} // namespace TextFormat - -C10_EXPORT string ProtoDebugString(const Message& proto) { - return proto.ShortDebugString(); -} - -C10_EXPORT bool ParseProtoFromLargeString(const string& str, Message* proto) { - ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size()); - ::google::protobuf::io::CodedInputStream coded_stream(&input_stream); - // Set PlanDef message size limit to 2G. - setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -C10_EXPORT bool ReadProtoFromTextFile(const char* filename, Message* proto) { - int fd = open(filename, O_RDONLY); - CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename); - FileInputStream* input = new FileInputStream(fd); - bool success = google::protobuf::TextFormat::Parse(input, proto); - delete input; - close(fd); - return success; -} - -C10_EXPORT void WriteProtoToTextFile( - const Message& proto, - const char* filename, - bool throwIfError) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - FileOutputStream* output = new FileOutputStream(fd); - if(!google::protobuf::TextFormat::Print(proto, output)) { - if (throwIfError) { - CAFFE_THROW("Cannot write proto to text file: ", filename); - } else { - LOG(ERROR) << "Cannot write proto to text file: " << filename; - } - } - delete output; - close(fd); -} - -C10_EXPORT bool ReadProtoFromBinaryFile( - const char* filename, - MessageLite* proto) { -#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified - int fd = open(filename, O_RDONLY | O_BINARY); -#else - int fd = open(filename, O_RDONLY); -#endif - CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input( - new CodedInputStream(raw_input.get())); - // A hack to manually allow using very large protocol buffers. - #if GOOGLE_PROTOBUF_VERSION >= 3011000 - // Only take one parameter since protobuf 3.11 - coded_input->SetTotalBytesLimit(2147483647); - #else - // Total bytes hard limit / warning limit are set to 2GB and 512MB respectively. - coded_input->SetTotalBytesLimit(2147483647, 536870912); - #endif - bool success = proto->ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - close(fd); - return success; -} - -C10_EXPORT void WriteProtoToBinaryFile( - const MessageLite& proto, - const char* filename) { - int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); - CAFFE_ENFORCE_NE( - fd, -1, "File cannot be created: ", filename, " error number: ", errno); - std::unique_ptr raw_output(new FileOutputStream(fd)); - std::unique_ptr coded_output( - new CodedOutputStream(raw_output.get())); - CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get())); - coded_output.reset(); - raw_output.reset(); - close(fd); -} - -#endif // CAFFE2_USE_LITE_PROTO - -C10_EXPORT ArgumentHelper::ArgumentHelper(const OperatorDef& def) { - for (auto& arg : def.arg()) { - if (arg_map_.count(arg.name())) { - if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) { - // If there are two arguments of the same name but different contents, - // we will throw an error. - CAFFE_THROW( - "Found argument of the same name ", - arg.name(), - "but with different contents.", - ProtoDebugString(def)); - } else { - LOG(WARNING) << "Duplicated argument name [" << arg.name() - << "] found in operator def: " << ProtoDebugString(def); - } - } - arg_map_[arg.name()] = arg; - } -} - -C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) { - for (auto& arg : netdef.arg()) { - CAFFE_ENFORCE( - arg_map_.count(arg.name()) == 0, - "Duplicated argument name [", - arg.name(), - "] found in net def: ", - ProtoDebugString(netdef)); - arg_map_[arg.name()] = arg; - } -} - -C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const { -#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP - return arg_map_.count(name); -#else - return arg_map_.count(std::string(name)); -#endif -} - -namespace { -// Helper function to verify that conversion between types won't loose any -// significant bit. -template -bool SupportsLosslessConversion(const InputType& value) { - return static_cast(static_cast(value)) == value; -} -} // namespace -bool operator==(const TensorProto& l, const TensorProto& r) { - return l.SerializeAsString() == r.SerializeAsString(); -} - -std::ostream& operator<<(std::ostream& output, const TensorProto& n) { - output << n.SerializeAsString(); - return output; -} -bool operator==(const QTensorProto& l, const QTensorProto& r) { - return l.SerializeAsString() == r.SerializeAsString(); -} - -std::ostream& operator<<(std::ostream& output, const QTensorProto& n) { - output << n.SerializeAsString(); - return output; -} -bool operator==(const NetDef& l, const NetDef& r) { - return l.SerializeAsString() == r.SerializeAsString(); -} - -std::ostream& operator<<(std::ostream& output, const NetDef& n) { - output << n.SerializeAsString(); - return output; -} - -#define INSTANTIATE_GET_SINGLE_ARGUMENT( \ - T, fieldname, enforce_lossless_conversion) \ - template <> \ - C10_EXPORT T ArgumentHelper::GetSingleArgument( \ - c10::string_view name, const T& default_value) const { \ - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ - if (it == arg_map_.end()) { \ - VLOG(1) << "Using default parameter value " << default_value \ - << " for parameter " << name; \ - return default_value; \ - } \ - CAFFE_ENFORCE( \ - it->second.has_##fieldname(), \ - "Argument ", \ - name, \ - " does not have the right field: expected field " #fieldname); \ - const auto& value = it->second.fieldname(); \ - if (enforce_lossless_conversion) { \ - auto supportsConversion = \ - SupportsLosslessConversion(value); \ - CAFFE_ENFORCE( \ - supportsConversion, \ - "Value", \ - value, \ - " of argument ", \ - name, \ - "cannot be represented correctly in a target type"); \ - } \ - return static_cast(value); \ - } \ - template <> \ - C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType( \ - c10::string_view name) const { \ - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ - if (it == arg_map_.end()) { \ - return false; \ - } \ - return it->second.has_##fieldname(); \ - } - -INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true) -INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false) -INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false) -#undef INSTANTIATE_GET_SINGLE_ARGUMENT - -#define INSTANTIATE_GET_REPEATED_ARGUMENT( \ - T, fieldname, enforce_lossless_conversion) \ - template <> \ - C10_EXPORT std::vector ArgumentHelper::GetRepeatedArgument( \ - c10::string_view name, const std::vector& default_value) const { \ - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ - if (it == arg_map_.end()) { \ - return default_value; \ - } \ - std::vector values; \ - for (const auto& v : it->second.fieldname()) { \ - if (enforce_lossless_conversion) { \ - auto supportsConversion = \ - SupportsLosslessConversion(v); \ - CAFFE_ENFORCE( \ - supportsConversion, \ - "Value", \ - v, \ - " of argument ", \ - name, \ - "cannot be represented correctly in a target type"); \ - } \ - values.push_back(static_cast(v)); \ - } \ - return values; \ - } - -INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true) -INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(TensorProto, tensors, false) -INSTANTIATE_GET_REPEATED_ARGUMENT(QTensorProto, qtensors, false) -#undef INSTANTIATE_GET_REPEATED_ARGUMENT - -#define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ - template <> \ - C10_EXPORT Argument MakeArgument(const string& name, const T& value) { \ - Argument arg; \ - arg.set_name(name); \ - arg.set_##fieldname(value); \ - return arg; \ - } - -CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f) -CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(int16_t, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i) -CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s) -#undef CAFFE2_MAKE_SINGULAR_ARGUMENT - -template <> -C10_EXPORT Argument MakeArgument(const string& name, const NetDef& value) { - Argument arg; - arg.set_name(name); - *arg.mutable_n() = value; - return arg; -} - -template <> -C10_EXPORT bool ArgumentHelper::RemoveArgument(OperatorDef& def, int index); -template <> -bool ArgumentHelper::RemoveArgument(NetDef& def, int index); - -template <> -C10_EXPORT Argument MakeArgument(const string& name, const MessageLite& value) { - Argument arg; - arg.set_name(name); - arg.set_s(value.SerializeAsString()); - return arg; -} - -#define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \ - template <> \ - C10_EXPORT Argument MakeArgument( \ - const string& name, const std::vector& value) { \ - Argument arg; \ - arg.set_name(name); \ - for (const auto& v : value) { \ - arg.add_##fieldname(v); \ - } \ - return arg; \ - } - -CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats) -CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints) -CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints) -CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings) -#undef CAFFE2_MAKE_REPEATED_ARGUMENT - -C10_EXPORT bool HasOutput(const OperatorDef& op, const std::string& output) { - for (const auto& outp : op.output()) { - if (outp == output) { - return true; - } - } - return false; -} - -C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) { - for (const auto& inp : op.input()) { - if (inp == input) { - return true; - } - } - return false; -} - -// Return the argument index or -1 if it does not exist. -C10_EXPORT int GetArgumentIndex( - const google::protobuf::RepeatedPtrField& args, - c10::string_view name) { - int index = 0; - for (const Argument& arg : args) { - if (arg.name() == name) { - return index; - } - index++; - } - return -1; -} - -C10_EXPORT const Argument& GetArgument( - const OperatorDef& def, - c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return def.arg(index); - } else { - CAFFE_THROW( - "Argument named ", - name, - " does not exist in operator ", - ProtoDebugString(def)); - } -} - -C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return def.arg(index); - } else { - CAFFE_THROW( - "Argument named ", - name, - " does not exist in net ", - ProtoDebugString(def)); - } -} - -C10_EXPORT const Argument* GetArgumentPtr( - const OperatorDef& def, - c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return &def.arg(index); - } else { - return nullptr; - } -} - -C10_EXPORT const Argument* GetArgumentPtr( - const NetDef& def, - c10::string_view name) { - int index = GetArgumentIndex(def.arg(), name); - if (index != -1) { - return &def.arg(index); - } else { - return nullptr; - } -} - -C10_EXPORT bool GetFlagArgument( - const google::protobuf::RepeatedPtrField& args, - c10::string_view name, - bool default_value) { - int index = GetArgumentIndex(args, name); - if (index != -1) { - // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) - auto arg = args.Get(index); - CAFFE_ENFORCE( - arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg)); - return arg.i(); - } - return default_value; -} - -C10_EXPORT bool GetFlagArgument( - const OperatorDef& def, - c10::string_view name, - bool default_value) { - return GetFlagArgument(def.arg(), name, default_value); -} - -C10_EXPORT bool -GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) { - return GetFlagArgument(def.arg(), name, default_value); -} - -template -Argument* GetMutableArgumentImpl( - const string& name, - const bool create_if_missing, - Def* def) { - for (int i = 0; i < def->arg_size(); ++i) { - if (def->arg(i).name() == name) { - return def->mutable_arg(i); - } - } - // If no argument of the right name is found... - if (create_if_missing) { - Argument* arg = def->add_arg(); - arg->set_name(name); - return arg; - } else { - return nullptr; - } -} - -C10_EXPORT Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - OperatorDef* def) { - return GetMutableArgumentImpl(name, create_if_missing, def); -} - -C10_EXPORT Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - NetDef* def) { - return GetMutableArgumentImpl(name, create_if_missing, def); -} - -C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) { - std::vector oldExternalInputs; - for (const auto& input : net->external_input()) { - oldExternalInputs.emplace_back(input); - } - std::vector oldExternalOutputs; - for (const auto& output : net->external_output()) { - oldExternalOutputs.emplace_back(output); - } - - net->clear_external_input(); - net->clear_external_output(); - - std::set inputSet; - for (const auto& input : oldExternalInputs) { - if (inputSet.count(input)) { - // Prevent duplicate external inputs. - continue; - } - inputSet.insert(input); - net->add_external_input(input); - } - - // Set of blobs that are external inputs or outputs of some operators. - std::set allOutputs(inputSet.begin(), inputSet.end()); - for (const auto& op : net->op()) { - for (const auto& input : op.input()) { - if (inputSet.count(input) || allOutputs.count(input)) { - continue; - } - // Add missing external inputs. - inputSet.insert(input); - net->add_external_input(input); - } - for (const auto& output : op.output()) { - allOutputs.insert(output); - } - } - - std::set outputSet; - for (const auto& output : oldExternalOutputs) { - if (!allOutputs.count(output)) { - continue; - } - if (outputSet.count(output)) { - continue; - } - outputSet.insert(output); - net->add_external_output(output); - } -} - -} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h deleted file mode 100644 index a6903425ab4e..000000000000 --- a/caffe2/utils/proto_utils.h +++ /dev/null @@ -1,383 +0,0 @@ -#ifndef CAFFE2_UTILS_PROTO_UTILS_H_ -#define CAFFE2_UTILS_PROTO_UTILS_H_ - -#ifdef CAFFE2_USE_LITE_PROTO -#include -#else // CAFFE2_USE_LITE_PROTO -#include -#endif // !CAFFE2_USE_LITE_PROTO - -#include -#include -#include - -#include "caffe2/utils/proto_wrap.h" -#include "caffe2/proto/caffe2_pb.h" - -#ifndef C10_ANDROID -#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP -#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key) -#else -#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key)) -#endif - -namespace caffe2 { - -using std::string; -using ::google::protobuf::MessageLite; - -// A wrapper function to return device name string for use in blob serialization -// / deserialization. This should have one to one correspondence with -// caffe2/proto/caffe2.proto: enum DeviceType. -// -// Note that we can't use DeviceType_Name, because that is only available in -// protobuf-full, and some platforms (like mobile) may want to use -// protobuf-lite instead. -TORCH_API std::string DeviceTypeName(const int32_t& d); - -TORCH_API int DeviceId(const DeviceOption& option); - -// Returns if the two DeviceOptions are pointing to the same device. -TORCH_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs); - -TORCH_API bool IsCPUDeviceType(int device_type); -TORCH_API bool IsGPUDeviceType(int device_type); - -// Common interfaces that reads file contents into a string. -TORCH_API bool ReadStringFromFile(const char* filename, string* str); -TORCH_API bool WriteStringToFile(const string& str, const char* filename); - -// Common interfaces that are supported by both lite and full protobuf. -TORCH_API bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); -inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) { - return ReadProtoFromBinaryFile(filename.c_str(), proto); -} - -TORCH_API void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); -inline void WriteProtoToBinaryFile(const MessageLite& proto, - const string& filename) { - return WriteProtoToBinaryFile(proto, filename.c_str()); -} - -#ifdef CAFFE2_USE_LITE_PROTO - -namespace TextFormat { -inline bool ParseFromString(const string& spec, MessageLite* proto) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; - return false; -} -} // namespace TextFormat - - -TORCH_API string ProtoDebugString(const MessageLite& proto); - -TORCH_API bool ParseProtoFromLargeString(const string& str, MessageLite* proto); - -// Text format MessageLite wrappers: these functions do nothing but just -// allowing things to compile. It will produce a runtime error if you are using -// MessageLite but still want text support. -inline bool ReadProtoFromTextFile( - const char* /*filename*/, - MessageLite* /*proto*/) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; - return false; // Just to suppress compiler warning. -} -inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -inline void WriteProtoToTextFile( - const MessageLite& /*proto*/, - const char* /*filename*/, - bool throwIfError = true) { - LOG(FATAL) << "If you are running lite version, you should not be " - << "calling any text-format protobuffers."; -} -inline void WriteProtoToTextFile(const MessageLite& proto, - const string& filename, - bool throwIfError = true) { - return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); -} - -inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#else // CAFFE2_USE_LITE_PROTO - -using ::google::protobuf::Message; - -namespace TextFormat { -TORCH_API bool ParseFromString(const string& spec, Message* proto); -} // namespace TextFormat - -TORCH_API string ProtoDebugString(const Message& proto); - -TORCH_API bool ParseProtoFromLargeString(const string& str, Message* proto); - -TORCH_API bool ReadProtoFromTextFile(const char* filename, Message* proto); -inline bool ReadProtoFromTextFile(const string filename, Message* proto) { - return ReadProtoFromTextFile(filename.c_str(), proto); -} - -TORCH_API void WriteProtoToTextFile(const Message& proto, const char* filename, bool throwIfError = true); -inline void WriteProtoToTextFile(const Message& proto, const string& filename, bool throwIfError = true) { - return WriteProtoToTextFile(proto, filename.c_str(), throwIfError); -} - -// Read Proto from a file, letting the code figure out if it is text or binary. -inline bool ReadProtoFromFile(const char* filename, Message* proto) { - return (ReadProtoFromBinaryFile(filename, proto) || - ReadProtoFromTextFile(filename, proto)); -} - -inline bool ReadProtoFromFile(const string& filename, Message* proto) { - return ReadProtoFromFile(filename.c_str(), proto); -} - -#endif // CAFFE2_USE_LITE_PROTO - -template < - class IterableInputs = std::initializer_list, - class IterableOutputs = std::initializer_list, - class IterableArgs = std::initializer_list> -OperatorDef CreateOperatorDef( - const string& type, - const string& name, - const IterableInputs& inputs, - const IterableOutputs& outputs, - const IterableArgs& args, - const DeviceOption& device_option = DeviceOption(), - const string& engine = "") { - OperatorDef def; - def.set_type(type); - def.set_name(name); - for (const string& in : inputs) { - def.add_input(in); - } - for (const string& out : outputs) { - def.add_output(out); - } - for (const Argument& arg : args) { - def.add_arg()->CopyFrom(arg); - } - if (device_option.has_device_type()) { - def.mutable_device_option()->CopyFrom(device_option); - } - if (engine.size()) { - def.set_engine(engine); - } - return def; -} - -// A simplified version compared to the full CreateOperator, if you do not need -// to specify args. -template < - class IterableInputs = std::initializer_list, - class IterableOutputs = std::initializer_list> -inline OperatorDef CreateOperatorDef( - const string& type, - const string& name, - const IterableInputs& inputs, - const IterableOutputs& outputs, - const DeviceOption& device_option = DeviceOption(), - const string& engine = "") { - return CreateOperatorDef( - type, - name, - inputs, - outputs, - std::vector(), - device_option, - engine); -} - -TORCH_API bool HasOutput(const OperatorDef& op, const std::string& output); -TORCH_API bool HasInput(const OperatorDef& op, const std::string& input); - -/** - * @brief A helper class to index into arguments. - * - * This helper helps us to more easily index into a set of arguments - * that are present in the operator. To save memory, the argument helper - * does not copy the operator def, so one would need to make sure that the - * lifetime of the OperatorDef object outlives that of the ArgumentHelper. - */ -class C10_EXPORT ArgumentHelper { - public: - template - static bool HasArgument(const Def& def, c10::string_view name) { - return ArgumentHelper(def).HasArgument(name); - } - - template - static T GetSingleArgument( - const Def& def, - c10::string_view name, - const T& default_value) { - return ArgumentHelper(def).GetSingleArgument(name, default_value); - } - - template - static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) { - return ArgumentHelper(def).HasSingleArgumentOfType(name); - } - - template - static std::vector GetRepeatedArgument( - const Def& def, - c10::string_view name, - const std::vector& default_value = std::vector()) { - return ArgumentHelper(def).GetRepeatedArgument(name, default_value); - } - - template - static MessageType GetMessageArgument(const Def& def, c10::string_view name) { - return ArgumentHelper(def).GetMessageArgument(name); - } - - template - static std::vector GetRepeatedMessageArgument( - const Def& def, - c10::string_view name) { - return ArgumentHelper(def).GetRepeatedMessageArgument(name); - } - - template - static bool RemoveArgument(Def& def, int index) { - if (index >= def.arg_size()) { - return false; - } - if (index < def.arg_size() - 1) { - def.mutable_arg()->SwapElements(index, def.arg_size() - 1); - } - def.mutable_arg()->RemoveLast(); - return true; - } - - explicit ArgumentHelper(const OperatorDef& def); - explicit ArgumentHelper(const NetDef& netdef); - bool HasArgument(c10::string_view name) const; - - template - T GetSingleArgument(c10::string_view name, const T& default_value) const; - template - bool HasSingleArgumentOfType(c10::string_view name) const; - template - std::vector GetRepeatedArgument( - c10::string_view name, - const std::vector& default_value = std::vector()) const; - - template - MessageType GetMessageArgument(c10::string_view name) const { - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); - CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name); - MessageType message; - if (it->second.has_s()) { - CAFFE_ENFORCE( - message.ParseFromString(it->second.s()), - "Failed to parse content from the string"); - } else { - VLOG(1) << "Return empty message for parameter " << name; - } - return message; - } - - template - std::vector GetRepeatedMessageArgument(c10::string_view name) const { - auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); - CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name); - std::vector messages(it->second.strings_size()); - for (int i = 0; i < messages.size(); ++i) { - CAFFE_ENFORCE( - messages[i].ParseFromString(it->second.strings(i)), - "Failed to parse content from the string"); - } - return messages; - } - - private: - std::map -#endif - > arg_map_; -}; - -// **** Arguments Utils ***** - -// Helper methods to get an argument from OperatorDef or NetDef given argument -// name. Throws if argument does not exist. -TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name); -TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name); -// Helper methods to get an argument from OperatorDef or NetDef given argument -// name. Returns nullptr if argument does not exist. -TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name); -TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name); - -// Helper methods to query a boolean argument flag from OperatorDef or NetDef -// given argument name. If argument does not exist, return default value. -// Throws if argument exists but the type is not boolean. -TORCH_API bool GetFlagArgument( - const OperatorDef& def, - c10::string_view name, - bool default_value = false); -TORCH_API bool GetFlagArgument( - const NetDef& def, - c10::string_view name, - bool default_value = false); - -TORCH_API Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - OperatorDef* def); -TORCH_API Argument* GetMutableArgument( - const string& name, - const bool create_if_missing, - NetDef* def); - -template -TORCH_API Argument MakeArgument(const string& name, const T& value); - -template -inline void AddArgument(const string& name, const T& value, Def* def) { - GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value)); -} -// **** End Arguments Utils ***** - -bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) { - return IsSameDevice(dl, dr); -} - -// Given a net, modify the external inputs/outputs if necessary so that -// the following conditions are met -// - No duplicate external inputs -// - No duplicate external outputs -// - Going through list of ops in order, all op inputs must be outputs -// from other ops, or registered as external inputs. -// - All external outputs must be outputs of some operators. -TORCH_API void cleanupExternalInputsAndOutputs(NetDef* net); - -} // namespace caffe2 - -namespace std { -template <> -struct hash { - typedef caffe2::DeviceOption argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const& device_option) const { - std::string serialized; - CAFFE_ENFORCE(device_option.SerializeToString(&serialized)); - return std::hash{}(serialized); - } -}; -} // namespace std - -#endif // CAFFE2_UTILS_PROTO_UTILS_H_ diff --git a/caffe2/utils/proto_utils_test.cc b/caffe2/utils/proto_utils_test.cc deleted file mode 100644 index 1a687690c69f..000000000000 --- a/caffe2/utils/proto_utils_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include - -#include "caffe2/core/test_utils.h" -#include "caffe2/utils/proto_utils.h" - -namespace caffe2 { - -TEST(ProtoUtilsTest, IsSameDevice) { - DeviceOption a; - DeviceOption b; - EXPECT_TRUE(IsSameDevice(a, b)); - a.set_node_name("my_node"); - EXPECT_FALSE(IsSameDevice(a, b)); - b.set_node_name("my_node"); - EXPECT_TRUE(IsSameDevice(a, b)); - b.set_device_id(2); - EXPECT_FALSE(IsSameDevice(a, b)); - a.set_device_id(2); - EXPECT_TRUE(IsSameDevice(a, b)); - a.set_device_type(DeviceTypeProto::PROTO_CUDA); - b.set_device_type(DeviceTypeProto::PROTO_CPU); - EXPECT_FALSE(IsSameDevice(a, b)); -} - -TEST(ProtoUtilsTest, SimpleReadWrite) { - string content("The quick brown fox jumps over the lazy dog."); - string name = std::tmpnam(nullptr); - EXPECT_TRUE(WriteStringToFile(content, name.c_str())); - string read_back; - EXPECT_TRUE(ReadStringFromFile(name.c_str(), &read_back)); - EXPECT_EQ(content, read_back); -} - -TEST(ProtoUtilsTest, CleanupExternalInputsAndOutputs) { - caffe2::NetDef net; - caffe2::testing::NetMutator(&net) - .newOp("op1", {"X1", "X2"}, {"Y"}) - .newOp("op2", {"W", "Y"}, {"Z1", "Z2"}) - .newOp("op3", {"Z2", "W"}, {"O"}) - .externalInputs({"X1", "X3", "X1", "W"}) - .externalOutputs({"O", "Z2", "Z3", "O", "X3"}); - cleanupExternalInputsAndOutputs(&net); - - std::vector externalInputs; - for (const auto& inputName : net.external_input()) { - externalInputs.emplace_back(inputName); - } - // The 2nd X1 is removed because of duplication. - // X2 is added because it should be a missing external input. - std::vector expectedExternalInputs{"X1", "X3", "W", "X2"}; - EXPECT_EQ(externalInputs, expectedExternalInputs); - - std::vector externalOutputs; - for (const auto& outputName : net.external_output()) { - externalOutputs.emplace_back(outputName); - } - // Z3 is removed because it's not an output of any operator in the net. - // The 2nd O is removed because of duplication. - std::vector expectedexternalOutputs{"O", "Z2", "X3"}; - EXPECT_EQ(externalOutputs, expectedexternalOutputs); -} - -} // namespace caffe2 diff --git a/caffe2/utils/signal_handler.h b/caffe2/utils/signal_handler.h deleted file mode 100644 index 14d93a0df670..000000000000 --- a/caffe2/utils/signal_handler.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include - -namespace caffe2 { - -#if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) -class TORCH_API C2FatalSignalHandler : public c10::FatalSignalHandler { - public: - void fatalSignalHandlerPostProcess() override; - static C2FatalSignalHandler& getInstance(); - - private: - explicit C2FatalSignalHandler(); -}; - -// This works by setting up certain fatal signal handlers. Previous fatal -// signal handlers will still be called when the signal is raised. Defaults -// to being off. -TORCH_API void setPrintStackTracesOnFatalSignal(bool print); -TORCH_API bool printStackTracesOnFatalSignal(); -#endif // defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLER) - -} // namespace caffe2 diff --git a/caffe2/utils/simple_queue.h b/caffe2/utils/simple_queue.h deleted file mode 100644 index c16f55223eed..000000000000 --- a/caffe2/utils/simple_queue.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef CAFFE2_UTILS_SIMPLE_QUEUE_H_ -#define CAFFE2_UTILS_SIMPLE_QUEUE_H_ - -#include // NOLINT -#include // NOLINT -#include - -#include - -namespace caffe2 { - -// This is a very simple queue that Yangqing wrote when bottlefeeding the baby, -// so don't take it seriously. What it does is a minimal thread-safe queue that -// allows me to run network as a DAG. -// -// A usual work pattern looks like this: one or multiple producers push jobs -// into this queue, and one or multiple workers pops jobs from this queue. If -// nothing is in the queue but NoMoreJobs() is not called yet, the pop calls -// will wait. If NoMoreJobs() has been called, pop calls will return false, -// which serves as a message to the workers that they should exit. -template -class SimpleQueue { - public: - SimpleQueue() : no_more_jobs_(false) {} - - // Pops a value and writes it to the value pointer. If there is nothing in the - // queue, this will wait till a value is inserted to the queue. If there are - // no more jobs to pop, the function returns false. Otherwise, it returns - // true. - bool Pop(T* value) { - std::unique_lock mutex_lock(mutex_); - while (queue_.size() == 0 && !no_more_jobs_) cv_.wait(mutex_lock); - if (queue_.size() == 0 && no_more_jobs_) return false; - *value = queue_.front(); - queue_.pop(); - return true; - } - - int size() { - std::unique_lock mutex_lock(mutex_); - return queue_.size(); - } - - // Push pushes a value to the queue. - void Push(const T& value) { - { - std::lock_guard mutex_lock(mutex_); - CAFFE_ENFORCE(!no_more_jobs_, "Cannot push to a closed queue."); - queue_.push(value); - } - cv_.notify_one(); - } - - // NoMoreJobs() marks the close of this queue. It also notifies all waiting - // Pop() calls so that they either check out remaining jobs, or return false. - // After NoMoreJobs() is called, this queue is considered closed - no more - // Push() functions are allowed, and once existing items are all checked out - // by the Pop() functions, any more Pop() function will immediately return - // false with nothing set to the value. - void NoMoreJobs() { - { - std::lock_guard mutex_lock(mutex_); - no_more_jobs_ = true; - } - cv_.notify_all(); - } - - private: - std::mutex mutex_; - std::condition_variable cv_; - std::queue queue_; - bool no_more_jobs_{}; - // We do not allow copy constructors. - SimpleQueue(const SimpleQueue& /*src*/) {} -}; - -} // namespace caffe2 - -#endif // CAFFE2_UTILS_SIMPLE_QUEUE_H_ diff --git a/caffe2/utils/simple_queue_test.cc b/caffe2/utils/simple_queue_test.cc deleted file mode 100644 index e59f699cd15a..000000000000 --- a/caffe2/utils/simple_queue_test.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include // NOLINT - -#include "caffe2/utils/simple_queue.h" -#include - -namespace caffe2 { - -static std::unique_ptr > gQueue; - -static void ConsumerFunction(int thread_idx) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int value; - while (true) { - if (!gQueue->Pop(&value)) return; - VLOG(1) << "Emitting " << value << " from thread " << thread_idx; - } -} - -static void ProducerFunction(int thread_idx, int start, int count) { - for (int i = 0; i < count; ++i) { - VLOG(1) << "Pushing " << i + start << " from thread " << thread_idx; - gQueue->Push(i + start); - } -} - - -TEST(SimpleQueueTest, SingleProducerSingleConsumer) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - std::thread consumer(ConsumerFunction, 0); - for (int i = 0; i < 10; ++i) { - gQueue->Push(i); - } - gQueue->NoMoreJobs(); - consumer.join(); -} - -TEST(SimpleQueueTest, SingleProducerDoubleConsumer) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - std::thread consumer0(ConsumerFunction, 0); - std::thread consumer1(ConsumerFunction, 1); - for (int i = 0; i < 10; ++i) { - gQueue->Push(i); - } - gQueue->NoMoreJobs(); - consumer0.join(); - consumer1.join(); -} - - -TEST(SimpleQueueTest, DoubleProducerDoubleConsumer) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - std::thread producer0(ProducerFunction, 0, 0, 10); - std::thread producer1(ProducerFunction, 0, 10, 10); - std::thread consumer0(ConsumerFunction, 2); - std::thread consumer1(ConsumerFunction, 3); - producer0.join(); - producer1.join(); - gQueue->NoMoreJobs(); - consumer0.join(); - consumer1.join(); -} - -TEST(SimpleQueueDeathTest, CannotAddAfterQueueFinished) { - // NOLINTNEXTLINE(modernize-make-unique) - gQueue.reset(new SimpleQueue()); - gQueue->Push(0); - gQueue->NoMoreJobs(); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_THROW(gQueue->Push(0), EnforceNotMet); -} - - -} // namespace caffe2 diff --git a/caffe2/utils/smart_tensor_printer.h b/caffe2/utils/smart_tensor_printer.h deleted file mode 100644 index e6d96ef37ae0..000000000000 --- a/caffe2/utils/smart_tensor_printer.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "caffe2/core/tensor.h" - -namespace caffe2 { - -// This is a wrapper around the TensorPrinter that doesn't require the user to -// explicit specify the type of the tensor while calling the Print() method. -// It also supports a convenience function with a default constructed printer as -// a static method. -class TORCH_API SmartTensorPrinter { - public: - // The proliferation of constructors is to give the feature parity with - // TensorPrinter - // yet not repeat the default arguments explicitly in case they change in the - // future. - SmartTensorPrinter() = default; - - explicit SmartTensorPrinter(const std::string& tensor_name); - - SmartTensorPrinter( - const std::string& tensor_name, - const std::string& file_name); - - SmartTensorPrinter( - const std::string& tensor_name, - const std::string& file_name, - int limit); - - void Print(const Tensor& tensor); - - void PrintMeta(const Tensor& tensor) { - tensorPrinter_.PrintMeta(tensor); - } - - // Uses a default constructed SmartTensorPrinter - static void PrintTensor(const Tensor& tensor); - - // Uses a default constructed SmartTensorPrinter - void PrintTensorMeta(const Tensor& tensor) { - DefaultTensorPrinter().PrintMeta(tensor); - } - - private: - // Returns a thread local default constructed TensorPrinter - static SmartTensorPrinter& DefaultTensorPrinter(); - - TensorPrinter tensorPrinter_; -}; -} diff --git a/caffe2/utils/smart_tensor_printer_test.cc b/caffe2/utils/smart_tensor_printer_test.cc deleted file mode 100644 index a45573001c6e..000000000000 --- a/caffe2/utils/smart_tensor_printer_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -#include "caffe2/utils/smart_tensor_printer.h" - -#include "caffe2/core/common.h" - -#include - -namespace caffe2 { - -template -std::string my_to_string(const T& value) { - return to_string(value); -} - -template <> -std::string my_to_string(const std::string& value) { - return value; -} - -template -void expect_stderr_contains(const std::vector& values) { - std::string captured_stderr = testing::internal::GetCapturedStderr(); - for (const auto& value : values) { - std::string stringValue = my_to_string(value); - EXPECT_TRUE(captured_stderr.find(stringValue) != std::string::npos); - } -} - -template -void printTensorAndCheck(const std::vector& values) { - testing::internal::CaptureStderr(); - - Tensor tensor = - TensorCPUFromValues({static_cast(values.size())}, values); - - SmartTensorPrinter::PrintTensor(tensor); - expect_stderr_contains(values); -} - -// We need real glog for this test to pass -#ifdef CAFFE2_USE_GOOGLE_GLOG - -#if !(__APPLE__) // TODO(janusz): thread_local does not work under mac. - -TEST(SmartTensorPrinterTest, SimpleTest) { - printTensorAndCheck(std::vector{1, 2, 3, 4, 5}); - printTensorAndCheck(std::vector{"bob", "alice", "facebook"}); -} - -#endif // !(__APPLE__) - -#endif // CAFFE2_USE_GOOGLE_GLOG - -} // namespace caffe2 diff --git a/caffe2/utils/threadpool/ThreadPool.cc b/caffe2/utils/threadpool/ThreadPool.cc index 27ade275672d..298fbe9ef4fa 100644 --- a/caffe2/utils/threadpool/ThreadPool.cc +++ b/caffe2/utils/threadpool/ThreadPool.cc @@ -1,6 +1,5 @@ #include "caffe2/utils/threadpool/ThreadPool.h" #include "WorkersPool.h" -#include "caffe2/core/logging.h" #if !defined(__s390x__) && !defined(__powerpc__) #include diff --git a/caffe2/utils/threadpool/WorkersPool.h b/caffe2/utils/threadpool/WorkersPool.h index b6bbc60f2099..23a72b02465e 100644 --- a/caffe2/utils/threadpool/WorkersPool.h +++ b/caffe2/utils/threadpool/WorkersPool.h @@ -5,8 +5,7 @@ #include #include "c10/util/thread_name.h" #include -#include "caffe2/core/common.h" -#include "caffe2/core/logging.h" +#include #if defined(_MSC_VER) #include diff --git a/caffe2/utils/threadpool/pthreadpool.cc b/caffe2/utils/threadpool/pthreadpool.cc index 44c758db5cb1..b8c6c7cebb8e 100644 --- a/caffe2/utils/threadpool/pthreadpool.cc +++ b/caffe2/utils/threadpool/pthreadpool.cc @@ -4,6 +4,7 @@ #include #include #include +#include #ifdef _MSC_VER #include @@ -14,10 +15,10 @@ #endif /* Library header */ -#include "caffe2/core/logging.h" #include "caffe2/utils/fixed_divisor.h" #include "caffe2/utils/threadpool/pthreadpool.h" +#include static inline size_t divide_round_up(size_t dividend, size_t divisor) { if (dividend % divisor == 0) { diff --git a/caffe2/utils/zmq_helper.h b/caffe2/utils/zmq_helper.h deleted file mode 100644 index 05bc22a73c4e..000000000000 --- a/caffe2/utils/zmq_helper.h +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_ -#define CAFFE2_UTILS_ZMQ_HELPER_H_ - -#include - -#include "caffe2/core/logging.h" - -namespace caffe2 { - -class ZmqContext { - public: - explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) { - CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context."); - int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads); - CAFFE_ENFORCE_EQ(rc, 0); - rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT); - CAFFE_ENFORCE_EQ(rc, 0); - } - ~ZmqContext() { - int rc = zmq_ctx_destroy(ptr_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void* ptr() { return ptr_; } - - private: - void* ptr_; - - C10_DISABLE_COPY_AND_ASSIGN(ZmqContext); -}; - -class ZmqMessage { - public: - ZmqMessage() { - int rc = zmq_msg_init(&msg_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - ~ZmqMessage() { - int rc = zmq_msg_close(&msg_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - zmq_msg_t* msg() { return &msg_; } - - void* data() { return zmq_msg_data(&msg_); } - size_t size() { return zmq_msg_size(&msg_); } - - private: - zmq_msg_t msg_; - C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage); -}; - -class ZmqSocket { - public: - explicit ZmqSocket(int type) - : context_(1), ptr_(zmq_socket(context_.ptr(), type)) { - CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq socket."); - } - - ~ZmqSocket() { - int rc = zmq_close(ptr_); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Bind(const string& addr) { - int rc = zmq_bind(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Unbind(const string& addr) { - int rc = zmq_unbind(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Connect(const string& addr) { - int rc = zmq_connect(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - void Disconnect(const string& addr) { - int rc = zmq_disconnect(ptr_, addr.c_str()); - CAFFE_ENFORCE_EQ(rc, 0); - } - - int Send(const string& msg, int flags) { - int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags); - if (nbytes) { - return nbytes; - } else if (zmq_errno() == EAGAIN) { - return 0; - } else { - LOG(FATAL) << "Cannot send zmq message. Error number: " - << zmq_errno(); - return 0; - } - } - - int SendTillSuccess(const string& msg, int flags) { - CAFFE_ENFORCE(msg.size(), "You cannot send an empty message."); - int nbytes = 0; - do { - nbytes = Send(msg, flags); - } while (nbytes == 0); - return nbytes; - } - - int Recv(ZmqMessage* msg) { - int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0); - if (nbytes >= 0) { - return nbytes; - } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { - return 0; - } else { - LOG(FATAL) << "Cannot receive zmq message. Error number: " - << zmq_errno(); - return 0; - } - } - - int RecvTillSuccess(ZmqMessage* msg) { - int nbytes = 0; - do { - nbytes = Recv(msg); - } while (nbytes == 0); - return nbytes; - } - - private: - ZmqContext context_; - void* ptr_; -}; - -} // namespace caffe2 - - -#endif // CAFFE2_UTILS_ZMQ_HELPER_H_ diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a582a3e6ec05..f1f2eb7cec31 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -89,8 +89,8 @@ endif() if(USE_XPU) include(${CMAKE_CURRENT_LIST_DIR}/public/xpu.cmake) if(NOT PYTORCH_FOUND_XPU) - # message(WARNING "Not compiling with XPU. Could NOT find SYCL." - # "Suppress this warning with -DUSE_XPU=OFF.") + message(WARNING "Not compiling with XPU. Could NOT find SYCL." + "Suppress this warning with -DUSE_XPU=OFF.") caffe2_update_option(USE_XPU OFF) endif() endif() @@ -144,6 +144,8 @@ endif() # ---[ BLAS set(AT_MKLDNN_ACL_ENABLED 0) +set(AT_MKLDNN_ENABLED 0) +set(AT_MKL_ENABLED 0) # setting default preferred BLAS options if not already present. if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") @@ -235,7 +237,6 @@ else() endif() if(NOT INTERN_BUILD_MOBILE) - set(AT_MKL_ENABLED 0) set(AT_MKL_SEQUENTIAL 0) set(USE_BLAS 1) if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND)) @@ -834,71 +835,49 @@ else() endif() include_directories(SYSTEM ${EIGEN3_INCLUDE_DIR}) -# ---[ Python + Numpy -if(BUILD_PYTHON) - # If not given a Python installation, then use the current active Python - if(NOT Python_EXECUTABLE) - execute_process( - COMMAND "which" "python3" RESULT_VARIABLE _exitcode OUTPUT_VARIABLE _py_exe) - if(${_exitcode} EQUAL 0) - if(NOT MSVC) - string(STRIP ${_py_exe} Python_EXECUTABLE) - endif() - message(STATUS "Setting Python to ${Python_EXECUTABLE}") - endif() - endif() - # Check that Python works - set(PYTHON_VERSION) - if(DEFINED Python_EXECUTABLE) - execute_process( - COMMAND "${Python_EXECUTABLE}" "--version" - RESULT_VARIABLE _exitcode OUTPUT_VARIABLE PYTHON_VERSION) - if(NOT _exitcode EQUAL 0) - message(FATAL_ERROR "The Python executable ${Python_EXECUTABLE} cannot be run. Make sure that it is an absolute path.") - endif() - if(PYTHON_VERSION) - string(REGEX MATCH "([0-9]+)\\.([0-9]+)" PYTHON_VERSION ${PYTHON_VERSION}) +# ---[ Python Interpreter +# If not given a Python installation, then use the current active Python +if(NOT Python_EXECUTABLE) + execute_process( + COMMAND "which" "python3" RESULT_VARIABLE _exitcode OUTPUT_VARIABLE _py_exe) + if(${_exitcode} EQUAL 0) + if(NOT MSVC) + string(STRIP ${_py_exe} Python_EXECUTABLE) endif() + message(STATUS "Setting Python to ${Python_EXECUTABLE}") endif() +endif() - # These should fill in the rest of the variables, like versions, but resepct - # the variables we set above - find_package(Python COMPONENTS Interpreter Development) - - if(NOT Python_Development_FOUND) - message(FATAL_ERROR - "Python development libraries could not be found.") - endif() - - if(${Python_VERSION} VERSION_LESS 3.8) - message(FATAL_ERROR - "Found Python libraries version ${Python_VERSION}. Python < 3.8 is no longer supported by PyTorch.") +if(BUILD_PYTHON) + set(PYTHON_COMPONENTS Development) + if(USE_NUMPY) + list(APPEND PYTHON_COMPONENTS NumPy) endif() + find_package(Python COMPONENTS Interpreter OPTIONAL_COMPONENTS ${PYTHON_COMPONENTS}) +else() + find_package(Python COMPONENTS Interpreter) +endif() - # When building pytorch, we pass this in directly from setup.py, and - # don't want to overwrite it because we trust python more than cmake - if(NUMPY_INCLUDE_DIR) - set(NUMPY_FOUND ON) - elseif(USE_NUMPY) - find_package(NumPy) - if(NOT NUMPY_FOUND) - message(WARNING "NumPy could not be found. Not building with NumPy. Suppress this warning with -DUSE_NUMPY=OFF") - endif() - endif() +if(NOT Python_Interpreter_FOUND) + message(FATAL_ERROR "Python3 could not be found.") +endif() - if(Python_Interpreter_FOUND AND Python_Development_FOUND) - add_library(python::python INTERFACE IMPORTED) - target_include_directories(python::python SYSTEM INTERFACE ${Python_INCLUDE_DIRS}) - if(WIN32) - target_link_libraries(python::python INTERFACE ${Python_LIBRARIES}) - endif() +if(${Python_VERSION} VERSION_LESS 3.8) + message(FATAL_ERROR + "Found Python libraries version ${Python_VERSION}. Python < 3.8 is no longer supported by PyTorch.") +endif() - caffe2_update_option(USE_NUMPY OFF) - if(NUMPY_FOUND) - caffe2_update_option(USE_NUMPY ON) - add_library(numpy::numpy INTERFACE IMPORTED) - target_include_directories(numpy::numpy SYSTEM INTERFACE ${NUMPY_INCLUDE_DIR}) +# ---[ Python + Numpy +if(BUILD_PYTHON) + if(Python_Development_FOUND) + if(USE_NUMPY) + if(NOT Python_NumPy_FOUND) + message(WARNING "NumPy could not be found. Not building with NumPy. Suppress this warning with -DUSE_NUMPY=OFF") + caffe2_update_option(USE_NUMPY OFF) + else() + caffe2_update_option(USE_NUMPY ON) + endif() endif() # Observers are required in the python build caffe2_update_option(USE_OBSERVERS ON) @@ -1303,8 +1282,6 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) add_definitions(-DONNX_ML=1) endif() add_definitions(-DONNXIFI_ENABLE_EXT=1) - # Add op schemas in "ai.onnx.pytorch" domain - add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/../caffe2/onnx/torch_ops") if(NOT USE_SYSTEM_ONNX) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx EXCLUDE_FROM_ALL) if(NOT MSVC) @@ -1485,8 +1462,6 @@ if(NOT INTERN_BUILD_MOBILE) set(AT_ROCM_ENABLED 1) endif() - set(AT_MKLDNN_ENABLED 0) - set(AT_MKLDNN_ACL_ENABLED 0) if(USE_MKLDNN) if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) message(WARNING @@ -1704,3 +1679,7 @@ endif() # Include google/FlatBuffers include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake) + +# Include cpp-httplib +add_library(httplib INTERFACE IMPORTED) +target_include_directories(httplib SYSTEM INTERFACE ${PROJECT_SOURCE_DIR}/third_party/cpp-httplib) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index de64370b37a2..ec6f09b60533 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -4,25 +4,38 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/src") set(__AOTRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/build") set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") - ExternalProject_Add(aotriton_external - GIT_REPOSITORY https://github.com/ROCm/aotriton.git - GIT_TAG 24a3fe9cb57e5cda3c923df29743f9767194cc27 - SOURCE_DIR ${__AOTRITON_SOURCE_DIR} - BINARY_DIR ${__AOTRITON_BUILD_DIR} - PREFIX ${__AOTRITON_INSTALL_DIR} - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} - -DAOTRITON_COMPRESS_KERNEL=OFF - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NO_SHARED=ON - # CONFIGURE_COMMAND "" - # BUILD_COMMAND ${MAKE_COMMAND} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a" - # INSTALL_COMMAND ${MAKE_COMMAND} install - ) - set(AOTRITON_FOUND TRUE) add_library(__caffe2_aotriton INTERFACE) - add_dependencies(__caffe2_aotriton aotriton_external) + # Note it is INSTALL"ED" + if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX}) + set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") + message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") + else() + file(STRINGS "${CMAKE_CURRENT_SOURCE_DIR}/.ci/docker/aotriton_version.txt" __AOTRITON_CI_INFO) + list(GET __AOTRITON_CI_INFO 3 __AOTRITON_CI_COMMIT) + ExternalProject_Add(aotriton_external + GIT_REPOSITORY https://github.com/ROCm/aotriton.git + GIT_TAG ${__AOTRITON_CI_COMMIT} + SOURCE_DIR ${__AOTRITON_SOURCE_DIR} + BINARY_DIR ${__AOTRITON_BUILD_DIR} + PREFIX ${__AOTRITON_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + -DAOTRITON_COMPRESS_KERNEL=OFF + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_NO_PYTHON=ON + -DAOTRITON_NO_SHARED=ON + # CONFIGURE_COMMAND "" + BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system. + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a" + USES_TERMINAL_DOWNLOAD TRUE + USES_TERMINAL_CONFIGURE TRUE + USES_TERMINAL_BUILD TRUE + USES_TERMINAL_INSTALL TRUE + # INSTALL_COMMAND ${MAKE_COMMAND} install + ) + add_dependencies(__caffe2_aotriton aotriton_external) + message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_SOURCE_DIR}") + endif() target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) + set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 71d73866b2af..433d96ebfd23 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -1,11 +1,3 @@ -if(UNIX) - # prevent Unknown CMake command "check_function_exists". - include(CheckFunctionExists) -endif() -include(CheckIncludeFile) -include(CheckCSourceCompiles) -include(CheckCSourceRuns) -include(CheckCCompilerFlag) include(CheckCXXSourceCompiles) include(CheckCXXCompilerFlag) include(CMakePushCheckState) diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake index 9604723e2cd3..1497f951402f 100644 --- a/cmake/Modules/FindAVX.cmake +++ b/cmake/Modules/FindAVX.cmake @@ -1,4 +1,5 @@ INCLUDE(CheckCSourceRuns) +INCLUDE(CheckCSourceCompiles) INCLUDE(CheckCXXSourceRuns) SET(AVX_CODE " diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 9e002c939e5b..382e71b1049b 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -21,20 +21,33 @@ IF(NOT MKLDNN_FOUND) if(USE_XPU) # Build oneDNN GPU library if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(DNNL_HOST_COMPILER "g++") + # Linux # g++ is soft linked to /usr/bin/cxx, oneDNN would not treat it as an absolute path + set(DNNL_HOST_COMPILER "g++") + set(SYCL_CXX_DRIVER "icpx") + set(DNNL_LIB_NAME "libdnnl.a") else() - message(FATAL_ERROR "oneDNN library currently only supports GUN g++ compiler for XPU backend") + # Windows + set(DNNL_HOST_COMPILER "DEFAULT") + set(SYCL_CXX_DRIVER "icx") + set(DNNL_LIB_NAME "dnnl.lib") endif() set(DNNL_MAKE_COMMAND "cmake" "--build" ".") + include(ProcessorCount) + ProcessorCount(proc_cnt) + if ((DEFINED ENV{MAX_JOBS}) AND ("$ENV{MAX_JOBS}" LESS_EQUAL ${proc_cnt})) + list(APPEND DNNL_MAKE_COMMAND "-j" "$ENV{MAX_JOBS}") + if(CMAKE_GENERATOR MATCHES "Make|Ninja") + list(APPEND DNNL_MAKE_COMMAND "--" "-l" "$ENV{MAX_JOBS}") + endif() + endif() ExternalProject_Add(xpu_mkldnn_proj SOURCE_DIR ${MKLDNN_ROOT} PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx - -DCMAKE_CXX_COMPILER=icpx - -DCMAKE_CXX_COMPILER_ID=IntelLLVM + -DCMAKE_CXX_COMPILER=${SYCL_CXX_DRIVER} -DDNNL_GPU_RUNTIME=SYCL -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF @@ -44,20 +57,20 @@ IF(NOT MKLDNN_FOUND) -DDNNL_DPCPP_HOST_COMPILER=${DNNL_HOST_COMPILER} # Use global cxx compiler as host compiler -G ${CMAKE_GENERATOR} # Align Generator to Torch BUILD_COMMAND ${DNNL_MAKE_COMMAND} - BUILD_BYPRODUCTS "xpu_mkldnn_proj-prefix/src/xpu_mkldnn_proj-build/src/libdnnl.a" + BUILD_BYPRODUCTS "xpu_mkldnn_proj-prefix/src/xpu_mkldnn_proj-build/src/${DNNL_LIB_NAME}" INSTALL_COMMAND "" ) ExternalProject_Get_Property(xpu_mkldnn_proj BINARY_DIR) set(__XPU_MKLDNN_BUILD_DIR ${BINARY_DIR}) - set(XPU_MKLDNN_LIBRARIES ${__XPU_MKLDNN_BUILD_DIR}/src/libdnnl.a) + set(XPU_MKLDNN_LIBRARIES ${__XPU_MKLDNN_BUILD_DIR}/src/${DNNL_LIB_NAME}) set(XPU_MKLDNN_INCLUDE ${__XPU_MKLDNN_BUILD_DIR}/include) # This target would be further linked to libtorch_xpu.so. # The libtorch_xpu.so would contain Conv&GEMM operators that depend on # oneDNN primitive implementations inside libdnnl.a. add_library(xpu_mkldnn INTERFACE) add_dependencies(xpu_mkldnn xpu_mkldnn_proj) - target_link_libraries(xpu_mkldnn INTERFACE ${__XPU_MKLDNN_BUILD_DIR}/src/libdnnl.a) + target_link_libraries(xpu_mkldnn INTERFACE ${__XPU_MKLDNN_BUILD_DIR}/src/${DNNL_LIB_NAME}) target_include_directories(xpu_mkldnn INTERFACE ${XPU_MKLDNN_INCLUDE}) endif() diff --git a/cmake/Modules/FindMatlabMex.cmake b/cmake/Modules/FindMatlabMex.cmake deleted file mode 100644 index 28ae65e7cbba..000000000000 --- a/cmake/Modules/FindMatlabMex.cmake +++ /dev/null @@ -1,48 +0,0 @@ -# This module looks for MatlabMex compiler -# Defines variables: -# Matlab_DIR - Matlab root dir -# Matlab_mex - path to mex compiler -# Matlab_mexext - path to mexext - -if(MSVC) - foreach(__ver "9.30" "7.14" "7.11" "7.10" "7.9" "7.8" "7.7") - get_filename_component(__matlab_root "[HKEY_LOCAL_MACHINE\\SOFTWARE\\MathWorks\\MATLAB\\${__ver};MATLABROOT]" ABSOLUTE) - if(__matlab_root) - break() - endif() - endforeach() -endif() - -if(APPLE) - foreach(__ver "R2014b" "R2014a" "R2013b" "R2013a" "R2012b" "R2012a" "R2011b" "R2011a" "R2010b" "R2010a") - if(EXISTS /Applications/MATLAB_${__ver}.app) - set(__matlab_root /Applications/MATLAB_${__ver}.app) - break() - endif() - endforeach() -endif() - -if(UNIX) - execute_process(COMMAND which matlab OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE __out RESULT_VARIABLE __res) - - if(__res MATCHES 0) # Suppress `readlink` warning if `which` returned nothing - execute_process(COMMAND which matlab COMMAND xargs readlink - COMMAND xargs dirname COMMAND xargs dirname COMMAND xargs echo -n - OUTPUT_VARIABLE __matlab_root OUTPUT_STRIP_TRAILING_WHITESPACE) - endif() -endif() - - -find_path(Matlab_DIR NAMES bin/mex bin/mexext PATHS ${__matlab_root} - DOC "Matlab directory" NO_DEFAULT_PATH) - -find_program(Matlab_mex NAMES mex mex.bat HINTS ${Matlab_DIR} PATH_SUFFIXES bin NO_DEFAULT_PATH) -find_program(Matlab_mexext NAMES mexext mexext.bat HINTS ${Matlab_DIR} PATH_SUFFIXES bin NO_DEFAULT_PATH) - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(MatlabMex DEFAULT_MSG Matlab_mex Matlab_mexext) - -if(MATLABMEX_FOUND) - mark_as_advanced(Matlab_mex Matlab_mexext) -endif() diff --git a/cmake/Modules/FindNumPy.cmake b/cmake/Modules/FindNumPy.cmake deleted file mode 100644 index 2c43b95bdcf6..000000000000 --- a/cmake/Modules/FindNumPy.cmake +++ /dev/null @@ -1,57 +0,0 @@ -# - Find the NumPy libraries -# This module finds if NumPy is installed, and sets the following variables -# indicating where it is. -# -# TODO: Update to provide the libraries and paths for linking npymath lib. -# -# NUMPY_FOUND - was NumPy found -# NUMPY_VERSION - the version of NumPy found as a string -# NUMPY_VERSION_MAJOR - the major version number of NumPy -# NUMPY_VERSION_MINOR - the minor version number of NumPy -# NUMPY_VERSION_PATCH - the patch version number of NumPy -# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601 -# NUMPY_INCLUDE_DIR - path to the NumPy include files - -unset(NUMPY_VERSION) -unset(NUMPY_INCLUDE_DIR) - -if(Python_Interpreter_FOUND) - execute_process(COMMAND "${Python_EXECUTABLE}" "-c" - "import numpy as n; print(n.__version__); print(n.get_include());" - RESULT_VARIABLE __result - OUTPUT_VARIABLE __output - OUTPUT_STRIP_TRAILING_WHITESPACE) - - if(__result MATCHES 0) - string(REGEX REPLACE ";" "\\\\;" __values ${__output}) - string(REGEX REPLACE "\r?\n" ";" __values ${__values}) - list(GET __values 0 NUMPY_VERSION) - list(GET __values 1 NUMPY_INCLUDE_DIR) - - string(REGEX MATCH "^([0-9])+\\.([0-9])+\\.([0-9])+" __ver_check "${NUMPY_VERSION}") - if(NOT "${__ver_check}" STREQUAL "") - set(NUMPY_VERSION_MAJOR ${CMAKE_MATCH_1}) - set(NUMPY_VERSION_MINOR ${CMAKE_MATCH_2}) - set(NUMPY_VERSION_PATCH ${CMAKE_MATCH_3}) - math(EXPR NUMPY_VERSION_DECIMAL - "(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}") - string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIR ${NUMPY_INCLUDE_DIR}) - else() - unset(NUMPY_VERSION) - unset(NUMPY_INCLUDE_DIR) - message(STATUS "Requested NumPy version and include path, but got instead:\n${__output}\n") - endif() - endif() -else() - message(STATUS "To find NumPy Python interpretator is required to be found.") -endif() - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(NumPy REQUIRED_VARS NUMPY_INCLUDE_DIR NUMPY_VERSION - VERSION_VAR NUMPY_VERSION) - -if(NUMPY_FOUND) - message(STATUS "NumPy ver. ${NUMPY_VERSION} found (include: ${NUMPY_INCLUDE_DIR})") -endif() - -caffe_clear_vars(__result __output __error_value __values __ver_check __error_value) diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index d9345bb2fe0d..4a4a6dfaa789 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -55,6 +55,23 @@ find_library( HINTS ${SYCL_LIBRARY_DIR} NO_DEFAULT_PATH ) +# On Windows, currently there's no sycl.lib. Only sycl7.lib with version suffix, +# where the current version of the SYCL runtime is 7. +# Until oneAPI adds support to sycl.lib without the version suffix, +# sycl_runtime_version needs to be hardcoded and uplifted when SYCL runtime version uplifts. +# TODO: remove this when sycl.lib is supported on Windows +if(WIN32) + set(sycl_runtime_version 7) + find_library( + SYCL_LIBRARY + NAMES "sycl${sycl_runtime_version}" + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH + ) + if(SYCL_LIBRARY STREQUAL "SYCL_LIBRARY-NOTFOUND") + message(FATAL_ERROR "Cannot find a SYCL library on Windows") + endif() +endif() find_library( OCL_LIBRARY diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 99b6521328d6..aeb367690d3f 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -14,6 +14,9 @@ function(caffe2_print_configuration_summary) message(STATUS " Found ccache : ${CCACHE_PROGRAM}") endif() message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") + message(STATUS " Shared LD flags : ${CMAKE_SHARED_LINKER_FLAGS}") + message(STATUS " Static LD flags : ${CMAKE_STATIC_LINKER_FLAGS}") + message(STATUS " Module LD flags : ${CMAKE_MODULE_LINKER_FLAGS}") message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) message(STATUS " Compile definitions : ${tmp}") @@ -71,7 +74,6 @@ function(caffe2_print_configuration_summary) message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") - message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") message(STATUS " CUDA version : ${CUDA_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") @@ -119,6 +121,7 @@ function(caffe2_print_configuration_summary) if(${USE_ROCM}) message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") + message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 0f5da8e6cae2..c4adccf3b61b 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -479,9 +479,7 @@ function(torch_compile_options libname) # templated classes crossing library boundary get duplicated (but identical) # definitions. It's easier to just disable it. target_compile_options(${libname} PRIVATE - $<$: -fvisibility=hidden> - $<$: -fvisibility=hidden> - $<$: -fvisibility=hidden>) + $<$: -fvisibility=hidden>) endif() # Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression) diff --git a/defs.bzl b/defs.bzl index 6ea4b1219325..5e8923556af0 100644 --- a/defs.bzl +++ b/defs.bzl @@ -33,7 +33,6 @@ default_compiler_flags = [ "-DTH_INDEX_BASE=0", "-DMAGMA_V2", "-DNO_CUDNN_DESTROY_HANDLE", - "-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api "-DUSE_FBGEMM", "-DUSE_PYTORCH_QNNPACK", # The dynamically loaded NVRTC trick doesn't work in fbcode, diff --git a/docker.Makefile b/docker.Makefile index a33c411907bc..7f131707e7ab 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -10,7 +10,7 @@ endif CUDA_VERSION_SHORT ?= 12.1 CUDA_VERSION ?= 12.1.1 -CUDNN_VERSION ?= 8 +CUDNN_VERSION ?= 9 BASE_RUNTIME = ubuntu:22.04 BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-devel-ubuntu22.04 CMAKE_VARS ?= diff --git a/docs/source/backends.rst b/docs/source/backends.rst index ef3c720e8335..bd83e49f5f2d 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -92,6 +92,8 @@ torch.backends.cuda .. autofunction:: torch.backends.cuda.can_use_efficient_attention +.. autofunction:: torch.backends.cuda.can_use_cudnn_attention + .. autofunction:: torch.backends.cuda.sdp_kernel torch.backends.cudnn diff --git a/docs/source/conf.py b/docs/source/conf.py index fe548737b313..4f73c111cb23 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -606,45 +606,6 @@ # torch.distributed.optim.utils "as_functional_optim", "register_functional_optim", - # torch.distributed.pipeline.sync.checkpoint - "checkpoint", - "enable_checkpointing", - "enable_recomputing", - "is_checkpointing", - "is_recomputing", - "restore_rng_states", - "save_rng_states", - # torch.distributed.pipeline.sync.dependency - "fork", - "join", - # torch.distributed.pipeline.sync.microbatch - "check", - "gather", - "scatter", - # torch.distributed.pipeline.sync.phony - "get_phony", - # torch.distributed.pipeline.sync.skip.layout - "inspect_skip_layout", - # torch.distributed.pipeline.sync.skip.tracker - "current_skip_tracker", - "use_skip_tracker", - # torch.distributed.pipeline.sync.stream - "as_cuda", - "current_stream", - "default_stream", - "get_device", - "is_cuda", - "new_stream", - "record_stream", - "use_device", - "use_stream", - "wait_stream", - # torch.distributed.pipeline.sync.utils - "partition_model", - # torch.distributed.pipeline.sync.worker - "create_workers", - "spawn_workers", - "worker", # torch.distributed.rendezvous "register_rendezvous_handler", "rendezvous", @@ -2648,52 +2609,6 @@ "PostLocalSGDOptimizer", # torch.distributed.optim.zero_redundancy_optimizer "ZeroRedundancyOptimizer", - # torch.distributed.pipeline.sync.batchnorm - "DeferredBatchNorm", - # torch.distributed.pipeline.sync.checkpoint - "Checkpoint", - "Checkpointing", - "Context", - "Function", - "Recompute", - "ThreadLocal", - # torch.distributed.pipeline.sync.copy - "Context", - "Copy", - "Wait", - # torch.distributed.pipeline.sync.dependency - "Fork", - "Join", - # torch.distributed.pipeline.sync.microbatch - "Batch", - "NoChunk", - # torch.distributed.pipeline.sync.pipe - "BalanceError", - "Pipe", - "PipeSequential", - "WithDevice", - # torch.distributed.pipeline.sync.pipeline - "Pipeline", - # torch.distributed.pipeline.sync.skip.layout - "SkipLayout", - # torch.distributed.pipeline.sync.skip.namespace - "Namespace", - # torch.distributed.pipeline.sync.skip.portal - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange", - # torch.distributed.pipeline.sync.skip.skippable - "Skippable", - # torch.distributed.pipeline.sync.skip.tracker - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - # torch.distributed.pipeline.sync.stream - "CPUStreamType", - # torch.distributed.pipeline.sync.worker - "Task", # torch.distributed.rpc.api "AllGatherStates", "RRef", diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index cee1ec6af2e8..7b9bf536c145 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -144,6 +144,26 @@ Jiterator (beta) jiterator._create_jit_fn jiterator._create_multi_output_jit_fn +TunableOp +--------- + +Some operations could be implemented using more than one library or more than +one technique. For example, a GEMM could be implemented for CUDA or ROCm using +either the cublas/cublasLt libraries or hipblas/hipblasLt libraries, +respectively. How does one know which implementation is the fastest and should +be chosen? That's what TunableOp provides. Certain operators have been +implemented using multiple strategies as Tunable Operators. At runtime, all +strategies are profiled and the fastest is selected for all subsequent +operations. + +See the :doc:`documentation ` for information on how to use it. + +.. toctree:: + :hidden: + + cuda.tunable + + Stream Sanitizer (prototype) ---------------------------- diff --git a/docs/source/cuda.tunable.rst b/docs/source/cuda.tunable.rst new file mode 100644 index 000000000000..52482122ec75 --- /dev/null +++ b/docs/source/cuda.tunable.rst @@ -0,0 +1,32 @@ +.. currentmodule:: torch.cuda.tunable + +TunableOp +========= + +.. note:: + This is a prototype feature, which means it is at an early stage + for feedback and testing, and its components are subject to change. + +Overview +-------- + +.. automodule:: torch.cuda.tunable + +API Reference +------------- + +.. autofunction:: enable +.. autofunction:: is_enabled +.. autofunction:: tuning_enable +.. autofunction:: tuning_is_enabled +.. autofunction:: set_max_tuning_duration +.. autofunction:: get_max_tuning_duration +.. autofunction:: set_max_tuning_iterations +.. autofunction:: get_max_tuning_iterations +.. autofunction:: set_filename +.. autofunction:: get_filename +.. autofunction:: get_results +.. autofunction:: get_validators +.. autofunction:: write_file_on_exit +.. autofunction:: write_file +.. autofunction:: read_file diff --git a/docs/source/distributed.elastic.rst b/docs/source/distributed.elastic.rst index 24d33d1982df..0aabb560c9c8 100644 --- a/docs/source/distributed.elastic.rst +++ b/docs/source/distributed.elastic.rst @@ -29,6 +29,7 @@ Documentation elastic/metrics elastic/events elastic/subprocess_handler + elastic/control_plane .. toctree:: :maxdepth: 1 diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 2f4218a0d980..e1d66d223b2b 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -4,184 +4,425 @@ Pipeline Parallelism #################### -.. note:: ``torch.distributed.pipelining`` is a package migrated from the `PiPPy `_ project. It is currently in alpha state and under extensive development. For examples that work with our APIs, please refer to PiPPy's `examples `_ directory. +.. note:: + ``torch.distributed.pipelining`` is currently in alpha state and under + development. API changes may be possible. It was migrated from the `PiPPy + `_ project. + Why Pipeline Parallel? ********************** -One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include *data parallelism*, *tensor/operation parallelism*, and *pipeline parallelism* (or *pipelining*). Pipelining is a technique in which the *code* of the model is partitioned and multiple *micro-batches* execute different parts of the model code concurrently. In many cases, pipeline parallelism can be an effective technique for scaling, in particular for large-scale jobs or bandwidth-limited interconnects. To learn more about pipeline parallelism in deep learning, see `this article `_. +Pipeline Parallelism is one of the **primitive** parallelism for deep learning. +It allows the **execution** of a model to be partitioned such that multiple +**micro-batches** can execute different parts of the model code concurrently. +Pipeline parallelism can be an effective technique for: + +* large-scale training +* bandwidth-limited clusters +* large model inference. + +The above scenarios share a commonality that the computation per device cannot +hide the communication of conventional parallelism, for example, the weight +all-gather of FSDP. + What is ``torch.distributed.pipelining``? ***************************************** -.. automodule:: torch.distributed.pipelining +While promising for scaling, pipelining is often difficult to implement because +it needs to **partition the execution** of a model in addition to model weights. +The partitioning of execution often requires intrusive code changes to your +model. Another aspect of complexity comes from **scheduling micro-batches in a +distributed environment**, with **data flow dependency** considered. + +The ``pipelining`` package provides a toolkit that does said things +**automatically** which allows easy implementation of pipeline parallelism +on **general** models. + +It consists of two parts: a +**splitting frontend** and a **distributed runtime**. +The splitting frontend takes your model code as-is, splits it up into "model +partitions", and captures the data-flow relationship. The distributed runtime +executes the pipeline stages on different devices in parallel, handling things +like micro-batch splitting, scheduling, communication, and gradient propagation, +etc. + +Overall, the ``pipelining`` package provides the following features: + +* Splitting of model code based on simple specification. +* Rich support for pipeline schedules, including GPipe, 1F1B, + Interleaved 1F1B and Looped BFS, and providing the infrastruture for writing + customized schedules. +* First-class support for cross-host pipeline parallelism, as this is where PP + is typically used (over slower interconnects). +* Composability with other PyTorch parallel techniques such as data parallel + (DDP, FSDP) or tensor parallel. The `TorchTitan + `_ project demonstrates a "3D parallel" + application on the Llama model. + + +Step 1: build ``PipelineStage`` +******************************* + +Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage`` +objects that wrap the part of the model running in that stage. The +``PipelineStage`` is responsible for allocating communication buffers and +creating send/recv ops to communicate with its peers. It manages intermediate +buffers e.g. for the outputs of forward that have not been consumed yet, and it +provides a utility for running the backwards for the stage model. + +A ``PipelineStage`` needs to know the input and output shapes for the stage +model, so that it can correctly allocate communication buffers. The shapes must +be static, e.g. at runtime the shapes can not change from step to step. A class +``PipeliningShapeError`` will be raised if runtime shapes do not match the +expected shapes. When composing with other paralleisms or applying mixed +precision, these techniques must be taken into account so the ``PipelineStage`` +knows the correct shape (and dtype) for the output of the stage module at +runtime. + +Users may construct a ``PipelineStage`` instance directly, by passing in an +``nn.Module`` representing the portion of the model that should run on the +stage. This may require changes to the original model code. See the example +in :ref:`option_1_manual`. + +Alternatively, the splitting frontend can use graph partitioning to split your +model into a series of ``nn.Module`` automatically. This technique requires the +model is traceable with ``torch.Export``. Composability of the resulting +``nn.Module`` with other parallelism techniques is experimental, and may require +some workarounds. Usage of this frontend may be more appealing if the user +cannot easily change the model code. See :ref:`option_2_tracer` for more +information. + + +Step 2: use ``PipelineSchedule`` for execution +********************************************** + +We can now attach the ``PipelineStage`` to a pipeline schedule, and run the +schedule with input data. Here is a GPipe example: -.. currentmodule:: torch.distributed.pipelining +.. code-block:: python -While promising for scaling, pipelining is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. ``torch.distributed.pipelining`` aims to provide **a toolkit that does said things automatically to allow high-productivity scaling of models.** It consists of a **compiler** and a **runtime** stack for easy pipelining of PyTorch models. In particular, it provides the following features: + from torch.distributed.pipelining import ScheduleGPipe -* Splitting of model code based on your specification. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. The specification is also simple. -* Support for rich pipeline scheduling paradigms, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS. It will be also easy to customize your own schedule under this framework. -* First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). -* Composability with other PyTorch parallel schemes such as data parallelism (DDP, FSDP) or tensor parallelism (overall, known as "3d parallelism"). + # Create a schedule + schedule = ScheduleGPipe(stage, n_microbatches) -Examples -******** + # Input data (whole batch) + x = torch.randn(batch_size, in_dim, device=device) -In the `PiPPy `_ repo where this package is migrated from, we provide rich examples based on realistic models. In particular, we show how to apply pipelining without any model code change. You can refer to the `HuggingFace examples directory `_. Popular examples include: `GPT2 `_, and `LLaMA `_. + # Run the pipeline with input `x` + # `x` will be divided into microbatches automatically + if rank == 0: + schedule.step(x) + else: + output = schedule.step() -Techniques Explained -******************** +Note that the above code needs to be launched for each worker, thus we use a +launcher service to launch multiple processes: -``torch.distributed.pipelining`` consists of two parts: a *compiler* and a *runtime*. The compiler takes your model code, splits it up, and transforms it into a ``Pipe``, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the ``PipelineStage`` in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. +.. code-block:: bash -Splitting a Model with ``pipeline`` -=================================== + torchrun --nproc_per_node=2 example.py -To see how we can split a model into a pipeline, let's first take an example trivial neural network: -.. code-block:: python +Options for Splitting a Model +***************************** - import torch +.. _option_1_manual: - class MyNetworkBlock(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.lin = torch.nn.Linear(in_dim, out_dim) +Option 1: splitting a model manually +==================================== - def forward(self, x): - x = self.lin(x) - x = torch.relu(x) - return x +To directly construct a ``PipelineStage``, the user is responsible for providing +a single ``nn.Module`` instance that owns the relevant ``nn.Parameters`` and +``nn.Buffers``, and defines a ``forward()`` method that executes the operations +relevant for that stage. For example, a condensed version of the Transformer +class defined in Torchtitan shows a pattern of building an easily partitionable +model. +.. code-block:: python - class MyNetwork(torch.nn.Module): - def __init__(self, in_dim, layer_dims): + class Transformer(nn.Module): + def __init__(self, model_args: ModelArgs): super().__init__() - prev_dim = in_dim - for i, dim in enumerate(layer_dims): - setattr(self, f'layer{i}', MyNetworkBlock(prev_dim, dim)) - prev_dim = dim + self.tok_embeddings = nn.Embedding(...) + + # Using a ModuleDict lets us delete layers witout affecting names, + # ensuring checkpoints will correctly save and load. + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(...) - self.num_layers = len(layer_dims) - # 10 output classes - self.output_proj = torch.nn.Linear(layer_dims[-1], 10) + self.output = nn.Linear(...) - def forward(self, x): - for i in range(self.num_layers): - x = getattr(self, f'layer{i}')(x) + def forward(self, tokens: torch.Tensor): + # Handling layers being 'None' at runtime enables easy pipeline splitting + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - return self.output_proj(x) + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h).float() if self.output else h + return output + +A model defined in this manner can be easily configured per stage by first +initializing the whole model (using meta-device to avoid OOM errors), deleting +undesired layers for that stage, and then creating a PipelineStage that wraps +the model. For example: + +.. code-block:: python + + with torch.device("meta"): + assert num_stages == 2, "This is a simple 2-stage example" + + # we construct the entire model, then delete the parts we do not need for this stage + # in practice, this can be done using a helper function that automatically divides up layers across stages. + model = Transformer() + + if stage_index == 0: + # prepare the first stage model + del model.layers["1"] + model.norm = None + model.output = None + + elif stage_index == 1: + # prepare the second stage model + model.tok_embeddings = None + del model.layers["0"] + + from torch.distributed.pipelining import PipelineStage + stage = PipelineStage( + model, + stage_index, + num_stages, + device, + input_args=example_input_microbatch, + ) - in_dim = 512 - layer_dims = [512, 1024, 256] - mn = MyNetwork(in_dim, layer_dims).to(device) +The ``PipelineStage`` requires an example argument ``input_args`` representing +the runtime input to the stage, which would be one microbatch worth of input +data. This argument is passed through the forward method of the stage module to +determine the input and output shapes required for communication. -This network is written as free-form Python code; it has not been modified for any specific parallelism technique. +When composing with other Data or Model parallelism techniques, ``output_args`` +may also be required, if the output shape/dtype of the model chunk will be +affected. -Let us see our usage of the ``pipeline`` interface: + +.. _option_2_tracer: + +Option 2: splitting a model automatically +========================================= + +If you have a full model and do not want to spend time on modifying it into a +sequence of "model partitions", the ``pipeline`` API is here to help. +Here is a brief example: + +.. code-block:: python + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(10, 3) + self.layers = torch.nn.ModuleList( + Layer() for _ in range(2) + ) + self.lm = LMHead() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.emb(x) + for layer in self.layers: + x = layer(x) + x = self.lm(x) + return x + + +If we print the model, we can see multiple hierarchies, which makes it hard to split by hand:: + + Model( + (emb): Embedding(10, 3) + (layers): ModuleList( + (0-1): 2 x Layer( + (lin): Linear(in_features=3, out_features=3, bias=True) + ) + ) + (lm): LMHead( + (proj): Linear(in_features=3, out_features=3, bias=True) + ) + ) + +Let us see how the ``pipeline`` API works: .. code-block:: python - from torch.distributed.pipelining import annotate_split_points, pipeline, Pipe, SplitPoint + from torch.distributed.pipelining import pipeline, SplitPoint - annotate_split_points(mn, {'layer0': SplitPoint.END, - 'layer1': SplitPoint.END}) + # An example micro-batch input + x = torch.LongTensor([1, 2, 4, 5]) - batch_size = 32 - example_input = torch.randn(batch_size, in_dim, device=device) - chunks = 4 + pipe = pipeline( + module=mod, + mb_args=(x,), + split_spec={ + "layers.1": SplitPoint.BEGINNING, + } + ) - pipe = pipeline(mn, chunks, example_args=(example_input,)) - print(pipe) +The ``pipeline`` API splits your model given a ``split_spec``, where +``SplitPoint.BEGINNING`` stands for adding a split point +*before* execution of certain submodule in the ``forward`` function, and +similarly, ``SplitPoint.END`` for split point *after* such. -:: +If we ``print(pipe)``, we can see:: - ************************************* pipe ************************************* GraphModule( (submod_0): GraphModule( - (layer0): InterpreterModule( - (lin): InterpreterModule() + (emb): InterpreterModule() + (layers): Module( + (0): InterpreterModule( + (lin): InterpreterModule() + ) ) ) (submod_1): GraphModule( - (layer1): InterpreterModule( - (lin): InterpreterModule() + (layers): Module( + (1): InterpreterModule( + (lin): InterpreterModule() + ) ) - ) - (submod_2): GraphModule( - (layer2): InterpreterModule( - (lin): InterpreterModule() + (lm): InterpreterModule( + (proj): InterpreterModule() ) - (output_proj): InterpreterModule() ) ) - def forward(self, arg8_1): - submod_0 = self.submod_0(arg8_1); arg8_1 = None + def forward(self, x): + submod_0 = self.submod_0(x); x = None submod_1 = self.submod_1(submod_0); submod_0 = None - submod_2 = self.submod_2(submod_1); submod_1 = None - return (submod_2,) + return (submod_1,) -So what's going on here? First, ``pipeline`` turns our model into a directed acyclic graph (DAG) by tracing the model. Then, it groups together the operations and parameters into *pipeline stages*. Stages are represented as ``submod_N`` submodules, where ``N`` is a natural number. -We used ``annotate_split_points`` to specify that the code should be split and the end of ``layer0`` and ``layer1``. Our code has thus been split into *three* pipeline stages. Our library also provides ``SplitPoint.BEGINNING`` if a user wants to split before certain annotation point. +The "model partitions" are represented by submodules (``submod_0``, +``submod_1``), each of which is reconstructed with original model operations, weights +and hierarchies. In addition, a "root-level" ``forward`` function is +reconstructed to capture the data flow between those partitions. Such data flow +will be replayed by the pipeline runtime later, in a distributed fashion. -While the ``annotate_split_points`` API gives users a way to specify the split points without modifying the model, our library also provides an API for in-model annotation: ``pipe_split()``. For details, you can read `this example `_. +The ``Pipe`` object provides a method for retrieving the "model partitions": -This covers the basic usage of the ``Pipe`` API. For more information, please see the documentation. +.. code-block:: python -Using ``PipelineSchedule`` for Execution -======================================== + stage_mod : nn.Module = pipe.get_stage_module(stage_idx) -After transforming the model into a ``Pipe`` representation, we can run its stages in a distributed *runtime*. This can be done in two steps: -* instantiate a ``PipelineStage`` from a stage module of ``Pipe``; -* run the ``PipelineStage`` according to a ``PipelineSchedule``. +The returned ``stage_mod`` is a ``nn.Module``, with which you can create an +optimizer, save or load checkpoints, or apply other parallelisms. -First off, let us instantiate a ``PipelineStage`` instance: +``Pipe`` also allows you to create a distributed stage runtime on a device given +a ``ProcessGroup``: .. code-block:: python - # We are using `torchrun` to run this example with multiple processes. - # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) + stage = pipe.build_stage(stage_idx, device, group) - # Initialize distributed environment - import torch.distributed as dist - dist.init_process_group(rank=rank, world_size=world_size) +Alternatively, if you would like to build the stage runtime later after some +modification to the ``stage_mod``, you can use a functional version of the +``build_stage`` API. For example: - # Pipeline stage is our main pipeline runtime. It takes in the pipe object, - # the rank of this process, and the device. - from torch.distributed.pipelining import PipelineStage - stage = PipelineStage(pipe, rank, device) +.. code-block:: python -We can now attach the ``PipelineStage`` to a pipeline schedule, GPipe for example, and run with data: + from torch.distributed.pipelining import build_stage + from torch.nn.parallel import DistributedDataParallel -.. code-block:: python + dp_mod = DistributedDataParallel(stage_mod) + info = pipe.info() + stage = build_stage(dp_mod, stage_idx, info, device, group) - from torch.distributed.pipelining import ScheduleGPipe - schedule = ScheduleGPipe(stage, chunks) +.. note:: + The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your + model into a single graph. If your model is not full-graph'able, you can use + our manual frontend below. - # Input data - x = torch.randn(batch_size, in_dim, device=device) - # Run the pipeline with input `x`. Divide the batch into 4 micro-batches - # and run them in parallel on the pipeline - if rank == 0: - schedule.step(x) - else: - output = schedule.step() +Hugging Face Examples +********************* -Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use ``torchrun`` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named `example.py `_ and then run it with ``torchrun`` like so: +In the `PiPPy `_ repo where this package was +original created, we kept examples based on unmodified Hugging Face models. +See the `examples/huggingface +`_ directory. -.. code-block:: bash +Examples include: + +* `GPT2 `_ +* `Llama `_ + + +Technical Deep Dive +******************* + +How does the ``pipeline`` API split a model? +============================================ + +First, the ``pipeline`` API turns our model into a directed acyclic graph (DAG) +by tracing the model. It traces the model using ``torch.export`` -- a PyTorch 2 +full-graph capturing tool. + +Then, it groups together the **operations and parameters** needed by a stage +into a reconstructed submodule: ``submod_0``, ``submod_1``, ... + +Different from conventional submodule access methods like ``Module.children()``, +the ``pipeline`` API does not only cut the module structure of your model, but +also the **forward** function of your model. + +This is necessary because model structure like ``Module.children()`` merely +captures information during ``Module.__init__()``, and does not capture any +information about ``Module.forward()``. Said differently, ``Module.children()`` +lacks information about the following aspects key to pipelininig: + +* Execution order of child modules in ``forward`` +* Activation flows between child modules +* Whether there are any functional operators between child modules (for example, + ``relu`` or ``add`` operations will not be captured by ``Module.children()``). - torchrun --nproc_per_node=3 example.py +The ``pipeline`` API, on the contrary, makes sure that the ``forward`` behavior +is truly preserved. It also captures the activation flow between the partitions, +helping the distributed runtime to make correct send/receive calls without human +intervention. + +Another flexibility of the ``pipeline`` API is that split points can be at +arbitrary levels within your model hierarchy. In the split partitions, the original model +hierarchy related to that partition will be reconstructed at no cost to you. +At a result, fully-qualified names (FQNs) pointing to a submodule or parameter +would be still valid, and services that relies on FQNs (such as FSDP, TP or +checkpointing) can still run with your partitioned modules with almost zero code +change. + + +Implementing Your Own Schedule +****************************** + +You can implement your own pipeline schedule by extending one of the following two class: + +* ``PipelineScheduleSingle`` +* ``PipelineScheduleMulti`` + +``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. +``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. -Pipeline Transformation APIs +For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. +Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. + + +API Reference +************* + +.. automodule:: torch.distributed.pipelining + +Model Split APIs ============================ The following set of APIs transform your model into a pipeline representation. @@ -194,14 +435,8 @@ The following set of APIs transform your model into a pipeline representation. .. autoclass:: Pipe -.. autofunction:: annotate_split_points - .. autofunction:: pipe_split -.. autoclass:: ArgsChunkSpec - -.. autoclass:: KwargsChunkSpec - Microbatch Utilities ==================== @@ -218,20 +453,20 @@ Microbatch Utilities Pipeline Stages =============== -.. automodule:: torch.distributed.pipelining.PipelineStage +.. automodule:: torch.distributed.pipelining.stage -.. currentmodule:: torch.distributed.pipelining.PipelineStage +.. currentmodule:: torch.distributed.pipelining.stage .. autoclass:: PipelineStage -.. autoclass:: ManualPipelineStage +.. autofunction:: build_stage Pipeline Schedules ================== -.. automodule:: torch.distributed.pipelining.PipelineSchedule +.. automodule:: torch.distributed.pipelining.schedules -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule +.. currentmodule:: torch.distributed.pipelining.schedules .. autoclass:: ScheduleGPipe @@ -241,22 +476,8 @@ Pipeline Schedules .. autoclass:: ScheduleLoopedBFS -Implementing Your Own Schedule -============================== - -You can implement your own pipeline schedule by extending one of the following two class: - -* ``PipelineScheduleSingle`` -* ``PipelineScheduleMulti`` - -``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. -``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. - -For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. -Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. - -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule - .. autoclass:: PipelineScheduleSingle + :members: .. autoclass:: PipelineScheduleMulti + :members: diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 0b091d567031..f4c73b9381e5 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,9 +876,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.pipeline -.. py:module:: torch.distributed.pipeline.sync -.. py:module:: torch.distributed.pipeline.sync.skip .. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks @@ -964,22 +961,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.optim.post_localSGD_optimizer .. py:module:: torch.distributed.optim.utils .. py:module:: torch.distributed.optim.zero_redundancy_optimizer -.. py:module:: torch.distributed.pipeline.sync.batchnorm -.. py:module:: torch.distributed.pipeline.sync.checkpoint -.. py:module:: torch.distributed.pipeline.sync.copy -.. py:module:: torch.distributed.pipeline.sync.dependency -.. py:module:: torch.distributed.pipeline.sync.microbatch -.. py:module:: torch.distributed.pipeline.sync.phony -.. py:module:: torch.distributed.pipeline.sync.pipe -.. py:module:: torch.distributed.pipeline.sync.pipeline -.. py:module:: torch.distributed.pipeline.sync.skip.layout -.. py:module:: torch.distributed.pipeline.sync.skip.namespace -.. py:module:: torch.distributed.pipeline.sync.skip.portal -.. py:module:: torch.distributed.pipeline.sync.skip.skippable -.. py:module:: torch.distributed.pipeline.sync.skip.tracker -.. py:module:: torch.distributed.pipeline.sync.stream -.. py:module:: torch.distributed.pipeline.sync.utils -.. py:module:: torch.distributed.pipeline.sync.worker .. py:module:: torch.distributed.remote_device .. py:module:: torch.distributed.rendezvous .. py:module:: torch.distributed.rpc.api diff --git a/docs/source/elastic/control_plane.rst b/docs/source/elastic/control_plane.rst new file mode 100644 index 000000000000..c37454cf1b0a --- /dev/null +++ b/docs/source/elastic/control_plane.rst @@ -0,0 +1,10 @@ +Control Plane +============= + +.. automodule:: torch.distributed.elastic.control_plane +.. currentmodule:: torch.distributed.elastic.control_plane + +This module contains optional helpers that add extra debug and control handlers +into your application. + +.. autofunction:: torch.distributed.elastic.control_plane.worker_main diff --git a/docs/source/elastic/events.rst b/docs/source/elastic/events.rst index 86d0be8dad52..c32136d00302 100644 --- a/docs/source/elastic/events.rst +++ b/docs/source/elastic/events.rst @@ -10,6 +10,8 @@ API Methods .. autofunction:: torch.distributed.elastic.events.record +.. autofunction:: torch.distributed.elastic.events.construct_and_record_rdzv_event + .. autofunction:: torch.distributed.elastic.events.get_logging_handler Event Objects diff --git a/docs/source/export.rst b/docs/source/export.rst index a4217e8081ba..29069d3228e4 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -632,23 +632,17 @@ number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support :ref:`torch.cond ` to express if-else like control flow (more coming soon!). -Missing Meta Kernels for Operators -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Missing Fake/Meta/Abstract Kernels for Operators +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When tracing, a META implementation (or "meta kernel") is required for all -operators. This is used to reason about the input/output shapes for this -operator. +When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is +required for all operators. This is used to reason about the input/output shapes +for this operator. -To register a meta kernel for a C++ Custom Operator, please refer to -`this documentation `__. - -The official API for registering custom meta kernels for custom ops implemented -in python is currently undergoing development. While the final API is being -refined, you can refer to the documentation -`here `_. +Please see :func:`torch.library.register_fake` for more details. In the unfortunate case where your model uses an ATen operator that is does not -have a meta kernel implementation yet, please file an issue. +have a FakeTensor kernel implementation yet, please file an issue. Read More @@ -689,6 +683,7 @@ API Reference .. automethod:: dynamic_shapes +.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes .. autoclass:: Constraint .. autoclass:: ExportedProgram diff --git a/docs/source/fx.rst b/docs/source/fx.rst index e9b7cd2d5723..0a0af6254a5d 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -1175,6 +1175,7 @@ API Reference .. py:module:: torch.fx.passes.fake_tensor_prop .. py:module:: torch.fx.passes.graph_drawer .. py:module:: torch.fx.passes.graph_manipulation +.. py:module:: torch.fx.passes.graph_transform_observer .. py:module:: torch.fx.passes.infra.partitioner .. py:module:: torch.fx.passes.infra.pass_base .. py:module:: torch.fx.passes.infra.pass_manager diff --git a/docs/source/index.rst b/docs/source/index.rst index ea704f20c3af..dcaadcbb63ed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,7 +103,6 @@ Features described in this documentation are classified by release status: optim complex_numbers ddp_comm_hooks - pipeline quantization rpc torch.random diff --git a/docs/source/library.rst b/docs/source/library.rst index 236da45f93c1..f632d93d1ec4 100644 --- a/docs/source/library.rst +++ b/docs/source/library.rst @@ -1,3 +1,5 @@ +.. _torch-library-docs: + torch.library =================================== .. py:module:: torch.library @@ -9,7 +11,8 @@ custom operators, and extending operators defined with PyTorch's C++ operator registration APIs (e.g. aten operators). For a detailed guide on effectively using these APIs, please see -`this gdoc `_ +Please see :ref:`custom-ops-landing-page` +for more details on how to effectively use these APIs. Testing custom ops ------------------ diff --git a/docs/source/mps.rst b/docs/source/mps.rst index bab0d3378ea8..86195242566f 100644 --- a/docs/source/mps.rst +++ b/docs/source/mps.rst @@ -17,6 +17,7 @@ torch.mps set_per_process_memory_fraction current_allocated_memory driver_allocated_memory + recommended_max_memory MPS Profiler ------------ diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index f2f5b5195dcb..b729c061b2a6 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined init is_available is_initialized + set_device set_stream stream synchronize diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index 81ebf64bc43a..f070f2204183 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -463,7 +463,7 @@ functions are used in the research community since complex numbers are not part ordered field and so having complex valued loss does not make much sense. It also turns out that no interesting real-valued objective fulfill the -Cauchy-Riemann equations. So the theory with homomorphic function cannot be +Cauchy-Riemann equations. So the theory with holomorphic function cannot be used for optimization and most people therefore use the Wirtinger calculus. Wirtinger Calculus comes into the picture ... @@ -602,7 +602,7 @@ Solving the above equations for :math:`\frac{\partial L}{\partial u}` and :math: .. math:: \begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ - \frac{\partial L}{\partial v} = -1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) + \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned} :label: [3] @@ -610,9 +610,9 @@ Substituting :eq:`[3]` in :eq:`[1]`, we get: .. math:: \begin{aligned} - \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} - 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ + \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ - &= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\ + &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned} diff --git a/docs/source/notes/custom_operators.rst b/docs/source/notes/custom_operators.rst new file mode 100644 index 000000000000..2cdf214351b0 --- /dev/null +++ b/docs/source/notes/custom_operators.rst @@ -0,0 +1,56 @@ +.. _custom-ops-landing-page: + +PyTorch Custom Operators Landing Page +===================================== + +PyTorch offers a large library of operators that work on Tensors (e.g. :func:`torch.add`, +:func:`torch.sum`, etc). However, you may wish to bring a new custom operation to PyTorch +and get it to work with subsystems like :func:`torch.compile`, autograd, and :func:`torch.vmap`. +In order to do so, you must register the custom operation with PyTorch via the Python +:ref:`torch-library-docs` or C++ TORCH_LIBRARY APIs. + +TL;DR +----- + +How do I author a custom op from Python? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. + [comment] TODO(rzou): The following will be a link to a tutorial on the PyTorch tutorials site in 2.4 + +Please see the `Python Custom Operators tutorial `_ + + +How do I integrate custom C++ and/or CUDA code with PyTorch? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. + [comment] TODO(rzou): The following will be a link to a tutorial on the PyTorch tutorials site in 2.4 + +Please see the `Custom C++ and CUDA Operators tutorial `_ + + +For more details +^^^^^^^^^^^^^^^^ + +Please see `The Custom Operators Manual (gdoc) `_ +(we're working on moving the information to our docs site). We recommend that you +first read one of the tutorials above and then use the Custom Operators Manual as a reference; +it is not meant to be read head to toe. + +When should I create a Custom Operator? +--------------------------------------- +If your operation is expressible as a composition of built-in PyTorch operators +then please write it as a Python function and call it instead of creating a +custom operator. Use the operator registration APIs to create a custom op if you +are calling into some library that PyTorch doesn't understand (e.g. custom C/C++ code, +a custom CUDA kernel, or Python bindings to C/C++/CUDA extensions). + +Why should I create a Custom Operator? +-------------------------------------- + +It is possible to use a C/C++/CUDA kernel by grabbing a Tensor's data pointer +and passing it to a pybind'ed kernel. However, this approach doesn't compose with +PyTorch subsystems like autograd, torch.compile, vmap, and more. In order +for an operation to compose with PyTorch subsystems, it must be registered +via the operator registration APIs. diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 80796375c3fe..bf69d0e012f6 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -4,6 +4,18 @@ Extending PyTorch In this note we'll cover ways of extending :mod:`torch.nn`, :mod:`torch.autograd`, :mod:`torch`, and writing custom C++ extensions. +Adding new operators +-------------------- + +PyTorch offers a large library of operators that work on Tensors (e.g. :func:`torch.add`, +:func:`torch.sum`, etc). However, you may wish to bring a new custom operation to PyTorch +and have it behave like PyTorch's built-in operators. In order to do so, you must +register the custom operation with PyTorch via the Python :ref:`torch-library-docs` or C++ TORCH_LIBRARY +APIs. + + +Please see :ref:`custom-ops-landing-page` for more details. + .. _extending-autograd: Extending :mod:`torch.autograd` @@ -968,13 +980,3 @@ Which prints the following, with extra comments:: Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{}) Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{}) - - -Writing custom C++ extensions ------------------------------ - -See this -`PyTorch tutorial `_ -for a detailed explanation and examples. - -Documentations are available at :doc:`../cpp_extension`. diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst deleted file mode 100644 index 94d730ee223d..000000000000 --- a/docs/source/pipeline.rst +++ /dev/null @@ -1,85 +0,0 @@ -.. _pipeline-parallelism: - -Pipeline Parallelism -==================== - -Pipeline parallelism was original introduced in the -`Gpipe `__ paper and is an efficient -technique to train large models on multiple GPUs. - -.. warning :: - torch.distributed.pipeline is deprecated, so is this document. For - up-to-date pipeline parallel implementation, please refer to the - `PiPPy `__ library under the PyTorch - organization (Pipeline Parallelism for PyTorch). - -Model Parallelism using multiple GPUs -------------------------------------- - -Typically for large models which don't fit on a single GPU, model parallelism -is employed where certain parts of the model are placed on different GPUs. -Although, if this is done naively for sequential models, the training process -suffers from GPU under utilization since only one GPU is active at one time as -shown in the figure below: - -.. figure:: _static/img/pipeline_parallelism/no_pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that only 1 GPU is utilized at a time - (`image source `__). - -Pipelined Execution -------------------- - -To alleviate this problem, pipeline parallelism splits the input minibatch into -multiple microbatches and pipelines the execution of these microbatches across -multiple GPUs. This is outlined in the figure below: - -.. figure:: _static/img/pipeline_parallelism/pipe.png - - The figure represents a model with 4 layers placed on 4 different GPUs - (vertical axis). The horizontal axis represents training this model through - time demonstrating that the GPUs are utilized much more efficiently. - However, there still exists a bubble (as demonstrated in the figure) where - certain GPUs are not utilized. - (`image source `__). - -Pipe APIs in PyTorch --------------------- -.. autoclass:: torch.distributed.pipeline.sync.Pipe - :members: forward - -Skip connections -^^^^^^^^^^^^^^^^ - -Certain models like `ResNeXt `__ -are not completely sequential and have skip connections between layers. -Naively implementing as part of pipeline parallelism would imply that -we need to copy outputs for certain layers through multiple GPUs till -we eventually reach the GPU where the layer for the skip connection resides. -To avoid this copy overhead, we provide APIs below to stash and pop Tensors -in different layers of the model. - -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash -.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop -.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables - -Tutorials ---------- - -The following tutorials give a good overview of how to use the -:class:`~torch.distributed.pipeline.sync.Pipe` API to train your models with the -rest of the components that PyTorch provides: - -- `Training Transformer models using Pipeline Parallelism `__ -- `Training Transformer models using Distributed Data Parallel and Pipeline Parallelism `__ - -Acknowledgements ----------------- - -The implementation for pipeline parallelism is based on `fairscale's pipe implementation `__ and -`torchgpipe `__. We would like to -thank both teams for their contributions and guidance towards bringing pipeline -parallelism into PyTorch. diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 218c83d0a373..3f9a96ac7da6 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -212,6 +212,37 @@ Tensor class reference (see :ref:`tensor-creation-ops`). - To create a tensor with similar type but different size as another tensor, use ``tensor.new_*`` creation ops. + - There is a legacy constructor ``torch.Tensor`` whose use is discouraged. + Use :func:`torch.tensor` instead. + +.. method:: Tensor.__init__(self, data) + + This constructor is deprecated, we recommend using :func:`torch.tensor` instead. + What this constructor does depends on the type of ``data``. + + * If ``data`` is a Tensor, returns an alias to the original Tensor. Unlike + :func:`torch.tensor`, this tracks autograd and will propagate gradients to + the original Tensor. ``device`` kwarg is not supported for this ``data`` type. + + * If ``data`` is a sequence or nested sequence, create a tensor of the default + dtype (typically ``torch.float32``) whose data is the values in the + sequences, performing coercions if necessary. Notably, this differs from + :func:`torch.tensor` in that this constructor will always construct a float + tensor, even if the inputs are all integers. + + * If ``data`` is a :class:`torch.Size`, returns an empty tensor of that size. + + This constructor does not support explicitly specifying ``dtype`` or ``device`` of + the returned tensor. We recommend using :func:`torch.tensor` which provides this + functionality. + + Args: + data (array_like): The tensor to construct from. + + Keyword args: + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + .. autoattribute:: Tensor.T .. autoattribute:: Tensor.H @@ -748,4 +779,5 @@ Tensor class reference Tensor.where Tensor.xlogy Tensor.xlogy_ + Tensor.xpu Tensor.zero_ diff --git a/docs/source/torch.compiler_aot_inductor.rst b/docs/source/torch.compiler_aot_inductor.rst index 0ebd03bbcecf..257f16f40cc0 100644 --- a/docs/source/torch.compiler_aot_inductor.rst +++ b/docs/source/torch.compiler_aot_inductor.rst @@ -37,7 +37,9 @@ For more details on ``torch.export``, you can refer to the :ref:`torch.export do If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, the following code will compile the model into a shared library for CUDA execution. - Otherwise, the compiled artifact will run on CPU. + Otherwise, the compiled artifact will run on CPU. For better performance during CPU inference, + it is suggested to enable freezing by setting `export TORCHINDUCTOR_FREEZING=1` + before running the Python script below. .. code-block:: python diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst index aeaf308ac090..a5883ce015be 100644 --- a/docs/source/torch.compiler_faq.rst +++ b/docs/source/torch.compiler_faq.rst @@ -37,7 +37,7 @@ backwards ops, due to how AOTAutograd compiled functions interact with dispatcher hooks. The basic strategy for optimizing DDP with Dynamo is outlined in -`distributed.py `__ +`distributed.py `__ where the main idea will be to graph break on `DDP bucket boundaries `__. @@ -186,7 +186,7 @@ The above are general principles for accelerating PyTorch code but different backends will each make different tradeoffs on what to optimize. For example Inductor first takes care of fusing whatever it can and only then generates `Triton `__ -kernels. It can also +kernels. Triton in addition offers speedups because of automatic memory coalescing, memory management and scheduling within each Streaming diff --git a/docs/source/torch.compiler_get_started.rst b/docs/source/torch.compiler_get_started.rst index 624b351d6fa8..caec0760acc7 100644 --- a/docs/source/torch.compiler_get_started.rst +++ b/docs/source/torch.compiler_get_started.rst @@ -64,7 +64,7 @@ the following: xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex - tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp0 = tl.load(in_ptr0 + (x0), xmask, other=0.0) tmp1 = tl.cos(tmp0) tmp2 = tl.sin(tmp1) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) diff --git a/docs/source/torch_environment_variables.rst b/docs/source/torch_environment_variables.rst index f63760de87e9..04feed91de4a 100644 --- a/docs/source/torch_environment_variables.rst +++ b/docs/source/torch_environment_variables.rst @@ -24,3 +24,4 @@ If you find anything in this documentation that is missing, incorrect, or could debugging_environment_variables miscellaneous_environment_variables logging + torch_nccl_environment_variables diff --git a/docs/source/torch_nccl_environment_variables.rst b/docs/source/torch_nccl_environment_variables.rst new file mode 100644 index 000000000000..a2498027e7ff --- /dev/null +++ b/docs/source/torch_nccl_environment_variables.rst @@ -0,0 +1,35 @@ +.. _torch_nccl_environment_variables: + +PYTORCH ProcessGroupNCCL Environment Variables +============================================== +For more information on the environment variables, see `ProcessGroupNCCL Environment Variables `_. + +.. list-table:: + :header-rows: 1 + + * - Variable + - Description + * - ``TORCH_NCCL_HIGH_PRIORITY`` + - Control whether to use high priority stream for the NCCL communicator. + * - ``TORCH_NCCL_BLOCKING_WAIT`` + - Control whether or not wait() is blocking or non-blocking. + * - ``TORCH_NCCL_DUMP_ON_TIMEOUT`` + - Control whether dumping debug info on watchdog timeout or exception is detected. This variable must be set together with TORCH_NCCL_TRACE_BUFFER_SIZE larger than 0. + * - ``TORCH_NCCL_DESYNC_DEBUG`` + - Control whether Desync Debug is enabled. This is helpful in figuring out the culprit rank of collective desync. + * - ``TORCH_NCCL_ENABLE_TIMING`` + - If set to ``1``, enable recording start-events for all ProcessGroupNCCL collectives, and compute accurate collective timing per-collective. + * - ``TORCH_NCCL_ENABLE_MONITORING`` + - If set to ``1``,enable monitoring thread which aborts the process when the ProcessGroupNCCL Watchdog thread gets stuck and no heartbeat is detected after TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged time than necessary tying up cluster resources. + * - ``TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC`` + - Control the watchdog heartbeat timeout period after which the monitoring thread will abort the process. + * - ``TORCH_NCCL_TRACE_BUFFER_SIZE`` + - The maximum number of events we store in the flight recorder's ring buffer. One event could be the start or end of a collective, for example. Set to 0 to disable the tracebuffer and debugging info dump. + * - ``TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC`` + - Control how much extra time we will wait for dumping the debugging info before we exit and throws timeout exception. + * - ``TORCH_NCCL_DEBUG_INFO_TEMP_FILE`` + - The file into which the debugging info would be dumped. + * - ``TORCH_NCCL_DEBUG_INFO_PIPE_FILE`` + - The pipe file to trigger debugging dump manually, write anything into the pipe would trigger the dump. + * - ``TORCH_NCCL_NAN_CHECK`` + - Control whether to enable NAN check for the input, Error would be thrown if NAN is detected. diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 066f9517acef..7f5564c13664 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -1640,16 +1640,6 @@ static PyObject* _dims(PyObject *self, PY_END(nullptr) } -static int64_t dim_index(const std::vector>& dims, mpy::hdl dim) { - for (int64_t i = 0, N = dims.size(); i < N; ++i) { - if (dims[i].ptr() == dim.ptr()) { - return i; - } - } - return -1; -} - - struct DotPart { Slice dims; size_t total_size = 1; diff --git a/functorch/csrc/dim/minpybind.h b/functorch/csrc/dim/minpybind.h index de82b5af95a4..f1eb87265372 100644 --- a/functorch/csrc/dim/minpybind.h +++ b/functorch/csrc/dim/minpybind.h @@ -385,10 +385,6 @@ bool is_int(handle h) { return PyLong_Check(h.ptr()); } -bool is_float(handle h) { - return PyFloat_Check(h.ptr()); -} - bool is_none(handle h) { return h.ptr() == Py_None; } diff --git a/mypy.ini b/mypy.ini index 48bd363ef6d1..c4fef0f5ba6f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,7 @@ warn_redundant_casts = True show_error_codes = True show_column_numbers = True check_untyped_defs = True +disallow_untyped_defs = True follow_imports = normal local_partial_types = True enable_error_code = possibly-undefined @@ -294,3 +295,10 @@ ignore_missing_imports = True [mypy-torch_xla.*] ignore_missing_imports = True + +# +# Third party dependencies that are optional. +# + +[mypy-redis] +ignore_missing_imports = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 133c86047606..24a917b80847 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ ignore = [ # these ignores are from flake8-bugbear; please fix! "B007", "B008", "B017", "B018", # Useless expression - "B019", "B023", "B028", # No explicit `stacklevel` keyword argument found "E402", @@ -69,7 +68,6 @@ ignore = [ "PERF401", "PERF403", # these ignores are from PYI; please fix! - "PYI019", "PYI024", "PYI036", "PYI041", @@ -126,12 +124,15 @@ select = [ "PT025", "PT026", "PYI", + "Q003", # avoidable escaped quote + "Q004", # unnecessary escaped quote "RSE", "RUF008", # mutable dataclass default "RUF015", # access first ele in constant time "RUF016", # type error non-integer index "RUF017", "RUF018", # no assignment in assert + "TCH", "TRY002", # ban vanilla raise (todo fix NOQAs) "TRY302", "TRY401", # verbose-log-message @@ -175,6 +176,10 @@ select = [ # autogenerated #TODO figure out why file level noqa is ignored "torch/_inductor/fx_passes/serialized_patterns/**" = ["F401", "F501"] "torch/onnx/**" = [ + "TCH001", # beartype may need runtime types + "TCH002", + "TCH003", + "TCH004", "UP037", # ONNX does runtime type checking ] diff --git a/requirements.txt b/requirements.txt index 09259eb5c23c..cc1616a1d99c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ networkx jinja2 fsspec lintrunner +ninja # setuptools was removed from default python install setuptools ; python_version >= "3.12" packaging diff --git a/scripts/appveyor/install.bat b/scripts/appveyor/install.bat deleted file mode 100644 index cd87d6273160..000000000000 --- a/scripts/appveyor/install.bat +++ /dev/null @@ -1,10 +0,0 @@ -:: Installation scripts for appveyor. - -@echo on - -if "%USE_CUDA%" == "ON" call %~dp0%install_cuda.bat - -:: Miniconda path for appveyor -set PATH=C:\Miniconda-x64;C:\Miniconda-x64\Scripts;%PATH% -:: Install numpy -conda install -y numpy diff --git a/scripts/appveyor/install_cuda.bat b/scripts/appveyor/install_cuda.bat deleted file mode 100644 index c8c86b002e5b..000000000000 --- a/scripts/appveyor/install_cuda.bat +++ /dev/null @@ -1,22 +0,0 @@ -@echo on - -appveyor DownloadFile ^ - https://developer.nvidia.com/compute/cuda/8.0/prod/local_installers/cuda_8.0.44_windows-exe ^ - -FileName cuda_8.0.44_windows.exe -appveyor Downloadfile ^ - http://developer.download.nvidia.com/compute/redist/cudnn/v5.1/cudnn-8.0-windows10-x64-v5.1.zip ^ - -FileName cudnn-8.0-windows10-x64-v5.1.zip - -cuda_8.0.44_windows.exe -s compiler_8.0 cublas_8.0 cublas_dev_8.0 cudart_8.0 curand_8.0 curand_dev_8.0 nvrtc_8.0 nvrtc_dev_8.0 -set PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v8.0\libnvvp;%PATH% - -7z x cudnn-8.0-windows10-x64-v5.1.zip -copy cuda\include\cudnn.h ^ - "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\include\" -copy cuda\lib\x64\cudnn.lib ^ - "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\x64\" -copy cuda\bin\cudnn64_5.dll ^ - "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin\" - -:: Make sure that nvcc is working correctly. -nvcc -V || exit /b diff --git a/scripts/install_triton_wheel.sh b/scripts/install_triton_wheel.sh index 269b80d07599..793c9a604edf 100755 --- a/scripts/install_triton_wheel.sh +++ b/scripts/install_triton_wheel.sh @@ -1,11 +1,23 @@ #!/bin/bash # Updates Triton to the pinned version for this copy of PyTorch BRANCH=$(git rev-parse --abbrev-ref HEAD) -TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)" -DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl" -if [[ "$BRANCH" =~ .*release.* ]]; then - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION +if [[ -z "${USE_XPU}" ]]; then + # Default install from PyTorch source + + TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)" + DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl" + if [[ "$BRANCH" =~ .*release.* ]]; then + pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION + else + pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+$(head -c 10 .ci/docker/ci_commit_pins/triton.txt) + fi else - pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+$(head -c 10 .ci/docker/ci_commit_pins/triton.txt) + # Always install Triton for XPU from source + + TRITON_XPU_REPO="https://github.com/intel/intel-xpu-backend-for-triton" + TRITON_XPU_COMMIT_ID="$(cat .ci/docker/ci_commit_pins/triton-xpu.txt)" + + # force-reinstall to ensure the pinned version is installed + pip install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python" fi diff --git a/scripts/model_zoo/update-caffe2-models.py b/scripts/model_zoo/update-caffe2-models.py deleted file mode 100755 index 1053530d05c5..000000000000 --- a/scripts/model_zoo/update-caffe2-models.py +++ /dev/null @@ -1,175 +0,0 @@ -#! /usr/bin/env python3 - -import os -import subprocess -import sys -import tarfile -import tempfile - -from urllib.request import urlretrieve - -from caffe2.python.models.download import ( - deleteDirectory, - downloadFromURLToFile, - getURLFromName, -) - - -class SomeClass: - # largely copied from - # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py - def _download(self, model): - model_dir = self._caffe2_model_dir(model) - assert not os.path.exists(model_dir) - os.makedirs(model_dir) - for f in ["predict_net.pb", "init_net.pb", "value_info.json"]: - url = getURLFromName(model, f) - dest = os.path.join(model_dir, f) - try: - try: - downloadFromURLToFile(url, dest, show_progress=False) - except TypeError: - # show_progress not supported prior to - # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 - # (Sep 17, 2017) - downloadFromURLToFile(url, dest) - except Exception as e: - print(f"Abort: {e}") - print("Cleaning up...") - deleteDirectory(model_dir) - sys.exit(1) - - def _caffe2_model_dir(self, model): - caffe2_home = os.path.expanduser("~/.caffe2") - models_dir = os.path.join(caffe2_home, "models") - return os.path.join(models_dir, model) - - def _onnx_model_dir(self, model): - onnx_home = os.path.expanduser("~/.onnx") - models_dir = os.path.join(onnx_home, "models") - model_dir = os.path.join(models_dir, model) - return model_dir, os.path.dirname(model_dir) - - # largely copied from - # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py - def _prepare_model_data(self, model): - model_dir, models_dir = self._onnx_model_dir(model) - if os.path.exists(model_dir): - return - os.makedirs(model_dir) - url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz" - - # On Windows, NamedTemporaryFile cannot be opened for a - # second time - download_file = tempfile.NamedTemporaryFile(delete=False) - try: - download_file.close() - print(f"Start downloading model {model} from {url}") - urlretrieve(url, download_file.name) - print("Done") - with tarfile.open(download_file.name) as t: - t.extractall(models_dir) - except Exception as e: - print(f"Failed to prepare data for model {model}: {e}") - raise - finally: - os.remove(download_file.name) - - -models = [ - "bvlc_alexnet", - "densenet121", - "inception_v1", - "inception_v2", - "resnet50", - # TODO currently onnx can't translate squeezenet :( - # 'squeezenet', - "vgg16", - # TODO currently vgg19 doesn't work in the CI environment, - # possibly due to OOM - # 'vgg19' -] - - -def download_models(): - sc = SomeClass() - for model in models: - print("update-caffe2-models.py: downloading", model) - caffe2_model_dir = sc._caffe2_model_dir(model) - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - if not os.path.exists(caffe2_model_dir): - sc._download(model) - if not os.path.exists(onnx_model_dir): - sc._prepare_model_data(model) - - -def generate_models(): - sc = SomeClass() - for model in models: - print("update-caffe2-models.py: generating", model) - caffe2_model_dir = sc._caffe2_model_dir(model) - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - subprocess.check_call(["echo", model]) - with open(os.path.join(caffe2_model_dir, "value_info.json")) as f: - value_info = f.read() - subprocess.check_call( - [ - "convert-caffe2-to-onnx", - "--caffe2-net-name", - model, - "--caffe2-init-net", - os.path.join(caffe2_model_dir, "init_net.pb"), - "--value-info", - value_info, - "-o", - os.path.join(onnx_model_dir, "model.pb"), - os.path.join(caffe2_model_dir, "predict_net.pb"), - ] - ) - subprocess.check_call( - ["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir - ) - - -def upload_models(): - sc = SomeClass() - for model in models: - print("update-caffe2-models.py: uploading", model) - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - subprocess.check_call( - [ - "aws", - "s3", - "cp", - model + ".tar.gz", - f"s3://download.onnx/models/{model}.tar.gz", - "--acl", - "public-read", - ], - cwd=onnx_models_dir, - ) - - -def cleanup(): - sc = SomeClass() - for model in models: - onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) - os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz")) - - -if __name__ == "__main__": - try: - subprocess.check_call(["aws", "sts", "get-caller-identity"]) - except: - print( - "update-caffe2-models.py: please run `aws configure` manually to set up credentials" - ) - sys.exit(1) - if sys.argv[1] == "download": - download_models() - if sys.argv[1] == "generate": - generate_models() - elif sys.argv[1] == "upload": - upload_models() - elif sys.argv[1] == "cleanup": - cleanup() diff --git a/scripts/model_zoo/update-models-from-caffe2.py b/scripts/model_zoo/update-models-from-caffe2.py deleted file mode 100644 index 3d4d4d5d1c0c..000000000000 --- a/scripts/model_zoo/update-models-from-caffe2.py +++ /dev/null @@ -1,372 +0,0 @@ -#! /usr/bin/env python3 - -import argparse -import glob -import json -import os -import shutil -import tarfile -import tempfile - -from urllib.request import urlretrieve - -import boto3 -import numpy as np -import onnx -import onnx.backend -from onnx import numpy_helper - -import caffe2.python.onnx.backend -import caffe2.python.onnx.frontend -import caffe2.python.workspace as c2_workspace -from caffe2.proto import caffe2_pb2 - -from caffe2.python.models.download import ( - deleteDirectory, - downloadFromURLToFile, - getURLFromName, -) - - -"""A script converting Caffe2 models to ONNX, and updating ONNX model zoos. - -Arguments: - -v, verbose - --local-dir, where we store the ONNX and Caffe2 models - --no-cache, ignore existing models in local-dir - --clean-test-data, delete all the existing test data when updating ONNX model zoo - --add-test-data, add add-test-data sets of test data for each ONNX model - --only-local, run locally (for testing purpose) - -Examples: - # store the data in /home/username/zoo-dir, delete existing test data, ignore local cache, - # and generate 3 sets of new test data - python update-caffe2-models.py --local-dir /home/username/zoo-dir --clean-test-data --no-cache --add-test-data 3 - -""" - -# TODO: Add GPU support - - -def upload_onnx_model(model_name, zoo_dir, backup=False, only_local=False): - if only_local: - print("No uploading in local only mode.") - return - model_dir = os.path.join(zoo_dir, model_name) - suffix = "-backup" if backup else "" - if backup: - print(f"Backing up the previous version of ONNX model {model_name}...") - rel_file_name = f"{model_name}{suffix}.tar.gz" - abs_file_name = os.path.join(zoo_dir, rel_file_name) - print(f"Compressing {model_name} model to {abs_file_name}") - with tarfile.open(abs_file_name, "w:gz") as f: - f.add(model_dir, arcname=model_name) - file_size = os.stat(abs_file_name).st_size - print( - f"Uploading {abs_file_name} ({float(file_size) / 1024 / 1024} MB) to s3 cloud..." - ) - client = boto3.client("s3", "us-east-1") - transfer = boto3.s3.transfer.S3Transfer(client) - transfer.upload_file( - abs_file_name, - "download.onnx", - f"models/latest/{rel_file_name}", - extra_args={"ACL": "public-read"}, - ) - - print(f"Successfully uploaded {rel_file_name} to s3!") - - -def download_onnx_model(model_name, zoo_dir, use_cache=True, only_local=False): - model_dir = os.path.join(zoo_dir, model_name) - if os.path.exists(model_dir): - if use_cache: - upload_onnx_model(model_name, zoo_dir, backup=True, only_local=only_local) - return - else: - shutil.rmtree(model_dir) - url = f"https://s3.amazonaws.com/download.onnx/models/latest/{model_name}.tar.gz" - - download_file = tempfile.NamedTemporaryFile(delete=False) - try: - download_file.close() - print( - f"Downloading ONNX model {model_name} from {url} and save in {download_file.name} ...\n" - ) - urlretrieve(url, download_file.name) - with tarfile.open(download_file.name) as t: - print(f"Extracting ONNX model {model_name} to {zoo_dir} ...\n") - t.extractall(zoo_dir) - except Exception as e: - print(f"Failed to download/backup data for ONNX model {model_name}: {e}") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - finally: - os.remove(download_file.name) - - if not only_local: - upload_onnx_model(model_name, zoo_dir, backup=True, only_local=only_local) - - -def download_caffe2_model(model_name, zoo_dir, use_cache=True): - model_dir = os.path.join(zoo_dir, model_name) - if os.path.exists(model_dir): - if use_cache: - return - else: - shutil.rmtree(model_dir) - os.makedirs(model_dir) - - for f in ["predict_net.pb", "init_net.pb", "value_info.json"]: - url = getURLFromName(model_name, f) - dest = os.path.join(model_dir, f) - try: - try: - downloadFromURLToFile(url, dest, show_progress=False) - except TypeError: - # show_progress not supported prior to - # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 - # (Sep 17, 2017) - downloadFromURLToFile(url, dest) - except Exception as e: - print(f"Abort: {e}") - print("Cleaning up...") - deleteDirectory(model_dir) - raise - - -def caffe2_to_onnx(caffe2_model_name, caffe2_model_dir): - caffe2_init_proto = caffe2_pb2.NetDef() - caffe2_predict_proto = caffe2_pb2.NetDef() - - with open(os.path.join(caffe2_model_dir, "init_net.pb"), "rb") as f: - caffe2_init_proto.ParseFromString(f.read()) - caffe2_init_proto.name = f"{caffe2_model_name}_init" - with open(os.path.join(caffe2_model_dir, "predict_net.pb"), "rb") as f: - caffe2_predict_proto.ParseFromString(f.read()) - caffe2_predict_proto.name = caffe2_model_name - with open(os.path.join(caffe2_model_dir, "value_info.json"), "rb") as f: - value_info = json.loads(f.read()) - - print( - f"Converting Caffe2 model {caffe2_model_name} in {caffe2_model_dir} to ONNX format" - ) - onnx_model = caffe2.python.onnx.frontend.caffe2_net_to_onnx_model( - init_net=caffe2_init_proto, - predict_net=caffe2_predict_proto, - value_info=value_info, - ) - - return onnx_model, caffe2_init_proto, caffe2_predict_proto - - -def tensortype_to_ndarray(tensor_type): - shape = [] - for dim in tensor_type.shape.dim: - shape.append(dim.dim_value) - if tensor_type.elem_type == onnx.TensorProto.FLOAT: - type = np.float32 - elif tensor_type.elem_type == onnx.TensorProto.INT: - type = np.int32 - else: - raise - array = np.random.rand(*shape).astype(type) - return array - - -def generate_test_input_data(onnx_model, scale): - real_inputs_names = list( - {input.name for input in onnx_model.graph.input} - - {init.name for init in onnx_model.graph.initializer} - ) - real_inputs = [] - for name in real_inputs_names: - for input in onnx_model.graph.input: - if name == input.name: - real_inputs.append(input) - - test_inputs = [] - for input in real_inputs: - ndarray = tensortype_to_ndarray(input.type.tensor_type) - test_inputs.append((input.name, ndarray * scale)) - - return test_inputs - - -def generate_test_output_data(caffe2_init_net, caffe2_predict_net, inputs): - p = c2_workspace.Predictor(caffe2_init_net, caffe2_predict_net) - inputs_map = {input[0]: input[1] for input in inputs} - - output = p.run(inputs_map) - c2_workspace.ResetWorkspace() - return output - - -def onnx_verify(onnx_model, inputs, ref_outputs): - prepared = caffe2.python.onnx.backend.prepare(onnx_model) - onnx_inputs = [] - for input in inputs: - if isinstance(input, tuple): - onnx_inputs.append(input[1]) - else: - onnx_inputs.append(input) - onnx_outputs = prepared.run(inputs=onnx_inputs) - np.testing.assert_almost_equal(onnx_outputs, ref_outputs, decimal=3) - - -model_mapping = { - "bvlc_alexnet": "bvlc_alexnet", - "bvlc_googlenet": "bvlc_googlenet", - "bvlc_reference_caffenet": "bvlc_reference_caffenet", - "bvlc_reference_rcnn_ilsvrc13": "bvlc_reference_rcnn_ilsvrc13", - "densenet121": "densenet121", - #'finetune_flickr_style': 'finetune_flickr_style', - "inception_v1": "inception_v1", - "inception_v2": "inception_v2", - "resnet50": "resnet50", - "shufflenet": "shufflenet", - "squeezenet": "squeezenet_old", - #'vgg16': 'vgg16', - "vgg19": "vgg19", - "zfnet512": "zfnet512", -} - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Update the ONNX models.") - parser.add_argument("-v", action="store_true", default=False, help="verbose") - parser.add_argument( - "--local-dir", - type=str, - default=os.path.expanduser("~"), - help="local dir to store Caffe2 and ONNX models", - ) - parser.add_argument( - "--no-cache", - action="store_true", - default=False, - help="whether use local ONNX models", - ) - parser.add_argument( - "--clean-test-data", - action="store_true", - default=False, - help="remove the old test data", - ) - parser.add_argument( - "--add-test-data", type=int, default=0, help="add new test data" - ) - parser.add_argument( - "--only-local", - action="store_true", - default=False, - help="no upload including backup", - ) - - args = parser.parse_args() - delete_test_data = args.clean_test_data - add_test_data = args.add_test_data - use_cache = not args.no_cache - only_local = args.only_local - - root_dir = args.local_dir - caffe2_zoo_dir = os.path.join(root_dir, ".caffe2", "models") - onnx_zoo_dir = os.path.join(root_dir, ".onnx", "models") - - for onnx_model_name in model_mapping: - c2_model_name = model_mapping[onnx_model_name] - - print( - f"####### Processing ONNX model {onnx_model_name} ({c2_model_name} in Caffe2) #######" - ) - download_caffe2_model(c2_model_name, caffe2_zoo_dir, use_cache=use_cache) - download_onnx_model( - onnx_model_name, onnx_zoo_dir, use_cache=use_cache, only_local=only_local - ) - - onnx_model_dir = os.path.join(onnx_zoo_dir, onnx_model_name) - - if delete_test_data: - print("Deleting all the existing test data...") - # NB: For now, we don't delete the npz files. - # for f in glob.glob(os.path.join(onnx_model_dir, '*.npz')): - # os.remove(f) - for f in glob.glob(os.path.join(onnx_model_dir, "test_data_set*")): - shutil.rmtree(f) - - onnx_model, c2_init_net, c2_predict_net = caffe2_to_onnx( - c2_model_name, os.path.join(caffe2_zoo_dir, c2_model_name) - ) - - print(f"Deleteing old ONNX {onnx_model_name} model...") - for f in glob.glob(os.path.join(onnx_model_dir, "model*".format())): - os.remove(f) - - print(f"Serializing generated ONNX {onnx_model_name} model ...") - with open(os.path.join(onnx_model_dir, "model.onnx"), "wb") as file: - file.write(onnx_model.SerializeToString()) - - print(f"Verifying model {onnx_model_name} with ONNX model checker...") - onnx.checker.check_model(onnx_model) - - total_existing_data_set = 0 - print(f"Verifying model {onnx_model_name} with existing test data...") - for f in glob.glob(os.path.join(onnx_model_dir, "*.npz")): - test_data = np.load(f, encoding="bytes") - inputs = list(test_data["inputs"]) - ref_outputs = list(test_data["outputs"]) - onnx_verify(onnx_model, inputs, ref_outputs) - total_existing_data_set += 1 - for f in glob.glob(os.path.join(onnx_model_dir, "test_data_set*")): - inputs = [] - inputs_num = len(glob.glob(os.path.join(f, "input_*.pb"))) - for i in range(inputs_num): - tensor = onnx.TensorProto() - with open(os.path.join(f, f"input_{i}.pb"), "rb") as pf: - tensor.ParseFromString(pf.read()) - inputs.append(numpy_helper.to_array(tensor)) - ref_outputs = [] - ref_outputs_num = len(glob.glob(os.path.join(f, "output_*.pb"))) - for i in range(ref_outputs_num): - tensor = onnx.TensorProto() - with open(os.path.join(f, f"output_{i}.pb"), "rb") as pf: - tensor.ParseFromString(pf.read()) - ref_outputs.append(numpy_helper.to_array(tensor)) - onnx_verify(onnx_model, inputs, ref_outputs) - total_existing_data_set += 1 - - starting_index = 0 - while os.path.exists( - os.path.join(onnx_model_dir, f"test_data_set_{starting_index}") - ): - starting_index += 1 - - if total_existing_data_set == 0 and add_test_data == 0: - add_test_data = 3 - total_existing_data_set = 3 - - print(f"Generating {add_test_data} sets of new test data...") - for i in range(starting_index, add_test_data + starting_index): - data_dir = os.path.join(onnx_model_dir, f"test_data_set_{i}") - os.makedirs(data_dir) - inputs = generate_test_input_data(onnx_model, 255) - ref_outputs = generate_test_output_data(c2_init_net, c2_predict_net, inputs) - onnx_verify(onnx_model, inputs, ref_outputs) - for index, input in enumerate(inputs): - tensor = numpy_helper.from_array(input[1]) - with open(os.path.join(data_dir, f"input_{index}.pb"), "wb") as file: - file.write(tensor.SerializeToString()) - for index, output in enumerate(ref_outputs): - tensor = numpy_helper.from_array(output) - with open(os.path.join(data_dir, f"output_{index}.pb"), "wb") as file: - file.write(tensor.SerializeToString()) - - del onnx_model - del c2_init_net - del c2_predict_net - - upload_onnx_model( - onnx_model_name, onnx_zoo_dir, backup=False, only_local=only_local - ) - - print("\n\n") diff --git a/setup.py b/setup.py index 65d81b4b01cb..07d80a7e1392 100644 --- a/setup.py +++ b/setup.py @@ -226,18 +226,6 @@ def _get_package_path(package_name): BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1" BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1" - -# set up appropriate env variables -if BUILD_LIBTORCH_WHL: - # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so - # functorch is not supported without python - os.environ["BUILD_FUNCTORCH"] = "OFF" - - -if BUILD_PYTHON_ONLY: - os.environ["BUILD_LIBTORCHLESS"] = "ON" - os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('libtorch')}/lib" - python_min_version = (3, 8, 0) python_min_version_str = ".".join(map(str, python_min_version)) if sys.version_info < python_min_version: @@ -265,9 +253,26 @@ def _get_package_path(package_name): from tools.build_pytorch_libs import build_caffe2 from tools.generate_torch_version import get_torch_version from tools.setup_helpers.cmake import CMake -from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS +from tools.setup_helpers.env import ( + build_type, + IS_DARWIN, + IS_LINUX, + IS_WINDOWS, + LIBTORCH_PKG_NAME, +) from tools.setup_helpers.generate_linker_script import gen_linker_script +# set up appropriate env variables +if BUILD_LIBTORCH_WHL: + # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so + # functorch is not supported without python + os.environ["BUILD_FUNCTORCH"] = "OFF" + + +if BUILD_PYTHON_ONLY: + os.environ["BUILD_LIBTORCHLESS"] = "ON" + os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path(LIBTORCH_PKG_NAME)}/lib" + ################################################################################ # Parameters parsed from environment ################################################################################ @@ -342,7 +347,7 @@ def report(*args): # Version, create_version_file, and package_name ################################################################################ -DEFAULT_PACKAGE_NAME = "libtorch" if BUILD_LIBTORCH_WHL else "torch" +DEFAULT_PACKAGE_NAME = LIBTORCH_PKG_NAME if BUILD_LIBTORCH_WHL else "torch" package_name = os.getenv("TORCH_PACKAGE_NAME", DEFAULT_PACKAGE_NAME) package_type = os.getenv("PACKAGE_TYPE", "wheel") @@ -1132,8 +1137,11 @@ def main(): 'mkl>=2021.1.1,<=2021.4.0; platform_system == "Windows"', ] + if sys.version_info >= (3, 12, 0): + install_requires.append("setuptools") + if BUILD_PYTHON_ONLY: - install_requires.append("libtorch") + install_requires.append(LIBTORCH_PKG_NAME) use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", "")) if ( @@ -1442,9 +1450,9 @@ def main(): if parts[0] == "torch": modified_packages.append(DEFAULT_PACKAGE_NAME + package[len("torch") :]) packages = modified_packages - package_dir = {"libtorch": "torch"} - torch_package_dir_name = "libtorch" - package_data = {"libtorch": torch_package_data} + package_dir = {LIBTORCH_PKG_NAME: "torch"} + torch_package_dir_name = LIBTORCH_PKG_NAME + package_data = {LIBTORCH_PKG_NAME: torch_package_data} extensions = [] else: torch_package_dir_name = "torch" diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index c3d3fe2f00ec..44de9e809615 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -211,30 +211,6 @@ "torch.distributed.optim.utils": [ "Type" ], - "torch.distributed.pipeline.sync.pipe": [ - "Pipeline" - ], - "torch.distributed.pipeline.sync.skip.layout": [ - "SkipLayout", - "inspect_skip_layout" - ], - "torch.distributed.pipeline.sync.skip.portal": [ - "Context", - "Portal", - "PortalBlue", - "PortalCopy", - "PortalOrange" - ], - "torch.distributed.pipeline.sync.skip.skippable": [ - "Skippable" - ], - "torch.distributed.pipeline.sync.skip.tracker": [ - "SkipTracker", - "SkipTrackerThroughPotals", - "ThreadLocal", - "current_skip_tracker", - "use_skip_tracker" - ], "torch.distributed.remote_device": [ "Optional", "Union" @@ -1321,12 +1297,10 @@ "_weight_norm_interface", "autocast", "broadcast_shapes", - "candidate", "compiled_with_cxx11_abi", "from_dlpack", "lobpcg", "lu", - "obj", "segment_reduce", "set_default_dtype", "set_grad_enabled", @@ -1697,10 +1671,6 @@ "get_args_parser", "run" ], - "torch.distributed.pipeline.sync": [ - "NoChunk", - "WithDevice" - ], "torch.distributed.rpc.rref_proxy": [ "Future", "partial", @@ -2634,32 +2604,12 @@ "TensorPipeRpcBackendOptions" ], "torch.distributed.pipelining": [ - "ArgsChunkSpec", - "KwargsChunkSpec", "Pipe", "PipelineStage", "SplitPoint", - "annotate_split_points", "pipe_split", "pipeline" ], - "torch.distributed.pipelining.PipelineSchedule": [ - "ABC", - "Any", - "Callable", - "Dict", - "List", - "Optional", - "Pipe", - "PipelineStageBase", - "Tuple", - "Union", - "abstractmethod", - "defaultdict", - "merge_chunks", - "record_function", - "split_args_kwargs_into_chunks" - ], "torch.distributed.pipelining.microbatch": [ "Any", "Dict", diff --git a/test/conftest.py b/test/conftest.py index 9ba728689285..5b84898df8a3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,10 +7,9 @@ import xml.etree.ElementTree as ET from collections import defaultdict from types import MethodType -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, TYPE_CHECKING, Union import pytest -from _pytest._code.code import ReprFileLocation from _pytest.config import Config, filename_arg from _pytest.config.argparsing import Parser from _pytest.junitxml import _NodeReporter, bin_xml_escape, LogXML @@ -20,6 +19,9 @@ from _pytest.terminal import _get_raw_skip_reason from pytest_shard_custom import pytest_addoptions as shard_addoptions, PytestShardPlugin +if TYPE_CHECKING: + from _pytest._code.code import ReprFileLocation + # a lot of this file is copied from _pytest.junitxml and modified to get rerun info xml_key = StashKey["LogXMLReruns"]() diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index b0e296ad2309..ceeb607d52a7 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -51,9 +51,6 @@ endif() add_executable(test_api ${TORCH_API_TEST_SOURCES}) target_include_directories(test_api PRIVATE ${ATen_CPU_INCLUDE}) target_link_libraries(test_api PRIVATE torch gtest) -if(NOT MSVC) - target_compile_options_if_supported(test_api -Wno-unused-variable) -endif() if(USE_CUDA) target_compile_definitions(test_api PRIVATE "USE_CUDA") diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 1b61499c2a75..9d4d381742e1 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -806,7 +806,6 @@ TEST(SerializeTest, Optim_RMSprop) { for (const auto i : c10::irange(params1_2_.size())) { if (i != (params1_2_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); - auto key1_2_ = params1_2_[i].unsafeGetTensorImpl(); const RMSpropParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); RMSpropParamState& curr_state1_2_ = diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 2d88d3f7172d..f0510d9c81f2 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -129,9 +129,6 @@ endif(MSVC) target_link_libraries(test_jit PRIVATE ${JIT_TEST_DEPENDENCIES}) target_include_directories(test_jit PRIVATE ${ATen_CPU_INCLUDE}) -if(NOT MSVC) - target_compile_options(test_jit PRIVATE $<$:-Wno-unused-variable>) -endif() if(LINUX) #Update to target_link_options when CMake version can be upgraded diff --git a/test/cpp/lazy/test_lazy_ops_util.cpp b/test/cpp/lazy/test_lazy_ops_util.cpp index cc5287cd9b3d..c024780187c7 100644 --- a/test/cpp/lazy/test_lazy_ops_util.cpp +++ b/test/cpp/lazy/test_lazy_ops_util.cpp @@ -12,11 +12,6 @@ namespace torch { namespace lazy { namespace { -bool IsLtcTensor(const at::Tensor& tensor) { - return dynamic_cast( - tensor.unsafeGetTensorImpl()); -} - std::unordered_set* CreateIgnoredCounters() { std::unordered_set* icounters = new std::unordered_set(); diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt index 012471d0e584..179270c4a4a1 100644 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -42,9 +42,6 @@ add_executable(test_tensorexpr target_link_libraries(test_tensorexpr PRIVATE torch gtest) target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) -if(NOT MSVC) - target_compile_options(test_tensorexpr PRIVATE -Wno-unused-variable) -endif() add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) target_link_libraries(tutorial_tensorexpr PRIVATE torch) diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index d469a7dfa21b..f6ffc84f62c0 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -179,7 +179,7 @@ TEST(LLVM, CharToFloatCastTest) { } TEST(LLVM, BitCast) { - constexpr int16_t ref16 = 1337; + /* constexpr int16_t ref16 = 1337; */ constexpr int32_t ref32 = 1337; constexpr int64_t ref64 = 1337; constexpr float reff32 = 1337.0f; @@ -1395,7 +1395,6 @@ TEST(LLVM, EliminatedStmt) { TEST(LLVM, SimpleReduction) { int M = 128; int N = 64; - const int kTotalSize = M * N; BufHandle a("a", {1, M, N}, kFloat); @@ -1429,7 +1428,6 @@ TEST(LLVM, SimpleReduction) { TEST(LLVM, RFactorReduction) { int M = 128; int N = 64; - const int kTotalSize = M * N; BufHandle a("a", {1, M, N}, kFloat); @@ -1474,7 +1472,6 @@ TEST(LLVM, RFactorReduction) { TEST(LLVM, RFactorVectorizedReduction) { int M = 128; int N = 64; - const int kTotalSize = M * N; BufHandle a("a", {1, M, N}, kFloat); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 6a6a94c82e59..d65b5c544f6c 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1092,6 +1092,7 @@ TEST(Reductions, ReduceOverSplitRfactor) { // Check the IR to verify the rfactored reduce is eliminated. // TODO: The alloc free should be eliminated here since it is size 0. + /* const std::string& verification_pattern = R"IR( # CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] @@ -1102,6 +1103,7 @@ TEST(Reductions, ReduceOverSplitRfactor) { # CHECK: } # CHECK: } # CHECK: Free(tmp_buf);)IR"; + */ // TODO: rfactor output is not consistent yet, will fix (@nickg). // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } diff --git a/test/cpp_extensions/extension.cpp b/test/cpp_extensions/extension.cpp index 1de9e0397111..0b609e82e0c5 100644 --- a/test/cpp_extensions/extension.cpp +++ b/test/cpp_extensions/extension.cpp @@ -2,6 +2,7 @@ // test include_dirs in setuptools.setup with relative path #include +#include torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) { return x.sigmoid() + y.sigmoid(); @@ -31,6 +32,10 @@ torch::Tensor random_tensor() { return torch::randn({1}); } +at::ScalarType get_math_type(at::ScalarType other) { + return at::toOpMathType(other); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)"); m.def( @@ -52,4 +57,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_symint", []() { return c10::SymInt(1); }); m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); }); m.def("get_tensor", []() { return random_tensor(); }); + m.def("get_math_type", &get_math_type); } diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index a7b97f8f7dd3..836013f7fb24 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -243,7 +243,6 @@ def world_size(self) -> int: return min(8, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp() def test_train_parity_single_group(self): """Tests train parity with DDP for a single FSDP group.""" self.run_subtests( @@ -275,7 +274,8 @@ def _test_train_parity_single_group(self, lin_shapes: List[Tuple[int, int]]): self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) - def test_train_parity_multi_group_eager(self): + @test_compiled_fsdp(compile_compute_on_module=Transformer) + def test_train_parity_multi_group(self): """ Tests train parity against DDP when using multiple parameter groups for communication (for communication and computation overlap plus memory @@ -294,21 +294,6 @@ def test_train_parity_multi_group_eager(self): self._test_train_parity_multi_group, ) - @skip_if_lt_x_gpu(2) - def test_train_parity_multi_group_compile(self): - self.run_subtests( - { - "reshard_after_forward": [True, False], - "device_type": ["cuda"], - "offload_policy": [OffloadPolicy()], - "delay_after_forward": [False, True], - "delay_before_all_gather": [False], - "delay_before_reduce_scatter": [False], - "delay_before_optim": [False, True], - }, - self._test_train_parity_multi_group, - ) - @skip_if_lt_x_gpu(2) def test_train_parity_multi_group_cpu_offload_eager(self): """ @@ -353,7 +338,15 @@ def _test_train_parity_multi_group( assert device_type in ("cuda", "cpu"), f"{device_type}" torch.manual_seed(42) lin_dim = 32 - model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)]) + vocab_size = 1024 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + vocab_size=vocab_size, + max_seq_len=64, + dropout_p=0, + ) + model = Transformer(model_args) ref_model = copy.deepcopy(model) if device_type == "cuda": replicate(ref_model.cuda(), device_ids=[self.rank]) @@ -368,8 +361,9 @@ def _test_train_parity_multi_group( reshard_after_forward=reshard_after_forward, offload_policy=offload_policy, ) - for mlp in model: - fully_shard_fn(mlp) + for module in model.modules(): + if isinstance(module, TransformerBlock): + fully_shard_fn(module) fully_shard_fn(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) @@ -398,7 +392,7 @@ def delayed_reduce_scatter(*args, **kwargs): ) with patch_all_gather_ctx, patch_reduce_scatter_ctx: for iter_idx in range(10): - inp = torch.randn((8, lin_dim), device=torch.device(device_type)) + inp = torch.randint(0, vocab_size, (3, 64), device=device_type) losses: List[torch.Tensor] = [] for _model, _optim in ((ref_model, ref_optim), (model, optim)): _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) @@ -412,7 +406,6 @@ def delayed_reduce_scatter(*args, **kwargs): self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp() def test_non_root_forward_backward(self): """ Tests running forward/backward through the root and then through a @@ -459,7 +452,6 @@ def test_non_root_forward_backward(self): self.assertEqual(ref_model(inp).sum(), model(inp).sum()) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp() def test_multi_forward_module(self): """ Tests parity with DDP when running a module that participates multiple @@ -511,6 +503,7 @@ def world_size(self) -> int: return min(torch.cuda.device_count(), 2) @skip_if_lt_x_gpu(2) + @test_compiled_fsdp(compile_compute_on_module=Transformer) def test_train_parity_with_activation_checkpointing(self): """ Tests train parity against DDP when composing with activation @@ -528,6 +521,9 @@ def _test_train_parity_with_activation_checkpointing( self, reshard_after_forward: Union[bool, int], checkpoint_impl: str ): assert checkpoint_impl in ("composable", "utils", "wrapper") + testing_compile = fully_shard != torch.distributed._composable.fsdp.fully_shard + if testing_compile and checkpoint_impl == "composable": + return torch.manual_seed(42) vocab_size = 1024 with torch.device(torch.device("cuda")): @@ -536,7 +532,7 @@ def _test_train_parity_with_activation_checkpointing( n_heads=4, vocab_size=vocab_size, max_seq_len=64, - dropout_p=0.1, + dropout_p=0, checkpoint_activations=(checkpoint_impl == "utils"), ) model = Transformer(model_args) @@ -579,16 +575,18 @@ def _test_train_parity_with_activation_checkpointing( torch.manual_seed(iter_idx + 1) # for dropout determinism losses.append(_model(inp).sum()) losses[-1].backward() - check_sharded_parity( - self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore - ) + if not testing_compile: + check_sharded_parity( + self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore + ) self.assertEqual(losses[0], losses[1]) for _optim in (ref_optim, optim): _optim.step() _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) - check_sharded_parity( - self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore - ) + if not testing_compile: + check_sharded_parity( + self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore + ) class TestFullyShardSharedParams(FSDPTest): @@ -597,22 +595,11 @@ def world_size(self) -> int: return min(4, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) - @test_compiled_fsdp(compile_compute_on_module=TransformerBlock) - def test_train_parity_with_shared_params_no_ac(self): - self.run_subtests( - { - "reshard_after_forward": [False, True], - "use_activation_checkpointing": [False], - }, - self._test_train_shared_params, - ) - - @skip_if_lt_x_gpu(2) - def test_train_parity_with_shared_params_ac(self): + def test_train_parity_with_shared_params(self): self.run_subtests( { "reshard_after_forward": [False, True], - "use_activation_checkpointing": [True], + "use_activation_checkpointing": [False, True], }, self._test_train_shared_params, ) @@ -1182,6 +1169,12 @@ def _test_2d_mlp_with_nd_mesh( _optim.step() self.assertEqual(losses[0], losses[1]) + for n, p in model.named_parameters(): + self.assertIsInstance(p, DTensor) + self.assertEqual(p.device_mesh.ndim, 2) + self.assertEqual(len(p.placements), 2) + self.assertEqual(p.device_mesh.mesh_dim_names, ("dp", "tp")) + class TestFullyShardHSDPTraining(FSDPTest): @property @@ -1255,12 +1248,12 @@ def _test_train_parity_hsdp( check_sharded_parity(self, ref_model, model) -class TestFullyShardCustomForwardMethod(FSDPTestMultiThread): +class TestFullyShardCustomForwardMethod(FSDPTest): @property def world_size(self) -> int: - return 2 + return min(torch.cuda.device_count(), 2) - @unittest.skipIf(not TEST_CUDA, "no cuda") + @skip_if_lt_x_gpu(2) def test_register_fsdp_forward_method(self): """Based on https://github.com/pytorch/pytorch/issues/109385""" @@ -1287,8 +1280,6 @@ def forward(self, imgs: torch.Tensor) -> torch.Tensor: torch.manual_seed(42) model = Model() - for param in model.parameters(): - dist.broadcast(param.detach(), src=0) ref_model = copy.deepcopy(model).cuda() fully_shard(model.vit) fully_shard(model.projector) diff --git a/test/distributed/_spmd/test_data_parallel.py b/test/distributed/_spmd/test_data_parallel.py index 4940320c0724..140ed54c037c 100644 --- a/test/distributed/_spmd/test_data_parallel.py +++ b/test/distributed/_spmd/test_data_parallel.py @@ -12,7 +12,7 @@ from torch.distributed._tensor import Replicate from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests # noqa: TCH001 from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, diff --git a/test/distributed/_spmd/test_graph_utils.py b/test/distributed/_spmd/test_graph_utils.py index 2c90159237c7..2545678e0f15 100644 --- a/test/distributed/_spmd/test_graph_utils.py +++ b/test/distributed/_spmd/test_graph_utils.py @@ -2,7 +2,7 @@ import os from torch.distributed._spmd.graph_utils import dump_graphs_to_files -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests # noqa: TCH001 from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase diff --git a/test/distributed/_spmd/test_tracing.py b/test/distributed/_spmd/test_tracing.py index b77a87a7f44d..77445aac7419 100644 --- a/test/distributed/_spmd/test_tracing.py +++ b/test/distributed/_spmd/test_tracing.py @@ -20,7 +20,7 @@ from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests # noqa: TCH001 from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms as base_with_comms, @@ -46,7 +46,7 @@ def _test_tracing_all_reduce_nd(self, mesh_tensor): local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 4143da2bd88c..5483b3171f30 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -33,6 +33,13 @@ def setUp(self): self.device_type = "cuda" if torch.cuda.is_available() else "cpu" self.world_pg = dist.distributed_c10d._get_default_group() + def checksAssert(self, comm_mode, key, expected_value, expected_total_value): + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_mode.get_total_counts(), expected_total_value) + self.assertEqual(comm_counts[key], expected_value) + + return + def test_comm_mode(self): world_pg = self.world_pg @@ -115,71 +122,100 @@ def test_comm_mode_with_c10d(self): all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) comm_mode = CommDebugMode() + + # tests c10d all_reduce tracing with comm_mode: dist.all_reduce(inp) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.allreduce_], 1) + self.checksAssert(comm_mode, c10d_ops.allreduce_, 1, 1) + # tests c10d all_gather_into_tensor tracing with comm_mode: dist.all_gather_into_tensor(all_gather_out, inp) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1) + self.checksAssert(comm_mode, c10d_ops._allgather_base_, 1, 1) + # tests c10d reduce_scatter tracing with comm_mode: dist.reduce_scatter_tensor(inp, all_gather_out) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1) + self.checksAssert(comm_mode, c10d_ops._reduce_scatter_base_, 1, 1) + # tests c10d broadcast tracing with comm_mode: dist.broadcast(inp, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.broadcast_], 1) + self.checksAssert(comm_mode, c10d_ops.broadcast_, 1, 1) # tests c10d gather tracing with comm_mode: dist.gather(inp, None, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.gather_], 1) + self.checksAssert(comm_mode, c10d_ops.gather_, 1, 1) # tests c10d reduce tracing with comm_mode: dist.reduce(inp, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.reduce_], 1) + self.checksAssert(comm_mode, c10d_ops.reduce_, 1, 1) # tests c10d scatter tracing with comm_mode: dist.scatter(inp, None, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_counts[c10d_ops.scatter_], 1) + self.checksAssert(comm_mode, c10d_ops.scatter_, 1, 1) - @requires_nccl() - def test_comm_mode_with_c10d_allreduce_coalesced(self): - world_pg = self.world_pg + # tests c10d all_gather tracing + output_list = [] - inp = torch.rand(2, 8, 16).cuda() - all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) + with comm_mode: + dist.all_gather(output_list, inp, None) + + self.checksAssert(comm_mode, c10d_ops.allgather_, 1, 1) + + # tests c10d allgather_coalesced_ tracing + output_list = [] - comm_mode = CommDebugMode() with comm_mode: - dist.all_reduce_coalesced(inp) + dist.all_gather_coalesced(output_list, [inp], None) + + self.checksAssert(comm_mode, c10d_ops.allgather_coalesced_, 1, 1) + + # tests c10d allgather_into_tensor_coalesced_ tracing + with comm_mode, dist._coalescing_manager(): dist.all_gather_into_tensor(all_gather_out, inp) - dist.reduce_scatter_tensor(inp, all_gather_out) - dist.broadcast(inp, 0) - comm_counts = comm_mode.get_comm_counts() - self.assertEqual(comm_mode.get_total_counts(), 4) - self.assertEqual(comm_counts[c10d_ops.allreduce_coalesced_], 1) - self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1) - self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1) - self.assertEqual(comm_counts[c10d_ops.broadcast_], 1) + self.checksAssert(comm_mode, c10d_ops.allgather_into_tensor_coalesced_, 1, 1) + + # tests c10d allreduce_coalesced + with comm_mode: + dist.all_reduce_coalesced(inp) + + self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1) + + # tests c10d reduce_scatter_ + with comm_mode: + dist.reduce_scatter(all_gather_out, [inp]) + + self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1) + + # tests c10d reduce_scatter_tensor_coalesced + with comm_mode as A, dist._coalescing_manager() as B: + dist.reduce_scatter_tensor(all_gather_out, inp) + + self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1) + + # tests c10d alltoall_ + with comm_mode: + dist.all_to_all([inp], [inp]) + + self.checksAssert(comm_mode, c10d_ops.alltoall_, 1, 1) + + # tests c10d alltoall_base_ + with comm_mode: + dist.all_to_all_single(inp, inp) + + self.checksAssert(comm_mode, c10d_ops.alltoall_base_, 1, 1) if __name__ == "__main__": diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index 1035df2f5f7d..b483194d6c3a 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -5,6 +5,7 @@ import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import ( distribute_tensor, + DTensor, init_device_mesh, Replicate, Shard, @@ -18,23 +19,30 @@ ) -def equal_forward(device_mesh, X, Y): +funcol_py = torch.ops.c10d_functional + + +def equal_allgather_forward(device_mesh, X, Y): eq = torch.tensor([torch.equal(X, Y)], device=X.device) eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) return torch.all(eq_gather).item() -def mm_forward(device_mesh, W, X): - return torch.mm(W, X) +def mm_all_gather_forward(device_mesh, A, B): + local_mm_result = torch.mm(A, B) + return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait() + +def mm_forward(A, B): # no device mesh needed since we don't do collective + return torch.mm(A, B) -def mm_allreduce_forward(device_mesh, W, X): - partial_sum_tensor = torch.mm(W, X) - reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() - return reduced_tensor +def mm_allreduce_forward(device_mesh, A, B): + partial_sum_tensor = torch.mm(A, B) + return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() -def mul_forward(device_mesh, X, scalar): + +def mul_forward(X, scalar): # no device mesh needed since we don't do collective return torch.mul(X, scalar) @@ -58,6 +66,7 @@ def test_local_map_correctness(self): row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + replicate = [Replicate()] W_dt = distribute_tensor( W, device_mesh, col_wise ) # col-wisely sharded W tensor @@ -70,12 +79,12 @@ def test_local_map_correctness(self): # DTensors' `_local_tensor`. local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # output redistribution to Replicate self.assertEqual(comm_mode.get_total_counts(), 1) @@ -88,6 +97,7 @@ def test_local_map_correctness(self): # check for `out_placements` @with_comms def test_local_map_out_placements(self): + # Test 1: wrap out into DTensor w/ `out_placements` device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) ) @@ -99,14 +109,40 @@ def test_local_map_out_placements(self): row_wise = [Shard(0)] X_dt = distribute_tensor(X, device_mesh, row_wise) Y_dt = distribute_tensor(Y, device_mesh, row_wise) - local_equal_forward = local_map(equal_forward, out_placements=None) + local_equal_allgather_forward = local_map( + equal_allgather_forward, + out_placements=None, + ) with comm_mode: - equal_dt = local_equal_forward(X_dt, Y_dt) # a bool + equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt) # a bool self.assertEqual(comm_mode.get_total_counts(), 1) self.assertTrue(not equal_dt) self.assertTrue(not (X.equal(Y))) + # Test 2: directly return out if no argument is DTensor + # matmul in DDP + replicate = [Replicate()] + X = torch.randn( + 4 // self.world_size, 4, device=self.device_type, requires_grad=False + ) + W = torch.randn(4, 4, device=self.device_type, requires_grad=False) + local_mm_all_gather_forward = local_map( + mm_all_gather_forward, + out_placements=row_wise, + in_placements=(None, row_wise, replicate), + ) + with comm_mode: + Y = local_mm_all_gather_forward(device_mesh, X, W) + + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1 + ) + X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait() + Y_replicate = torch.mm(X_replicate, W) + self.assertEqual(Y, Y_replicate) # Y is a torch.Tensor + # check for `in_placements` handling @with_comms def test_local_map_in_placements(self): @@ -173,6 +209,54 @@ def test_local_map_in_placements(self): self.assertTrue(placement.is_shard(dim=0)) self.assertEqual(Y_dt.full_tensor(), Y) + # Test 4: `None` placements for Tensor input argument + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + X_dt = distribute_tensor( + X, device_mesh, row_wise + ) # row-wisely sharded X tensor + W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor + local_mm_forward = local_map( + mm_forward, + out_placements=None, + in_placements=(None, None), + device_mesh=device_mesh, + ) + with comm_mode: + Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual( + DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), + torch.mm(X, W), + ) + + # Test 5: Some placements for Tensor input argument + local_mm_forward = local_map( + mm_forward, + out_placements=None, + in_placements=(replicate, row_wise), + device_mesh=device_mesh, + ) + with comm_mode: + Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual( + DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), + torch.mm(X, W), + ) + + # Test 6: expect error - `None` placements for DTensor input argument + local_mm_forward = local_map( + mm_forward, + out_placements=row_wise, + in_placements=(row_wise, None), + device_mesh=device_mesh, + ) + with self.assertRaisesRegex(AssertionError, "expects placements"): + Y_dt = local_mm_forward(X_dt, W_dt) + # check for `redistribute_inputs` handling @with_comms def test_local_map_redistribute(self): @@ -188,6 +272,7 @@ def test_local_map_redistribute(self): row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + replicate = [Replicate()] W_dt = distribute_tensor( W, device_mesh, row_wise ) # row-wisely sharded W tensor which will be redistributed @@ -198,13 +283,13 @@ def test_local_map_redistribute(self): # Test 1: allow input redistribution local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, redistribute_inputs=True, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # 2 for input redistribution and 1 for output self.assertEqual(comm_mode.get_total_counts(), 3) @@ -215,13 +300,13 @@ def test_local_map_redistribute(self): # Test 2: no input redistribution is allowed local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, redistribute_inputs=False, ) with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) if __name__ == "__main__": diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index db5a26d43850..3979dd4ad546 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -17,12 +17,19 @@ ) from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_FUSED_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + TEST_CUDA, +) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + skipIfRocm, + TEST_WITH_ROCM, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -41,6 +48,7 @@ def world_size(self) -> int: return 2 @skip_if_lt_x_gpu(2) + @skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" ) @@ -299,18 +307,29 @@ def test_ring_attention_custom_transformer(self) -> None: @skip_if_lt_x_gpu(2) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Does not support flash nor efficient attention", ) + @unittest.skipIf( + TEST_CUDA and not TEST_WITH_ROCM and not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Does not support flash attention", + ) # On CUDA (not ROCM) platform, the UT is skipped if no FA support (even if ME may get supported) @with_comms @parametrize( "attention_fn", [ - _scaled_dot_product_ring_flash_attention, - _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_flash_attention + if PLATFORM_SUPPORTS_FLASH_ATTENTION + else None, + _scaled_dot_product_ring_efficient_attention + if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION + else None, # _scaled_dot_product_ring_cudnn_attention, # TODO: not built by default ], ) def test_ring_attention_compile(self, attention_fn: object) -> None: + if attention_fn is None: + self.skipTest("Unsupported on current platform") device_mesh = DeviceMesh( self.device_type, torch.arange(0, self.world_size), diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 70a67b8e0b93..17a6aebd8f93 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -14,7 +14,13 @@ init_device_mesh, ) from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Partial, + Replicate, + Shard, + TensorMeta, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -55,27 +61,29 @@ def test_dtensor_constructor(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) placements = [Shard(0)] local_tensor = torch.randn(3, 3, requires_grad=True) - dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + torch.Size([self.world_size * 3, 3]), + local_tensor.stride(), + local_tensor.dtype, + ), + ) + dist_tensor = DTensor( local_tensor, - device_mesh, - placements, - shape=dist_tensor_shape, - dtype=local_tensor.dtype, + spec, requires_grad=True, - stride=local_tensor.stride(), ) self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3))) with self.assertWarnsRegex(UserWarning, "To construct"): DTensor( local_tensor, - device_mesh, - placements, - shape=dist_tensor_shape, - dtype=local_tensor.dtype, + spec, requires_grad=False, - stride=local_tensor.stride(), ) @with_comms @@ -174,7 +182,7 @@ def test_from_local(self): ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec) self.assertEqual(ddp_tensor.size(), local_tensor.size()) - partial_spec = [_Partial()] + partial_spec = [Partial()] partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec) self.assertEqual(partial_tensor.size(), local_tensor.size()) @@ -272,19 +280,23 @@ def test_from_local_negative_dim(self): def test_to_local(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) placements = (Shard(0),) - dist_tensor_shape = torch.Size([self.world_size * 3, 3]) local_tensor_with_grad = torch.randn( 3, 3, device=self.device_type, requires_grad=True ) - + dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + dist_tensor_shape, + local_tensor_with_grad.stride(), + local_tensor_with_grad.dtype, + ), + ) sharded_tensor = DTensor( local_tensor_with_grad, - device_mesh, - placements, - shape=dist_tensor_shape, - dtype=local_tensor_with_grad.dtype, + spec, requires_grad=True, - stride=local_tensor_with_grad.stride(), ) self.assertEqual(sharded_tensor.size(), dist_tensor_shape) self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad) @@ -319,6 +331,11 @@ def test_to_local(self): except RuntimeError: self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size]) + # test the case under no-grad we directly return the local tensor + with torch.no_grad(): + local_no_grad = sharded_tensor.to_local() + assert local_no_grad is sharded_tensor._local_tensor + @with_comms def test_to_local_grad_hint(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -330,7 +347,7 @@ def test_to_local_grad_hint(self): with comm_mode: local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local( - grad_placements=[_Partial()] + grad_placements=[Partial()] ) local_out.backward(torch.ones_like(local_out)) @@ -362,7 +379,7 @@ def test_full_tensor_grad_hint(self): global_tensor = torch.ones(8, 3, requires_grad=True) sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements) - local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()]) + local_out = sharded_dtensor.full_tensor(grad_placements=[Partial()]) local_out.sum().backward() replica_grad = sharded_dtensor.grad.full_tensor() diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index f40cb4999858..0f097e07e92f 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -17,10 +17,11 @@ DeviceMesh, DTensor, init_device_mesh, + Partial, Replicate, Shard, ) -from torch.distributed._tensor.placement_types import _Partial +from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, @@ -121,7 +122,7 @@ def fn(x): compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn) - for x in [Shard(0), Replicate(), _Partial()]: + for x in [Shard(0), Replicate(), Partial()]: opt_fn = fn(x) compiled_out = compiled_fn(x) self.assertEqual(opt_fn, compiled_out) @@ -193,41 +194,45 @@ def fn(x): def test_dtensor_constructor_w_graph_break(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + x = torch.randn(64, 32, requires_grad=True) + spec = DTensorSpec( + mesh, + (Replicate(), Shard(0)), + tensor_meta=TensorMeta( + shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype + ), + ) # test passing in DTensor as inputs/outputs and run some tensor computation def fn(x): print("graph break!") return DTensor( x, - mesh, - (Replicate(), Shard(0)), - shape=[128, 32], - dtype=x.dtype, + spec, requires_grad=x.requires_grad, - stride=[32, 1], ) - x = torch.randn(64, 32, requires_grad=True) out = fn(x) out2 = torch.compile(fn, backend="eager")(x) def test_dtensor_constructor_w_dynamo_disable(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + x = torch.randn(32, requires_grad=True) + spec = DTensorSpec( + mesh, + (Replicate(),), + tensor_meta=TensorMeta(shape=torch.Size([32]), stride=(1,), dtype=x.dtype), + ) @torch._dynamo.disable(recursive=False) def fn(x): print("foo") return DTensor( x, - mesh, - (Replicate(),), - shape=torch.Size([32]), - dtype=x.dtype, + spec, requires_grad=x.requires_grad, - stride=(1,), ) - x = torch.randn(32, requires_grad=True) out = fn(x) out2 = torch.compile(fn, backend="eager")(x) self.assertEqual(out, out2) @@ -313,7 +318,7 @@ def fn(x): x_dt = DTensor.from_local( x, mesh, - [_Partial()], + [Partial()], run_check=False, shape=(10, 257, 160), stride=(41120, 160, 1), @@ -354,7 +359,7 @@ def fn(x): x_dt = DTensor.from_local( x, mesh, - [_Partial()], + [Partial()], run_check=False, shape=(10, 257, 160), stride=(41120, 160, 1), @@ -515,7 +520,7 @@ def fn(x): return x + x x = torch.randn(4, 4, requires_grad=True) - x_dt = DTensor.from_local(x, mesh, [_Partial()], run_check=False) + x_dt = DTensor.from_local(x, mesh, [Partial()], run_check=False) y = torch.randn(4, 4, requires_grad=True) y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 22a56118b212..07f8bfedc615 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -102,6 +102,7 @@ def wrapped(fn): xfail("addr"), xfail("all"), xfail("allclose"), + xfail("alias_copy"), xfail("amax"), xfail("amin"), xfail("aminmax"), @@ -403,7 +404,6 @@ def wrapped(fn): xfail("rsub"), xfail("scalar_tensor"), xfail("scatter_add"), - xfail("scatter"), xfail("scatter_reduce", "amax"), xfail("scatter_reduce", "amin"), xfail("scatter_reduce", "mean"), diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index 6f8015bfd0a4..d2ea73ae8c87 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -371,7 +371,7 @@ def _replicate_fn(name, module, device_mesh): if elementwise_affine: # if input is sharded on any outer dimension, the gradient of weight - # and bias should be _Partial + # and bias should be Partial dim_map = x_dist._spec.dim_map outer_dims = range(norm_idx) needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py index fa3f9272c63e..7889ed46ca5e 100644 --- a/test/distributed/_tensor/test_matrix_ops.py +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -10,7 +10,7 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import ( - _Partial, + Partial, Placement, Replicate, Shard, @@ -77,7 +77,7 @@ def test_addmm_auto_redistribute(self): # test if addmm output is a partial self.assertIsInstance(dist_res, DTensor) - self.assertIsInstance(dist_res.placements[0], _Partial) + self.assertIsInstance(dist_res.placements[0], Partial) # test if result is the same as tensor dist_local_res = dist_res.full_tensor() @@ -144,11 +144,11 @@ def test_t_partial(self): da = distribute_tensor(a, device_mesh, [Shard(1)]) db = distribute_tensor(b, device_mesh, [Shard(0)]) - # mm(da, db) should return a _Partial tensor. - # transposing it should keep it _Partial + # mm(da, db) should return a Partial tensor. + # transposing it should keep it Partial dc = torch.mm(da, db).t() - self.assertTrue(isinstance(dc.placements[0], _Partial)) + self.assertTrue(isinstance(dc.placements[0], Partial)) # check that the local and distributed op results match self.assertEqual( diff --git a/test/distributed/_tensor/test_op_strategy.py b/test/distributed/_tensor/test_op_strategy.py index 0cb469e1c405..5194d5bf7d89 100644 --- a/test/distributed/_tensor/test_op_strategy.py +++ b/test/distributed/_tensor/test_op_strategy.py @@ -11,8 +11,8 @@ gen_einsum_strategies, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Replicate, Shard, TensorMeta, @@ -139,7 +139,7 @@ def test_redistribute_cost_mesh_1d(self): mesh_1d = self.build_device_mesh() shard_placement = (Shard(0),) replica_placement = (Replicate(),) - partial_placement = (_Partial(),) + partial_placement = (Partial(),) global_tensor = torch.randn(10, 10) global_tensor_meta = self._extract_tensor_meta(global_tensor) @@ -174,7 +174,7 @@ def test_redistribute_cost_latency(self): mesh = self.build_device_mesh() shard0_placement = (Shard(0),) - partial_placement = (_Partial(),) + partial_placement = (Partial(),) shard1_placement = (Shard(1),) shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8)) @@ -220,7 +220,7 @@ def test_redistribute_cost_mesh_2d(self): ) shard_placement = (Shard(0), Shard(0)) replica_placement = (Replicate(), Replicate()) - partial_placement = (_Partial(), _Partial()) + partial_placement = (Partial(), Partial()) global_tensor = torch.randn(8, 8) global_tensor_meta = self._extract_tensor_meta(global_tensor) diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py index 4b25efdc9105..f0103bad2de6 100644 --- a/test/distributed/_tensor/test_pointwise_ops.py +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -11,7 +11,7 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.placement_types import ( - _Partial, + Partial, Placement, Replicate, Shard, @@ -141,15 +141,15 @@ def _run_sharded_elementwise_ops( def test_partial_add(self): device_mesh = self.build_device_mesh() - d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()]) - d_2 = DTensor.from_local(torch.rand(2, 2), device_mesh, [_Partial()]) + d_1 = DTensor.from_local(torch.rand(2, 2), device_mesh, [Partial()]) + d_2 = DTensor.from_local(torch.rand(2, 2), device_mesh, [Partial()]) d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) def test_partial_mul(self): device_mesh = self.build_device_mesh() - d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()]) - d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [_Partial()]) + d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) + d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) d_3 = d_1 * d_2 self.assertTrue(d_3._spec.placements[0].is_replicate()) self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size**2)) @@ -256,7 +256,7 @@ def test_dropout_errors(self): with self.assertRaisesRegex(RuntimeError, "supported"): self._run_sharded_elementwise_ops( device_mesh=device_mesh, - placements=[_Partial("sum")], + placements=[Partial("sum")], input_size=(8, 5), op=torch.nn.functional.dropout, ) diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index c97682b606c7..1d2673a6a7bc 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -6,7 +6,7 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.testing._internal.common_utils import run_tests @@ -105,7 +105,7 @@ def test_replicate_to_local_partial_grad(self): with comm_mode: out = replica_tensor.redistribute(placements=[Replicate()]).to_local( - grad_placements=[_Partial()] + grad_placements=[Partial()] ) out.backward(torch.ones_like(out)) @@ -168,7 +168,7 @@ def test_partial_to_replicate_forward_backward(self): # backward should work as expected device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True) - partial_spec = [_Partial()] + partial_spec = [Partial()] replica_spec = [Replicate()] comm_mode = CommDebugMode() @@ -199,11 +199,11 @@ def test_partial_to_replicate_forward_backward(self): def test_replicate_to_partial(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) - partial_spec = _Partial() + partial_spec = Partial() replica_spec = Replicate() # 1) test replicate -> partial forward replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec]) - with self.assertRaisesRegex(RuntimeError, "Can not redistribute to _Partial"): + with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"): partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) from torch.distributed._tensor._redistribute import Redistribute @@ -246,7 +246,7 @@ def test_replicate_to_partial(self): @with_comms def test_partial_to_shard(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - partial_spec = [_Partial()] + partial_spec = [Partial()] my_rank = device_mesh.get_rank() input_sizes_and_shard_dim = [ @@ -441,7 +441,7 @@ def test_multi_dim_mesh(self): possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)] all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities]))) all_inputs = list( - itertools.product(*(mesh_shape.ndim * [possibilities + [_Partial()]])) + itertools.product(*(mesh_shape.ndim * [possibilities + [Partial()]])) ) for inputs in all_inputs: diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index 2d8d726da865..e86a702855c6 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -4,7 +4,7 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -92,7 +92,7 @@ def test_inplace_op(self): # test inplace op self and other dtensor with other specs # and make sure out spec not change shard_spec = [Shard(0)] - partial_spec = [_Partial()] + partial_spec = [Partial()] dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec) partial_grad = DTensor.from_local(torch.randn(12, 3), mesh, partial_spec) res = dt_to_inplace_add.add_(partial_grad) @@ -168,7 +168,7 @@ def test_ones_like(self): @with_comms def test_ones_like_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [_Partial()] + shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) @@ -181,7 +181,7 @@ def test_ones_like_partial_sum(self): @with_comms def test_fill_inplace_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [_Partial()] + shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) @@ -197,7 +197,7 @@ def test_fill_inplace_partial_sum(self): @with_comms def test_zeros_like_partial_sum(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - shard_spec = [_Partial()] + shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) @@ -236,8 +236,8 @@ def test_stack(self): mesh_2d = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 2) ) - partial_replicate_placement = [_Partial(), Replicate()] - partial_placement = [_Partial(), _Partial()] + partial_replicate_placement = [Partial(), Replicate()] + partial_placement = [Partial(), Partial()] partial_replicate_dt = DTensor.from_local( torch.randn(4, 8), mesh_2d, partial_replicate_placement @@ -390,6 +390,40 @@ def test_new_empty_strided(self): self.assertEqual(new_empty_strided_dt._local_tensor.size(), (12, 4)) self.assertEqual(new_empty_strided_dt._local_tensor.stride(), (4, 1)) + @with_comms + def test_scatter(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + comm_mode = CommDebugMode() + + # case 1 all replicate: input replicated, index/src replicated, output replicated + global_indexs = [ + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[0, 1, 2], [0, 1, 4]]), + ] + for scatter_dim in [0, 1]: + srcs = [torch.arange(1, 11).reshape((2, 5)), 4] + for global_src in srcs: + global_input = torch.zeros(3, 5, dtype=torch.int64) + global_index = global_indexs[scatter_dim] + + input_dt = distribute_tensor( + global_input.clone(), device_mesh, [Replicate()] + ) + index_dt = distribute_tensor(global_index, device_mesh, [Replicate()]) + if isinstance(global_src, torch.Tensor): + src_dt = distribute_tensor(global_src, device_mesh, [Replicate()]) + else: + src_dt = global_src + global_output = torch.scatter( + global_input, scatter_dim, global_index, global_src + ) + with comm_mode: + output_dt = torch.scatter(input_dt, scatter_dim, index_dt, src_dt) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual(output_dt.placements, [Replicate()]) + self.assertEqual(output_dt.to_local(), global_output) + @with_comms def test_gather(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 3d6608a491ec..467b5e092306 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -10,7 +10,12 @@ ) from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed._tensor.placement_types import ( + DTensorSpec, + Replicate, + Shard, + TensorMeta, +) from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.common_utils import run_tests @@ -185,14 +190,20 @@ def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self): chunks = list(torch.chunk(dtensor_tp.to_local(), 2, dim=0)) shard_rank = 0 if self.rank // 2 == 0 else 1 sharded_param = chunks[shard_rank] + spec_2d = DTensorSpec( + mesh=mesh_2d, + placements=(Shard(0), Shard(0)), + tensor_meta=TensorMeta( + global_tensor.size(), + global_tensor.stride(), + global_tensor.dtype, + ), + ) + dtensor_2d = DTensor( sharded_param, - mesh_2d, - [Shard(0), Shard(0)], - shape=global_tensor.size(), - dtype=global_tensor.dtype, + spec_2d, requires_grad=False, - stride=global_tensor.stride(), ) self.assertEqual( diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 94ec8602bd4f..58d1f20ad911 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -19,6 +19,7 @@ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, ) +from torch.distributed.checkpoint import state_dict as ptd_state_dict from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, @@ -29,7 +30,7 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.optim import _apply_optimizer_in_backward from torch.nn.parallel import DistributedDataParallel as DDP @@ -42,6 +43,7 @@ from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, + MultiProcessTestCase, with_comms, ) from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin @@ -424,52 +426,6 @@ def test_strict(self) -> None: with self.assertRaisesRegex(RuntimeError, "Missing key"): set_model_state_dict(model, model_state_dict=model_state_dict) - @with_comms - @skip_if_lt_x_gpu(1) - def test_partial(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) - - model_state_dict1 = get_model_state_dict(model) - model_state_dict1 = copy.deepcopy(model_state_dict1) - model_state_dict2 = get_model_state_dict(model, submodules={model.l}) - model_state_dict2 = copy.deepcopy(model_state_dict2) - model_state_dict3 = get_model_state_dict( - model, - submodules={model.l}, - options=StateDictOptions(keep_submodule_prefixes=False), - ) - model_state_dict3 = copy.deepcopy(model_state_dict3) - self.assertEqual(len(model_state_dict2), 2) - self.assertEqual(len(model_state_dict3), 2) - for key in model_state_dict3.keys(): - full_fqn = f"l.{key}" - value1 = model_state_dict1[full_fqn] - value2 = model_state_dict2[full_fqn] - value3 = model_state_dict3[key] - self.assertEqual(value1, value2) - self.assertEqual(value2, value3) - - zeros_state_dict = { - k: torch.zeros_like(v) for k, v in model_state_dict1.items() - } - model.load_state_dict(zeros_state_dict) - set_model_state_dict( - model, - model_state_dict=model_state_dict2, - options=StateDictOptions(strict=False), - ) - self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) - self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) - - model.load_state_dict(zeros_state_dict) - set_model_state_dict( - model, - model_state_dict={model.l: model_state_dict3}, - options=StateDictOptions(strict=False), - ) - self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) - self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) - def _test_cpu_offload_full_state_dict( self, optimizer_class: Type[Optimizer] ) -> None: @@ -650,9 +606,11 @@ def check(equal): # Drop the states to simulate loading from rank0 if dist.get_rank() > 0: load_states = {} + load_states2 = {} load_optim_states = {} else: load_states = copy.deepcopy(states) + load_states2 = copy.deepcopy(states) load_optim_states = copy.deepcopy(optim_states) set_model_state_dict( @@ -670,7 +628,21 @@ def check(equal): broadcast_from_rank0=True, full_state_dict=True ), ) + check(equal=True) + # Verify the `strict` flag. + load_states = load_states2 + if load_states: + key = next(iter(load_states.keys())) + load_states.pop(key) + with self.assertRaisesRegex(RuntimeError, "Missing key"): + set_model_state_dict( + fsdp_model, + model_state_dict=load_states, + options=StateDictOptions( + broadcast_from_rank0=True, full_state_dict=True + ), + ) device_mesh = init_device_mesh("cuda", (self.world_size,)) self.run_subtests( @@ -696,6 +668,178 @@ def test_fsdp_root_not_initialized(self) -> None: get_model_state_dict(fsdp_model) get_optimizer_state_dict(fsdp_model, fsdp_optim) + @with_comms + @skip_if_lt_x_gpu(2) + def test_optim_state_dict_param_matching(self) -> None: + # This test verifies parameters between optim and optim_state_dict + # "initial_lr" is added to optim_state_dict, but not to the new optim + # We test whether "initial_lr" appear in optim after + # set_optimizer_state_dict. + device = "cuda" + torch.manual_seed(0) + model = nn.Sequential( + *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)] + ) + for layer in model: + fully_shard(layer) + fully_shard(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + torch.optim.lr_scheduler.LambdaLR( + optim, lr_lambda=[lambda epoch: 0.95**epoch] + ) + opt_state_dict = ptd_state_dict.get_optimizer_state_dict( + model, + optim, + options=ptd_state_dict.StateDictOptions( + full_state_dict=True, cpu_offload=True + ), + ) + if dist.get_rank() == 0: + self.assertTrue("initial_lr" in opt_state_dict["param_groups"][0]) + + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + self.assertTrue("initial_lr" not in optim.param_groups[0]) + + ptd_state_dict.set_optimizer_state_dict( + model, + optim, + optim_state_dict=opt_state_dict, + options=ptd_state_dict.StateDictOptions( + broadcast_from_rank0=True, full_state_dict=True + ), + ) + if dist.get_rank() == 0: + self.assertTrue("initial_lr" in optim.param_groups[0]) + + @with_comms + @skip_if_lt_x_gpu(2) + def test_flattened_osd(self) -> None: + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) + fsdp_model = FSDP2(copy.deepcopy(model), mesh=device_mesh) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) + batch = torch.rand(8, 100, device="cuda") + fsdp_model(batch).sum().backward() + fsdp_optim.step() + fsdp_optim.zero_grad() + osd1 = get_optimizer_state_dict(fsdp_model, fsdp_optim) + osd2 = get_optimizer_state_dict( + fsdp_model, + fsdp_optim, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + fsdp_optim2 = torch.optim.AdamW(fsdp_model.parameters()) + set_optimizer_state_dict( + fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd2 + ) + self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) + set_optimizer_state_dict( + fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd1 + ) + self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) + + @with_comms + @skip_if_lt_x_gpu(1) + def test_deprecate_partial(self) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + + model_state_dict1 = get_model_state_dict(model) + model_state_dict1 = copy.deepcopy(model_state_dict1) + with self.assertWarnsRegex( + FutureWarning, + "Getting submodules only model/optim state_dict is deprecated", + ): + model_state_dict2 = get_model_state_dict(model, submodules={model.l}) + model_state_dict2 = copy.deepcopy(model_state_dict2) + with self.assertWarnsRegex( + FutureWarning, + "Getting submodules only model/optim state_dict is deprecated", + ): + model_state_dict3 = get_model_state_dict( + model, + submodules={model.l}, + options=StateDictOptions(keep_submodule_prefixes=False), + ) + model_state_dict3 = copy.deepcopy(model_state_dict3) + self.assertEqual(len(model_state_dict2), 2) + self.assertEqual(len(model_state_dict3), 2) + for key in model_state_dict3.keys(): + full_fqn = f"l.{key}" + value1 = model_state_dict1[full_fqn] + value2 = model_state_dict2[full_fqn] + value3 = model_state_dict3[key] + self.assertEqual(value1, value2) + self.assertEqual(value2, value3) + + zeros_state_dict = { + k: torch.zeros_like(v) for k, v in model_state_dict1.items() + } + model.load_state_dict(zeros_state_dict) + set_model_state_dict( + model, + model_state_dict=model_state_dict2, + options=StateDictOptions(strict=False), + ) + self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) + self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) + + model.load_state_dict(zeros_state_dict) + with self.assertWarnsRegex(FutureWarning, "Passing model_state_dict as a "): + set_model_state_dict( + model, + model_state_dict={model.l: model_state_dict3}, + options=StateDictOptions(strict=False), + ) + self.assertEqual(model.l.weight, model_state_dict1["l.weight"]) + self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) + + @with_comms + @skip_if_lt_x_gpu(1) + def test_deprecate_fsdp_api(self) -> None: + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) + fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) + with self.assertWarnsRegex( + FutureWarning, + r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated", + ): + with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): + fsdp_model.state_dict() + + with self.assertRaisesRegex(AssertionError, "FutureWarning not triggered"): + with self.assertWarnsRegex( + FutureWarning, + r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated", + ): + get_model_state_dict(model) + + +class TestNoComm(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @skip_if_lt_x_gpu(1) + def test_no_dist(self) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + optim = torch.optim.AdamW(model.parameters(), lr=1e-3) + + self.assertFalse(dist.is_initialized()) + msd = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + for v in msd.values(): + self.assertFalse(v.is_cuda) + self.assertEqual(model.state_dict(), msd) + set_model_state_dict(model, model.state_dict()) + osd = get_optimizer_state_dict( + model, + optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + set_optimizer_state_dict(model, optim, osd) + set_optimizer_state_dict(model, optim, optim.state_dict()) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py index 22ab029a612f..95e77a5662ee 100644 --- a/test/distributed/checkpoint/test_traverse.py +++ b/test/distributed/checkpoint/test_traverse.py @@ -1,13 +1,16 @@ # Owner(s): ["oncall: distributed"] from collections import OrderedDict +from typing import TYPE_CHECKING import torch import torch.distributed.checkpoint._traverse as _traverse -from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from torch.testing._internal.common_utils import run_tests, TestCase +if TYPE_CHECKING: + from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + # TODO: add comments for TestTraverse class TestTraverse(TestCase): diff --git a/test/distributed/elastic/multiprocessing/redirects_test.py b/test/distributed/elastic/multiprocessing/redirects_test.py index 0d8c14310f87..2fa507a15a36 100644 --- a/test/distributed/elastic/multiprocessing/redirects_test.py +++ b/test/distributed/elastic/multiprocessing/redirects_test.py @@ -138,3 +138,7 @@ def c_print(i): libc.printf(bytes(f"c:{i}\n", "utf-8")) self._redirect_large_buffer(c_print) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py new file mode 100644 index 000000000000..775b062451b1 --- /dev/null +++ b/test/distributed/elastic/test_control_plane.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Owner(s): ["oncall: distributed"] + +import json +import os +import pickle +import socket +import tempfile +from contextlib import contextmanager + +from urllib3.connection import HTTPConnection +from urllib3.connectionpool import HTTPConnectionPool + +from torch.distributed.elastic.control_plane import ( + TORCH_WORKER_SERVER_SOCKET, + worker_main, +) +from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase + + +class UnixHTTPConnection(HTTPConnection): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def connect(self) -> None: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + + +class UnixHTTPConnectionPool(HTTPConnectionPool): + def __init__(self, socket_path: str) -> None: + super().__init__("localhost") + + self.socket_path = socket_path + + def _new_conn(self): + return UnixHTTPConnection(self.socket_path) + + +@contextmanager +def local_worker_server() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "socket.sock") + os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path + + with worker_main(): + pool = UnixHTTPConnectionPool(socket_path) + yield pool + + +class WorkerServerTest(TestCase): + def test_worker_server(self) -> None: + with local_worker_server() as pool: + resp = pool.request("GET", "/") + self.assertEqual(resp.status, 200) + self.assertEqual( + resp.data, + b"""

torch.distributed.WorkerServer

+
Handler names +""", + ) + + resp = pool.request("POST", "/handler/ping") + self.assertEqual(resp.status, 200) + self.assertEqual(resp.data, b"pong") + + resp = pool.request("GET", "/handler/") + self.assertEqual(resp.status, 200) + self.assertIn("ping", json.loads(resp.data)) + + resp = pool.request("POST", "/handler/nonexistant") + self.assertEqual(resp.status, 404) + self.assertIn(b"Handler nonexistant not found:", resp.data) + + @requires_cuda + def test_dump_nccl_trace_pickle(self) -> None: + with local_worker_server() as pool: + resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") + self.assertEqual(resp.status, 200) + out = pickle.loads(resp.data) + + def test_tcp(self) -> None: + import requests + + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer("", 1234) + out = requests.get("http://localhost:1234/handler/") + self.assertEqual(out.status_code, 200) + + server.shutdown() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/test_run.py similarity index 89% rename from test/distributed/launcher/run_test.py rename to test/distributed/launcher/test_run.py index c816042e3e46..ba58aec43871 100644 --- a/test/distributed/launcher/run_test.py +++ b/test/distributed/launcher/test_run.py @@ -13,7 +13,6 @@ import subprocess import sys import tempfile -import unittest import uuid from contextlib import closing from unittest import mock @@ -23,12 +22,13 @@ from torch.distributed.elastic.agent.server.api import RunResult, WorkerState from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs from torch.distributed.elastic.multiprocessing.errors import ChildFailedError -from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer from torch.distributed.elastic.utils import get_socket_with_port from torch.distributed.elastic.utils.distributed import get_free_port from torch.testing._internal.common_utils import ( + run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_DEV_DBG_ASAN, + TestCase, ) @@ -63,19 +63,7 @@ class MockException(Exception): pass -class ElasticLaunchTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - # start a standalone, single process etcd server to use for all tests - cls._etcd_server = EtcdServer() - cls._etcd_server.start() - cls._etcd_endpoint = cls._etcd_server.get_endpoint() - - @classmethod - def tearDownClass(cls): - # stop the standalone etcd server - cls._etcd_server.stop() - +class ElasticLaunchTest(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() @@ -103,8 +91,6 @@ def _test_launch_user_script_python(self): args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -156,8 +142,6 @@ def test_launch_user_script_bash(self): args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -187,8 +171,6 @@ def test_launch_user_script_default_nproc(self): world_size = 1 args = [ f"--nnodes={nnodes}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -220,8 +202,6 @@ def test_launch_with_env_vars(self): os.environ["PET_NNODES"] = str(nnodes) os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node) - os.environ["PET_RDZV_BACKEND"] = "etcd" - os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint os.environ["PET_RDZV_ID"] = run_id os.environ["PET_MONITOR_INTERVAL"] = "1" os.environ["PET_START_METHOD"] = "spawn" @@ -250,8 +230,6 @@ def _test_nproc_launch_configuration(self, nproc_type, expected_number): args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_type}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -272,7 +250,8 @@ def _test_nproc_launch_configuration(self, nproc_type, expected_number): @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) - def test_nproc_launch_auto_configurations(self): + @patch("torch.cuda.is_available", return_value=False) + def test_nproc_launch_auto_configurations(self, _mock1): self._test_nproc_launch_configuration("auto", os.cpu_count()) @skip_but_pass_in_sandcastle_if( @@ -310,8 +289,9 @@ def test_launch_elastic(self): args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -343,8 +323,9 @@ def test_launch_elastic_worker_raise_exception(self, record_mock): args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", f"--rdzv-id={run_id}", "--monitor-interval=1", "--max-restarts=0", @@ -376,8 +357,9 @@ def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run) args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv_conf=timeout=5", f"--rdzv-id={run_id}", "--monitor-interval=1", "--max-restarts=0", @@ -452,8 +434,9 @@ def test_launch_elastic_multiple_agents(self): args = [ f"--nnodes={min_nodes}:{max_nodes}", f"--nproc-per-node={nproc_per_node}", - "--rdzv-backend=etcd", - f"--rdzv-endpoint={self._etcd_endpoint}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{get_free_port()}", + "--rdzv_conf=timeout=5", f"--rdzv-id={run_id}", "--monitor-interval=1", "--start-method=spawn", @@ -608,21 +591,6 @@ def test_is_not_torchelastic_launched(self): is_torchelastic_launched = fp.readline() self.assertEqual("False", is_torchelastic_launched) - def test_init_method_tcp(self): - port = get_free_port() - with patch.object( - sys, - "argv", - [ - path("bin/test_script_init_method.py"), - f"--init-method=tcp://localhost:{port}", - "--rank=0", - "--world-size=1", - ], - ): - runpy.run_path(sys.argv[0], run_name="__main__") - # nothing to validate, just make sure it runs - @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) @@ -642,27 +610,6 @@ def test_init_method_tcp_with_torchelastic(self): ) # nothing to validate, just make sure it runs - def test_init_method_env(self): - port = get_free_port() - with patch.dict( - os.environ, - { - "RANK": "0", - "WORLD_SIZE": "1", - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(port), - }, - ), patch.object( - sys, - "argv", - [ - path("bin/test_script_init_method.py"), - "--init-method=env://", - ], - ): - runpy.run_path(sys.argv[0], run_name="__main__") - # nothing to validate, just make sure it runs - @skip_but_pass_in_sandcastle_if( TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) @@ -681,3 +628,7 @@ def test_init_method_env_with_torchelastic(self): ] ) # nothing to validate, just make sure it runs + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/pipeline/sync/LICENSE b/test/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc9..000000000000 --- a/test/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/pipeline/sync/__init__.py b/test/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 94cd5bcb415e..000000000000 --- a/test/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. -# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/pipeline/sync/conftest.py b/test/distributed/pipeline/sync/conftest.py deleted file mode 100644 index 4f2479b27b29..000000000000 --- a/test/distributed/pipeline/sync/conftest.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import tempfile - -import pytest - -import torch -import torch.distributed as dist - - -@pytest.fixture(autouse=True) -def manual_seed_zero(): - torch.manual_seed(0) - - -@pytest.fixture(scope="session") -def cuda_sleep(): - # Warm-up CUDA. - torch.empty(1, device="cuda") - - # From test/test_cuda.py in PyTorch. - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - torch.cuda._sleep(1000000) - end.record() - end.synchronize() - cycles_per_ms = 1000000 / start.elapsed_time(end) - - def cuda_sleep(seconds): - torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) - - return cuda_sleep - - -def pytest_report_header(): - return f"torch: {torch.__version__}" - - -@pytest.fixture -def setup_rpc(scope="session"): - file = tempfile.NamedTemporaryFile() - dist.rpc.init_rpc( - name="worker0", - rank=0, - world_size=1, - rpc_backend_options=dist.rpc.TensorPipeRpcBackendOptions( - init_method=f"file://{file.name}", - ), - ) - yield - dist.rpc.shutdown() - - -def pytest_ignore_collect(path, config): - "Skip this directory if distributed modules are not enabled." - return not dist.is_available() diff --git a/test/distributed/pipeline/sync/skip/__init__.py b/test/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index ab03724cafbf..000000000000 --- a/test/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/pipeline/sync/skip/test_api.py b/test/distributed/pipeline/sync/skip/test_api.py deleted file mode 100644 index be38d6d83dac..000000000000 --- a/test/distributed/pipeline/sync/skip/test_api.py +++ /dev/null @@ -1,52 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import copy - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, stash -from torch.testing._internal.common_utils import run_tests - - -def test_namespace_difference(): - ns1 = Namespace() - ns2 = Namespace() - assert ns1 != ns2 - - -def test_namespace_copy(): - ns = Namespace() - assert copy.copy(ns) == ns - assert copy.copy(ns) is not ns - - -def test_skippable_repr(): - @skippable(stash=["hello"]) - class Hello(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, x): - yield stash("hello", x) - return self.conv(x) # noqa: B901 - - m = Hello() - assert ( - repr(m) - == """ -@skippable(Hello( - (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) -)) -""".strip() - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_gpipe.py b/test/distributed/pipeline/sync/skip/test_gpipe.py deleted file mode 100644 index 4f433ab38941..000000000000 --- a/test/distributed/pipeline/sync/skip/test_gpipe.py +++ /dev/null @@ -1,126 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.portal import ( - PortalBlue, - PortalCopy, - PortalOrange, -) -from torch.distributed.pipeline.sync.utils import partition_model -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -@pytest.mark.parametrize( - "balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"] -) -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint, setup_rpc): - if torch.cuda.device_count() < len(balance): - pytest.skip("at least %d cuda devices required" % len(balance)) - - @skippable(stash=["1to3"]) - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - yield stash("1to3", input) - output = self.conv(input) - return output # noqa: B901 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - output = self.conv(input) - return output - - @skippable(pop=["1to3"]) - class Layer3(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 1) - - def forward(self, input): - skip_1to3 = yield pop("1to3") - output = self.conv(input) + skip_1to3 - return output - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - model = partition_model(model, balance) - model = Pipe(model, chunks=3, checkpoint=checkpoint) - - in_device = model.devices[0] - out_device = model.devices[-1] - - input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) - output = model(input) - loss = output.local_value().mean() - loss.backward() - - assert torch.allclose( - output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1 - ) - assert torch.allclose( - input.grad.norm(), torch.tensor(0.0004533053, device=in_device) - ) - - -def test_none_skip(setup_rpc): - @skippable(stash=["none"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("none", None) - return input # noqa: B901 - - @skippable(pop=["none"]) - class Pop(nn.Module): - def forward(self, input): - none = yield pop("none") - assert none is None - return input - - model = nn.Sequential(Stash(), Pop()) - model = Pipe(model, chunks=5) - - input = torch.rand(10, requires_grad=True) - output = model(input) - - def assert_grad_fn_is_not_portal(grad_fn, visited=None): - if visited is None: - visited = set() - if grad_fn in visited or grad_fn is None: - return - - assert not isinstance(grad_fn, PortalBlue._backward_cls) - assert not isinstance(grad_fn, PortalCopy._backward_cls) - assert not isinstance(grad_fn, PortalOrange._backward_cls) - - visited.add(grad_fn) - for next_grad_fn, _ in grad_fn.next_functions: - assert_grad_fn_is_not_portal(next_grad_fn, visited) - - assert_grad_fn_is_not_portal(output.local_value().grad_fn) - - output.local_value().sum().backward() - assert input.grad.mean().item() == 1 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py deleted file mode 100644 index 4d542285cd5a..000000000000 --- a/test/distributed/pipeline/sync/skip/test_inspect_skip_layout.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import inspect_skip_layout -from torch.testing._internal.common_utils import run_tests - - -class Pass(nn.Module): - def forward(self, input): - return input - - -@skippable(stash=["foo"]) -class StashFoo(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input # noqa: B901 - - -@skippable(pop=["foo"]) -class PopFoo(nn.Module): - def forward(self, input): - foo = yield stash("foo") - return input + foo - - -@skippable(stash=["bar"]) -class StashBar(nn.Module): - def forward(self, input): - yield stash("bar", input) - return input # noqa: B901 - - -@skippable(pop=["bar"]) -class PopBar(nn.Module): - def forward(self, input): - bar = yield pop("bar") - return input + bar - - -def test_no_skippables(): - p1 = nn.Sequential(Pass()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_inner_partition(): - p1 = nn.Sequential(StashFoo(), PopFoo()) - p2 = nn.Sequential(Pass()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], []] - - -def test_adjoining_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2]) - policy = [list(layout.copy_policy(i)) for i in range(2)] - - assert policy == [[], [(0, None, "foo")]] - - -def test_far_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(Pass()) - p3 = nn.Sequential(PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - assert policy == [[], [], [(0, None, "foo")]] - - -def test_pop_2_from_different_partitions(): - p1 = nn.Sequential(StashFoo()) - p2 = nn.Sequential(StashBar()) - p3 = nn.Sequential(PopBar(), PopFoo()) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] - - -def test_namespace(): - ns1 = Namespace() - ns2 = Namespace() - - p1 = nn.Sequential(StashFoo().isolate(ns1)) - p2 = nn.Sequential(StashFoo().isolate(ns2)) - p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) - - layout = inspect_skip_layout([p1, p2, p3]) - policy = [list(layout.copy_policy(i)) for i in range(3)] - - # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. - assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_leak.py b/test/distributed/pipeline/sync/skip/test_leak.py deleted file mode 100644 index f4d1043e0549..000000000000 --- a/test/distributed/pipeline/sync/skip/test_leak.py +++ /dev/null @@ -1,136 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import is_checkpointing, is_recomputing, Pipe -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import current_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@skippable(stash=["skip"]) -class Stash(nn.Module): - def forward(self, input): - yield stash("skip", input) - return input # noqa: B901 - - -@skippable(pop=["skip"]) -class Pop(nn.Module): - def forward(self, input): - skip = yield pop("skip") - return input + skip - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint, setup_rpc): - # Without checkpointing: - # +- Stash --+ +--- Pop ----+ - - - layers - # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function - # +----------+ +------------+ - # - # With checkpointing: - # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ - # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | - # +----------+ +------------+ +------------+ +----------+ - - def portal_tensor_life_is(tensor_life, skip_tracker=None): - if skip_tracker is None: - skip_tracker = current_skip_tracker() - - # Get the current portal. - portal = next(iter(skip_tracker.portals.values())) - - if tensor_life == 0: - return portal.tensor_life == 0 and portal.tensor is None - else: - return portal.tensor_life == tensor_life and portal.tensor is not None - - # Check the portal tensor after 'Stash'. - stash_ = Stash() - - @stash_.register_forward_hook - def check_portal_tensor_after_stash(*_): - if is_checkpointing(): - assert portal_tensor_life_is(2) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(1) - - pop_ = Pop() - - @pop_.register_forward_hook - def check_portal_tensor_after_pop(*_): - if is_checkpointing(): - assert portal_tensor_life_is(1) - elif is_recomputing(): - assert portal_tensor_life_is(0) - else: - assert portal_tensor_life_is(0) - - class NoPortalTensorAtBackward(nn.Module): - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - ctx.skip_tracker = current_skip_tracker() - return input.detach() - - @staticmethod - def backward(ctx, grad): - assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) - return grad - - def forward(self, input): - return self.F.apply(input) - - model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input).local_value() - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): - def deny(*args, **kwargs): - raise AssertionError("tried to create Portal without Pipe") - - monkeypatch.setattr( - "torch.distributed.pipeline.sync.skip.portal.Portal.__init__", deny - ) - - model = nn.Sequential(Stash(), Pop()) - - input = torch.rand(10, requires_grad=True) - - if train: - model.train() - output = model(input) - output.norm().backward() - else: - model.eval() - with torch.no_grad(): - model(input) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_portal.py b/test/distributed/pipeline/sync/skip/test_portal.py deleted file mode 100644 index 5ad180b6f9c8..000000000000 --- a/test/distributed/pipeline/sync/skip/test_portal.py +++ /dev/null @@ -1,163 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.skip.portal import Portal -from torch.distributed.pipeline.sync.stream import default_stream -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_copy_returns_on_next_device(): - portal = Portal(torch.rand(1), tensor_life=1) - - prev_stream = default_stream(torch.device("cpu")) - next_stream = default_stream(torch.device("cuda")) - - phony = torch.zeros(0, requires_grad=True) - assert phony.device.type == "cpu" - - phony = portal.copy(prev_stream, next_stream, phony) - assert phony.device.type == "cuda" - - -def test_blue_orange(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1, requires_grad=True) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert torch.allclose(tensor2.grad, torch.tensor([1.0])) - - -def test_blue_orange_not_requires_grad(): - tensor1 = torch.rand(1, requires_grad=True) - tensor2 = torch.rand(1) - - # Same with: output = tensor1*2 + tensor2 - # - # +----------------------+ - # | | - # tensor2 -- PortalBlue -+ +- PortalOrange -+ - # | | | - # tensor1 ------------ Join -- Fork --- Mul --- Add -- output - # - main = tensor1 - portal = Portal(tensor2, tensor_life=2) - phony = portal.blue() - main = join(main, phony) - main, phony = fork(main) - sub = portal.orange(phony) - output = main * 2 + sub - - output.backward() - - assert torch.allclose(tensor1.grad, torch.tensor([2.0])) - assert tensor2.grad is None - - -def test_use_grad(): - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life=1) - - portal.put_grad(tensor) - assert portal.use_grad() is tensor - - # Gradient in a portal is ephemeral. - with pytest.raises(RuntimeError): - portal.use_grad() - - -class TestTensorLife: - @pytest.fixture - def new_portal(self): - portal = None - - def new_portal(tensor_life): - nonlocal portal - tensor = torch.rand(1, requires_grad=True) - portal = Portal(tensor, tensor_life) - return portal, tensor - - yield new_portal - - # A test using this fixture must exhaust the tensor in the portal. - with pytest.raises(RuntimeError): - portal.check_tensor_life() - assert portal.tensor is None - - def test_tensor_life_0(self, new_portal): - portal, tensor = new_portal(0) - assert portal.tensor is None - - def test_tensor_life_1(self, new_portal): - portal, tensor = new_portal(1) - assert portal.tensor is tensor - - portal.blue() - - def test_tensor_life_2(self, new_portal): - portal, tensor = new_portal(2) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_3(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - def test_tensor_life_4(self, new_portal): - portal, tensor = new_portal(4) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - portal.blue() - - def test_tensor_life_3_plus_1(self, new_portal): - portal, tensor = new_portal(3) - assert portal.tensor is tensor - - phony = portal.blue() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - assert portal.orange(phony).data_ptr() == tensor.data_ptr() - - another_tensor = torch.rand(1, requires_grad=True) - portal.put_tensor(another_tensor, tensor_life=1) - portal.blue() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_stash_pop.py b/test/distributed/pipeline/sync/skip/test_stash_pop.py deleted file mode 100644 index 5d273860f6a6..000000000000 --- a/test/distributed/pipeline/sync/skip/test_stash_pop.py +++ /dev/null @@ -1,144 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker -from torch.testing._internal.common_utils import run_tests - - -@pytest.fixture(autouse=True) -def skip_tracker(): - skip_tracker = SkipTracker() - with use_skip_tracker(skip_tracker): - yield skip_tracker - - -def test_stash(skip_tracker): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - assert len(skip_tracker.tensors) == 0 - - with use_skip_tracker(skip_tracker): - l1(torch.tensor(42)) - - assert len(skip_tracker.tensors) == 1 - - -def test_pop(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - output = l2(l1(torch.tensor(42))) - - assert output.item() == 42 - - -def test_declare_but_not_use(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - return input * 2 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - return input * 3 - - l1 = Stash() - l2 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(torch.tensor(42)) - - -def test_stash_not_declared(): - @skippable() - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - l1 = Stash() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_pop_not_declared(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable() - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - l1 = Stash() - l2 = Pop() - - latent = l1(torch.tensor(42)) - - with pytest.raises(RuntimeError): - l2(latent) - - -def test_pop_not_stashed(): - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - yield pop("foo") - - l1 = Pop() - - with pytest.raises(RuntimeError): - l1(torch.tensor(42)) - - -def test_stash_none(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", None) - return input * 2 # noqa: B901 - - l1 = Stash() - l1(torch.tensor(42)) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_tracker.py b/test/distributed/pipeline/sync/skip/test_tracker.py deleted file mode 100644 index 9c3a970f7574..000000000000 --- a/test/distributed/pipeline/sync/skip/test_tracker.py +++ /dev/null @@ -1,145 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading -from queue import Queue - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - enable_checkpointing, - enable_recomputing, -) -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.skip import pop, skippable, stash -from torch.distributed.pipeline.sync.skip.layout import SkipLayout -from torch.distributed.pipeline.sync.skip.tracker import ( - current_skip_tracker, - SkipTracker, - SkipTrackerThroughPotals, -) -from torch.testing._internal.common_utils import run_tests - - -def test_default_skip_tracker(): - q = Queue() - - def f(): - q.put(current_skip_tracker()) - - t = threading.Thread(target=f) - t.start() - t.join() - - skip_tracker = q.get() - - assert type(skip_tracker) is SkipTracker - assert type(skip_tracker) is not SkipTrackerThroughPotals - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_default_skip_tracker_by_data_parallel(): - @skippable(stash=["foo"]) - class Stash(nn.Module): - def forward(self, input): - yield stash("foo", input) - return input * 2 # noqa: B901 - - @skippable(pop=["foo"]) - class Pop(nn.Module): - def forward(self, input): - foo = yield pop("foo") - return foo - - model = nn.Sequential(Stash(), Pop()) - model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) - - input = torch.rand(10, device=0) - output = model(input) - - assert torch.allclose(output, input) - - -def test_reuse_portal(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", a) - portal = skip_tracker.portals[(None, "test")] - - skip_tracker.save(batch, None, "test", b) - assert portal is skip_tracker.portals[(None, "test")] - - -def test_no_copy_no_portal(): - skip_layout = SkipLayout( - num_partitions=2, - skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}, - ) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - a = torch.tensor([2.0]) - b = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "copy", a) - skip_tracker.save(batch, None, "not_copy", b) - - assert (None, "copy") in skip_tracker.portals - assert (None, "copy") not in skip_tracker.tensors - assert (None, "not_copy") in skip_tracker.tensors - assert (None, "not_copy") not in skip_tracker.portals - - -def test_tensor_life_without_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -def test_tensor_life_with_checkpointing(): - skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) - skip_tracker = SkipTrackerThroughPotals(skip_layout) - - batch = Batch(torch.tensor([1.0])) - tensor = torch.tensor([2.0]) - - with enable_checkpointing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 2 - - with enable_checkpointing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 1 - - with enable_recomputing(): - skip_tracker.load(batch, None, "test") - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - with enable_recomputing(): - skip_tracker.save(batch, None, "test", tensor) - assert skip_tracker.portals[(None, "test")].tensor_life == 0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/skip/test_verify_skippables.py b/test/distributed/pipeline/sync/skip/test_verify_skippables.py deleted file mode 100644 index 1d5941487da8..000000000000 --- a/test/distributed/pipeline/sync/skip/test_verify_skippables.py +++ /dev/null @@ -1,165 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -from torch import nn - -from torch.distributed.pipeline.sync.skip import Namespace, skippable, verify_skippables -from torch.testing._internal.common_utils import run_tests - - -def test_matching(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2())) - - -def test_stash_not_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "no module declared 'foo' as poppable but stashed" in str(e.value) - - -def test_pop_unknown(): - @skippable(pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) - - -def test_stash_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'1' redeclared 'foo' as stashable" in str(e.value) - - -def test_pop_again(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer3(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - assert "'2' redeclared 'foo' as poppable" in str(e.value) - - -def test_stash_pop_together_different_names(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"], stash=["bar"]) - class Layer2(nn.Module): - pass - - @skippable(pop=["bar"]) - class Layer3(nn.Module): - pass - - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) - - -def test_stash_pop_together_same_name(): - @skippable(stash=["foo"], pop=["foo"]) - class Layer1(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1())) - assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) - - -def test_double_stash_pop(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - with pytest.raises(TypeError) as e: - verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) - assert "'2' redeclared 'foo' as stashable" in str(e.value) - assert "'3' redeclared 'foo' as poppable" in str(e.value) - - -def test_double_stash_pop_but_isolated(): - @skippable(stash=["foo"]) - class Layer1(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer2(nn.Module): - pass - - @skippable(stash=["foo"]) - class Layer3(nn.Module): - pass - - @skippable(pop=["foo"]) - class Layer4(nn.Module): - pass - - ns1 = Namespace() - ns2 = Namespace() - - verify_skippables( - nn.Sequential( - Layer1().isolate(ns1), - Layer2().isolate(ns1), - Layer3().isolate(ns2), - Layer4().isolate(ns2), - ) - ) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_balance.py b/test/distributed/pipeline/sync/test_balance.py deleted file mode 100644 index faf09f4581ae..000000000000 --- a/test/distributed/pipeline/sync/test_balance.py +++ /dev/null @@ -1,240 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import time - -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync._balance import ( - balance_by_size, - balance_by_time, - blockpartition, -) -from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -def test_blockpartition(): - assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [ - [1, 2, 3, 4], - [5, 6], - ] - - -def test_blockpartition_zeros(): - assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] - - -def test_blockpartition_non_positive_partitions(): - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=0) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=-1) - - -def test_blockpartition_short_sequence(): - with pytest.raises(ValueError): - blockpartition.solve([], partitions=1) - with pytest.raises(ValueError): - blockpartition.solve([42], partitions=2) - - -@pytest.mark.parametrize("device", devices) -@pytest.mark.skip(reason="Flaky due to time.sleep()") -def test_balance_by_time(device): - class Delay(nn.Module): - def __init__(self, seconds): - super().__init__() - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - return x - - model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) - sample = torch.rand(1) - balance = balance_by_time(2, model, sample, device=device) - assert balance == [4, 2] - - -def test_balance_by_time_loop_resets_input(): - # nn.Flatten was introduced at PyTorch 1.2.0. - class Flatten(nn.Module): - def forward(self, x): - return x.flatten(1) - - model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) - sample = torch.rand(10, 3, 8, 8) - balance = balance_by_time(2, model, sample, device="cpu") - assert balance == [1, 2] - - -@skip_if_no_cuda -def test_balance_by_size_latent(): - class Expand(nn.Module): - def __init__(self, times): - super().__init__() - self.times = times - - def forward(self, x): - for i in range(self.times): - x = x + torch.rand_like(x, requires_grad=True) - return x - - sample = torch.rand(10, 100, 100) - - model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) - balance = balance_by_size(2, model, sample) - assert balance == [4, 2] - - model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) - balance = balance_by_size(2, model, sample) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param(): - model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) - sample = torch.rand(7, 1) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) - sample = torch.rand(1, 7) - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [2, 4] - - -@skip_if_no_cuda -def test_balance_by_size_param_scale(): - class Tradeoff(nn.Module): - def __init__(self, param_size, latent_size): - super().__init__() - self.fc = nn.Linear(param_size, param_size) - self.latent_size = latent_size - - def forward(self, x): - for i in range(self.latent_size): - x = x + torch.rand_like(x, requires_grad=True) - return x - - model = nn.Sequential( - Tradeoff(param_size=1, latent_size=6), - Tradeoff(param_size=2, latent_size=5), - Tradeoff(param_size=3, latent_size=4), - Tradeoff(param_size=4, latent_size=3), - Tradeoff(param_size=5, latent_size=2), - Tradeoff(param_size=6, latent_size=1), - ) - - sample = torch.rand(1, requires_grad=True) - - balance = balance_by_size(2, model, sample, param_scale=0) - assert balance == [2, 4] - - balance = balance_by_size(2, model, sample, param_scale=100) - assert balance == [4, 2] - - -@pytest.mark.parametrize("device", devices) -def test_layerwise_sandbox(device): - model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - model.eval() - - for layer in layerwise_sandbox(model, torch.device(device)): - assert layer.training - assert all(p.device.type == device for p in layer.parameters()) - - assert all(not l.training for l in model) - assert all(p.device.type == "cpu" for p in model.parameters()) - - -@pytest.mark.parametrize("device", devices) -def test_sandbox_during_profiling(device): - model = nn.Sequential(nn.BatchNorm2d(3)) - - before = {k: v.clone() for k, v in model.state_dict().items()} - - sample = torch.rand(1, 3, 10, 10) - balance_by_time(1, model, sample, device=device) - - after = model.state_dict() - - assert before.keys() == after.keys() - for key, value in before.items(): - assert torch.allclose(after[key], value), key - - -def test_not_training(): - class AssertTraining(nn.Module): - def forward(self, x): - assert self.training - return x - - model = nn.Sequential(AssertTraining()) - - model.eval() - assert not model.training - - sample = torch.rand(1) - balance_by_time(1, model, sample, device="cpu") - - assert not model.training - - -def test_balance_by_time_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_time(1, model, sample, device="cpu") - - -@skip_if_no_cuda -def test_balance_by_size_tuple(): - class Twin(nn.Module): - def forward(self, x): - return x, x.detach() - - class Add(nn.Module): - def forward(self, a, b): - return a + b - - model = nn.Sequential(Twin(), Add()) - sample = torch.rand(1, requires_grad=True) - balance_by_size(1, model, sample) - - -def test_already_has_grad(): - model = nn.Sequential(nn.Conv2d(3, 3, 1)) - sample = torch.rand(1, 3, 32, 32) - model(sample).norm().backward() - - with pytest.raises(ValueError, match="some parameter already has gradient"): - balance_by_time(1, model, sample, device="cpu") - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_bugs.py b/test/distributed/pipeline/sync/test_bugs.py deleted file mode 100644 index 928a78db6e32..000000000000 --- a/test/distributed/pipeline/sync/test_bugs.py +++ /dev/null @@ -1,146 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.nn.functional as F -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests - - -def test_python_autograd_function(setup_rpc): - # A Python autograd function might fail with this error: - # - # RuntimeError: Returning Variables sharing storage with other Variables - # that require grad is not supported in Python functions. Please submit a - # feature request if you hit this error. - # - # It doesn't look like an essential restriction. But it happens on the - # current PyTorch version. To avoid it, we should detach the tensor before - # returning by identity autograd functions, such as Wait, Fork, and Join. - # - class Identity(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - return grad - - class M(nn.Module): - def forward(self, input): - return Identity.apply(input) - - model = nn.Sequential(M(), M()) - model = Pipe(model, checkpoint="always") - - x = torch.rand(42) - y = model(x) - assert torch.allclose(x, y.local_value()) - - -def test_exception_no_hang(setup_rpc): - # In v0.0.2, once a failed partition receives a normal message - # (non-closing) for the next micro-batch, a hang occurred. The reason was - # that a failed partition didn't call in_queue.task_done() on a normal - # message. So the former partition was blocked at out_queue.join() for the - # next of next micro-batch. - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep, setup_rpc): - # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. - # Under this behavior, if checkpointing was disabled, there's a possibility - # that gradient accumulations on other tensors are not synchronized - # properly to the copy stream. - class Sleep(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x.detach() - - @staticmethod - def backward(ctx, grad): - with torch.cuda.device(grad.device): - cuda_sleep(0.05) - return grad - - class Layer1(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b): - a = a * self.ones - return a * 1, b * 2, b * 3 - - class Layer2(nn.Module): - def __init__(self): - super().__init__() - self.ones = nn.Parameter(torch.ones(32, 3, 32, 32, requires_grad=True)) - - def forward(self, a, b, c): - a = a * self.ones - b = Sleep.apply(b) - return a + b + c - - model = nn.Sequential(Layer1().cuda(0), Layer2().cuda(1)) - model = Pipe(model, chunks=32, checkpoint="never") - - a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) - - y = model(a, b) - y.local_value().norm().backward() - - torch.cuda.synchronize(0) - torch.cuda.synchronize(1) - - assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) - - -def test_parallel_randoms(setup_rpc): - class Dropouts(nn.Module): - def forward(self, x): - for _ in range(100): - x = F.dropout(x, p=0.001) - return x - - model = nn.Sequential(Dropouts(), Dropouts()) - - x = torch.rand(10, 10, requires_grad=True) - model = Pipe(model, chunks=10, checkpoint="always") - y = model(x) - y = y.local_value() - y.norm().backward() - - assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_checkpoint.py b/test/distributed/pipeline/sync/test_checkpoint.py deleted file mode 100644 index 7be8ddefafe9..000000000000 --- a/test/distributed/pipeline/sync/test_checkpoint.py +++ /dev/null @@ -1,178 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from functools import partial - -import pytest - -import torch -import torch.cuda -from torch import nn - -from torch.distributed.pipeline.sync.checkpoint import ( - checkpoint, - Checkpointing, - is_checkpointing, - is_recomputing, -) -from torch.distributed.pipeline.sync.dependency import fork, join -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.testing._internal.common_utils import run_tests - -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") - - -@pytest.mark.parametrize("device", devices) -def test_serial_checkpoints(device): - # Copied from https://github.com/pytorch/pytorch/pull/18568. - timeline = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, name, x): - ctx.name = name - timeline.append(f"{name}:forward") - return x.detach() - - @staticmethod - def backward(ctx, grad_output): - name = ctx.name - timeline.append(f"{name}:backward") - return None, grad_output - - a = torch.rand(1, device=device, requires_grad=True) - b = torch.rand(1, device=device, requires_grad=True) - - # Increase the next function sequence number. - _ = a + 1 + 2 + 3 + 4 + 5 - - a = checkpoint(partial(Log.apply, "a"), a) - - a, phony = fork(a) - b = join(b, phony) - - b = checkpoint(partial(Log.apply, "b"), b) - - c = torch.cat((a, b)) - - out = c.sum() - - # +--> {a} --Checkpoint(Log)--> {a} - # {out} --Sum--> {c} --Cat ^-----------------------------+ - # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} - out.backward() - - assert timeline == [ - "a:forward", - "b:forward", - "b:forward", - "b:backward", - "a:forward", - "a:backward", - ] - # |----------------------| |-----------------------| |-----------------------| - # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) - - -def test_not_requires_grad(): - x = Batch(torch.rand(1, requires_grad=False)) - assert not x[0].requires_grad - - def f(x): - return x * 2 - - chk = Checkpointing(f, x) - x = chk.checkpoint() - assert x[0].requires_grad - - chk.recompute(x) - assert x[0].requires_grad - - x.tensor.backward() - - -def test_not_requires_grad_with_parameter(): - x = torch.rand(1, requires_grad=False) - a = torch.rand(1, requires_grad=True) - - def f(x): - return x * a - - y = checkpoint(f, x) - y.backward() - - assert a.grad is not None - - -@pytest.mark.parametrize("device", devices) -def test_random_in_checkpoint(device): - dropout = nn.Dropout(p=0.5) - - torch.manual_seed(0) - x = torch.randn(3, 3, device=device, requires_grad=True) - y = dropout(x) - y.norm().backward() - - torch.manual_seed(0) - chk_x = torch.randn(3, 3, device=device, requires_grad=True) - chk_y = checkpoint(dropout, chk_x) - chk_y.norm().backward() - - assert torch.allclose(x.grad, chk_x.grad) - - -def test_detect_checkpointing_recomputing(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output.backward() - - assert logs == [(True, False), (False, True)] - - -def test_detect_checkpointing_recomputing_without_checkpoint(): - logs = [] - - class Detect(nn.Module): - def forward(self, input): - logs.append((is_checkpointing(), is_recomputing())) - return input - - model = Detect() - input = torch.rand(1, requires_grad=True) - - output = model(input) - output.backward() - - assert logs == [(False, False)] - - -def test_non_grad_output(): - class ForkNonGrad(nn.Module): - def forward(self, input): - return (input * 2, torch.rand(1)) - - model = ForkNonGrad() - input = torch.rand(1, requires_grad=True) - - output = checkpoint(model, input) - output[0].backward() - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_copy.py b/test/distributed/pipeline/sync/test_copy.py deleted file mode 100644 index 302c3d25d53f..000000000000 --- a/test/distributed/pipeline/sync/test_copy.py +++ /dev/null @@ -1,85 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.copy import Copy, Wait -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - get_device, - is_cuda, - new_stream, - use_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): - device = get_device(prev_stream) - - with use_stream(prev_stream): - if is_cuda(prev_stream): - cuda_sleep(0.5) - x = torch.ones(100, device=device, requires_grad=True) - - (y,) = Copy.apply(prev_stream, next_stream, x) - (y,) = Wait.apply(prev_stream, next_stream, x) - - with use_stream(next_stream): - assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) - y.norm().backward() - with use_stream(prev_stream): - assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) - - -def test_copy_wait_cpu_cpu(): - prev_stream = CPUStream - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream) - - -@skip_if_no_cuda -def test_copy_wait_cpu_cuda(cuda_sleep): - prev_stream = CPUStream - next_stream = current_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cpu(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = CPUStream - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -@skip_if_no_cuda -def test_copy_wait_cuda_cuda(cuda_sleep): - prev_stream = current_stream(torch.device("cuda")) - next_stream = new_stream(torch.device("cuda")) - _test_copy_wait(prev_stream, next_stream, cuda_sleep) - - -def test_wait_multiple_tensors(): - a = torch.rand(1, requires_grad=True) - b = torch.rand(1, requires_grad=True) - - a, b = Wait.apply(CPUStream, CPUStream, a, b) - - assert a.grad_fn is b.grad_fn - assert a.grad_fn.__class__ is Wait._backward_cls - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_deferred_batch_norm.py b/test/distributed/pipeline/sync/test_deferred_batch_norm.py deleted file mode 100644 index c3807c57d612..000000000000 --- a/test/distributed/pipeline/sync/test_deferred_batch_norm.py +++ /dev/null @@ -1,200 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from copy import deepcopy -from itertools import chain - -import pytest - -import torch -from torch import nn, optim - -from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm -from torch.testing._internal.common_utils import run_tests - -CHUNKS = 4 - - -def tilt_dist(input): - # Tilt variance by channel. - rgb = input.transpose(0, 1) - rgb[0] *= 1 - rgb[1] *= 10 - rgb[2] *= 100 - - # Tilt mean by single batch. - for i, single in enumerate(input): - single += 2**i - - return input - - -def chunked_forward(model, input, chunks=CHUNKS): - output_chunks = [] - - for chunk in input.chunk(chunks): - output_chunks.append(model(chunk)) - - return torch.cat(output_chunks) - - -@pytest.mark.parametrize("chunks", [1, 4]) -@pytest.mark.parametrize("input_requires_grad", [True, False]) -def test_transparency(chunks, input_requires_grad): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) - - input1 = torch.rand(16, 3, 224, 224) - input1 = tilt_dist(input1) - input2 = input1.clone() - input1.requires_grad = input_requires_grad - input2.requires_grad = input_requires_grad - - output1 = chunked_forward(bn, input1, chunks=chunks) - output2 = chunked_forward(dbn, input2, chunks=chunks) - - assert torch.allclose(output1, output2, atol=1e-4) - - output1.mean().backward() - output2.mean().backward() - - assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) - - if input_requires_grad: - assert input1.grad is not None - assert input2.grad is not None - assert torch.allclose(input1.grad, input2.grad, atol=1e-4) - - -@pytest.mark.parametrize("momentum", [0.1, None]) -def test_running_stats(momentum): - bn = nn.BatchNorm2d(3, momentum=momentum) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) - assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) - - -def test_convert_deferred_batch_norm(): - bn = nn.BatchNorm2d(3, track_running_stats=False) - bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) - assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False - - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) - assert dbn is dbn_again - - dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) - assert dbn is not dbn_again # because of different chunks - - -def test_eval(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - bn(input) - chunked_forward(dbn, input) - - bn.eval() - dbn.eval() - - assert torch.allclose(bn(input), dbn(input), atol=1e-4) - - -def test_optimize(): - bn = nn.BatchNorm2d(3) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) - - for i in range(5): - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - # train - y = bn(input) - a = y.sum() - a.backward() - - y = chunked_forward(dbn, input) - b = y.sum() - b.backward() - - opt.step() - - # eval - bn.eval() - dbn.eval() - - with torch.no_grad(): - assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i)) - - -def test_conv_bn(): - bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) - dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - - opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) - - # 1st step - a = bn(input) - b = chunked_forward(dbn, input) - - # Outputs are different. (per-mini-batch vs. per-micro-batch) - assert not torch.allclose(a, b) - - a.sum().backward() - b.sum().backward() - opt.step() - opt.zero_grad() - - # Conv layers are also trained differently because of their different outputs. - assert not torch.allclose(bn[0].weight, dbn[0].weight) - - # But BNs track identical running stats. - assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - # 2nd step - a = bn(input) - b = chunked_forward(dbn, input) - a.sum().backward() - b.sum().backward() - - # BNs can't track identical running stats due to the different conv layers. - assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) - assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) - - -def test_input_requiring_grad(): - dbn = DeferredBatchNorm(3, chunks=CHUNKS) - - input = torch.rand(16, 3, 224, 224) - input = tilt_dist(input) - input.requires_grad = True - - chunked_forward(dbn, input) - - assert not dbn.sum.requires_grad - assert dbn.sum.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_dependency.py b/test/distributed/pipeline/sync/test_dependency.py deleted file mode 100644 index e966d6541bf5..000000000000 --- a/test/distributed/pipeline/sync/test_dependency.py +++ /dev/null @@ -1,152 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import weakref - -import pytest - -import torch - -from torch.distributed.pipeline.sync.dependency import Fork, fork, Join, join -from torch.testing._internal.common_utils import run_tests - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def test_fork_join(): - logs = [] - - class Log(torch.autograd.Function): - @staticmethod - def forward(ctx, number, tensor): - ctx.number = number - return tensor.detach() - - @staticmethod - def backward(ctx, grad): - logs.append(ctx.number) - return None, grad - - a = torch.rand(1, device="cpu", requires_grad=True) - b = torch.rand(1, device="cuda", requires_grad=True) - - a = Log.apply(1, a) - - a, phony = fork(a) - b = join(a, phony) - - b = Log.apply(2, b) - b = b.to("cpu") - - (a + b).backward() - - assert logs == [2, 1] - - -def test_fork_join_enable_grad(): - x = torch.rand(1, requires_grad=True) - - with torch.enable_grad(): - x2, p = fork(x) - - assert p.requires_grad - assert x2 is not x - x = x2 - - assert x.requires_grad - assert p.requires_grad - assert x.grad_fn.__class__ is Fork._backward_cls - assert p.grad_fn.__class__ is Fork._backward_cls - - with torch.enable_grad(): - x2 = join(x, p) - - assert x2 is not x - x = x2 - - assert x.requires_grad - assert x.grad_fn.__class__ is Join._backward_cls - - -def test_fork_join_no_grad(monkeypatch): - def do_not_apply(*args): - raise AssertionError("Function.apply called") - - monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) - - x = torch.rand(1, requires_grad=True) - - with torch.no_grad(): - x2, p = fork(x) - - assert not p.requires_grad - assert x2 is x - x = x2 - - with torch.no_grad(): - x2 = join(x, p) - - assert x2 is x - x = x2 - - -def test_fork_leak(): - leak = None - - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - return input - - @staticmethod - def backward(ctx, grad): - nonlocal leak - leak = weakref.ref(ctx) - return grad - - x = torch.rand(1, requires_grad=True) - x = F.apply(x) - x, phony = fork(x) - x = join(x, phony) - - x.backward() - del x, phony - - assert leak() is None - - -def test_join_when_fork_not_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - assert not a.requires_grad - a, p = fork(a) - assert not a.requires_grad - assert not p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert not b.requires_grad - - -def test_join_when_fork_requires_grad(): - x = torch.rand(2, 1) - a, b = x.chunk(2) - - a.requires_grad_() - assert a.requires_grad - a, p = fork(a) - assert a.requires_grad - assert p.requires_grad - - assert not b.requires_grad - b = join(b, p) - assert b.requires_grad - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_inplace.py b/test/distributed/pipeline/sync/test_inplace.py deleted file mode 100644 index 33f31b2a52bb..000000000000 --- a/test/distributed/pipeline/sync/test_inplace.py +++ /dev/null @@ -1,79 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_inplace_on_requires_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(setup_rpc): - # In-place operation on a tensor not requiring grad doesn't cause a - # RuntimeError. Currently, we cannot detect this case. - model = nn.Sequential(nn.ReLU(inplace=True)) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - x = torch.rand(1) - y = model(x).local_value() - del model - - message = r"a leaf Variable that requires grad .* used in an in-place operation." - with pytest.raises(RuntimeError, match=message): - y.backward() - - -@pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(setup_rpc): - class M(nn.Module): - def forward(self, foo_bar): - # 'foo' requires grad but 'bar' does not. In-place operation on - # 'bar' won't cause a RuntimeError. - foo, bar = foo_bar - - # add_(1) is not idempotent, in contrast to relu_(). If it is - # executed multiple times, it will accumulates each difference onto - # 'bar'. - bar.add_(1) - - # 'bar' is still captured by checkpointing. 'foo' will get - # incorrect grad. - return foo * bar - - model = nn.Sequential(M()) - model = Pipe(model, [1], devices=["cpu"], checkpoint="always") - - foo = torch.tensor([1.0], requires_grad=True) - bar = torch.tensor([1.0]) - - output = model((foo, bar)).local_value() - del model - output.backward() - - # The gradient of 'foo' should be 2, but it is 3 actually because - # bar.add_(1) was executed twice due to checkpointing. - assert foo.grad.item() == 2.0 - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_microbatch.py b/test/distributed/pipeline/sync/test_microbatch.py deleted file mode 100644 index b5e44aa73a8d..000000000000 --- a/test/distributed/pipeline/sync/test_microbatch.py +++ /dev/null @@ -1,148 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch -import torch.cuda - -from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter -from torch.testing._internal.common_utils import run_tests - - -def test_batch_atomic(): - x = torch.tensor(42) - b = Batch(x) - - assert b.atomic - - assert b.tensor is x - with pytest.raises(AttributeError): - b.tensors - - assert list(b) == [x] - assert len(b) == 1 - assert b[0] is x - - -def test_batch_non_atomic(): - x, y = torch.tensor(42), torch.tensor(21) - b = Batch((x, y)) - - assert not b.atomic - - with pytest.raises(AttributeError): - b.tensor - - assert list(b) == [x, y] - assert len(b) == 2 - assert b[0] is x - assert b[1] is y - - -def test_batch_call(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - def f(x): - return x - - def g(x, y): - return x, y - - assert a.call(f).atomic - assert not b.call(g).atomic - - -def test_batch_setitem_by_index(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[0] = torch.tensor(0) - b[0] = torch.tensor(0) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 2 - assert b[0].item() == 0 - assert b[1].item() == 21 - - -def test_batch_setitem_by_slice(): - a = Batch(torch.tensor(42)) - b = Batch((torch.tensor(42), torch.tensor(21))) - - a[:] = (torch.tensor(0),) - b[:] = (torch.tensor(0),) - - assert a.atomic - assert a[0].item() == 0 - - assert not b.atomic - assert len(b) == 1 - assert b[0].item() == 0 - - -def test_check(): - check(torch.device("cpu"), torch.tensor(42)) - check(torch.device("cpu"), torch.tensor(4), torch.tensor(2)) - - with pytest.raises(TypeError): - check(torch.device("cpu"), 42) - - with pytest.raises(TypeError): - check(torch.device("cpu"), "str") - - with pytest.raises(TypeError): - check(torch.device("cpu"), (torch.tensor(4), 2)) - - -def test_gather_tensors(): - a = torch.zeros(1, 1) - b = torch.zeros(1, 1) - - ab = gather([Batch(a), Batch(b)]) - - assert ab.size() == (2, 1) - - -def test_gather_tuples(): - a = (torch.zeros(1, 1), torch.zeros(2, 2)) - b = (torch.zeros(1, 1), torch.zeros(2, 2)) - - ab = gather([Batch(a), Batch(b)]) - - assert isinstance(ab, tuple) - assert ab[0].size() == (2, 1) - assert ab[1].size() == (4, 2) - - -def test_scatter_tensor(): - ab = torch.zeros(2, 1) - - a, b = scatter(ab, chunks=2) - - assert a.tensor.size() == (1, 1) - assert b.tensor.size() == (1, 1) - - -def test_scatter_multiple_tensors(): - ab = (torch.zeros(2, 1), torch.zeros(4, 2)) - - a, b = scatter(*ab, chunks=2) - - assert next(iter(a)).size() == (1, 1) - assert next(iter(b)).size() == (1, 1) - assert list(a)[1].size() == (2, 2) - assert list(b)[1].size() == (2, 2) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_phony.py b/test/distributed/pipeline/sync/test_phony.py deleted file mode 100644 index 6aeb873b30b2..000000000000 --- a/test/distributed/pipeline/sync/test_phony.py +++ /dev/null @@ -1,57 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from torch.distributed.pipeline.sync.phony import get_phony -from torch.testing._internal.common_utils import run_tests - - -def test_phony_size(): - p = get_phony(torch.device("cpu"), requires_grad=False) - assert p.size() == (0,) - - -def test_phony_requires_grad(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=False) - assert p1.requires_grad - assert not p2.requires_grad - - -def test_cached_phony(): - p1 = get_phony(torch.device("cpu"), requires_grad=True) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - assert p1 is p2 - - p3 = get_phony(torch.device("cpu"), requires_grad=False) - p4 = get_phony(torch.device("cpu"), requires_grad=False) - assert p3 is p4 - - assert p1 is not p3 - - -def test_phony_in_autograd_function(): - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() - - x = torch.rand(1, requires_grad=True) - - p1 = Phonify.apply(x) - p2 = get_phony(torch.device("cpu"), requires_grad=True) - - assert p1 is not p2 - assert p1.grad_fn is not None - assert p2.grad_fn is None - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py deleted file mode 100644 index e493b1d5a03e..000000000000 --- a/test/distributed/pipeline/sync/test_pipe.py +++ /dev/null @@ -1,858 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import random -import time -from collections import OrderedDict -from copy import deepcopy - -import pytest - -import torch -from torch import nn, Tensor - -from torch.distributed.pipeline.sync import NoChunk, Pipe, WithDevice -from torch.distributed.pipeline.sync.pipe import PipeSequential -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import run_tests, TEST_CUDA - -skip_if_no_cuda = pytest.mark.skipif(not TEST_CUDA, reason="cuda required") - - -def test_pipe_without_rpc(): - model = nn.Sequential(nn.Linear(1, 1)) - with pytest.raises(RuntimeError, match="Please initialize RPC framework"): - pipe = Pipe(model, chunks=1) - - -def test_parameters(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=1) - assert list(pipe.parameters()) != [] - - -def test_public_attrs(setup_rpc): - class MyString: - def __init__(self, value): - self.value = value - - def __str__(self): - return self.value - - model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) - - assert pipe.devices == [torch.device("cpu")] - assert pipe.chunks == 42 - assert isinstance(pipe.chunks, int) - assert pipe.checkpoint == "always" - assert isinstance(pipe.checkpoint, str) - - -def test_sequential_like(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert len(model) == 2 - assert list(model) == [a, b] - - assert model[0] is a - assert model[1] is b - with pytest.raises(IndexError): - _ = model[2] - - assert model[-1] is b - assert model[-2] is a - - -def test_chunks_less_than_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises(ValueError): - Pipe(model, chunks=0) - - with pytest.raises(ValueError): - Pipe(model, chunks=-1) - - -def test_batch_size_indivisible(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(7, 1)) - - # Indivisible batch size is legal. - assert not record - - -def test_batch_size_small(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=4) - - with pytest.warns(None) as record: - model(torch.rand(2, 1)) - - # Batch size smaller than chunks is legal. - assert not record - - -def test_checkpoint_mode(setup_rpc): - def count_grad_fn(grad_fn, name, visited=None): - if visited is None: - visited = set() - if grad_fn in visited: - return 0 - visited.add(grad_fn) - - if grad_fn is None: - return 0 - if grad_fn.__class__.__name__ == name: - return 1 - - counter = 0 - for next_grad_fn, _ in grad_fn.next_functions: - counter += count_grad_fn(next_grad_fn, name, visited=visited) - return counter - - model = nn.Sequential(nn.Linear(1, 1)) - input = torch.rand(2, 1) - - always = Pipe(model, chunks=2, checkpoint="always") - except_last = Pipe(model, chunks=2, checkpoint="except_last") - never = Pipe(model, chunks=2, checkpoint="never") - - always_output = always(input) - except_last_output = except_last(input) - never_output = never(input) - - assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 - assert ( - count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") - == 1 - ) - assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 - - -def test_checkpoint_mode_invalid(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - with pytest.raises( - ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'" - ): - Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") - - -def test_checkpoint_mode_when_chunks_1(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - - # All checkpoint modes are fine. - Pipe(model, chunks=1, checkpoint="except_last") - Pipe(model, chunks=1, checkpoint="always") - Pipe(model, chunks=1, checkpoint="never") - - -def test_checkpoint_eval(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - def find_grad_fn(grad_fn, name): - if grad_fn is None: - return False - if grad_fn.__class__.__name__ == name: - return True - for next_grad_fn, _ in grad_fn.next_functions: - if find_grad_fn(next_grad_fn, name): - return True - return False - - model.train() - train_output = model(input) - assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") - - model.eval() - eval_output = model(input) - assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") - - -def test_checkpoint_non_float_input(setup_rpc): - class ForkNonFloat(nn.Module): - def forward(self, input): - return (input * 2, torch.tensor([False])) - - class JoinNonFloat(nn.Module): - def forward(self, input, non_float): - return input * 2 - - model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) - model = Pipe(model, chunks=1, checkpoint="always") - - input = torch.rand(1, requires_grad=True) - output = model(input) - output.backward() - - -def test_no_grad(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, chunks=2) - input = torch.rand(2, 1) - - latent = None - - def hook(module, input, output): - _ = module - _ = input - - nonlocal latent - latent = output - - partition = model.partitions[0] - partition.register_forward_hook(hook) - - with torch.no_grad(): - model(input) - - assert latent.grad_fn is None - - -def test_exception(setup_rpc): - class ExpectedException(Exception): - pass - - class Raise(nn.Module): - def forward(self, *_): - raise ExpectedException - - model = nn.Sequential(Raise()) - model = Pipe(model, chunks=1) - - with pytest.raises(ExpectedException): - model(torch.rand(1)) - - -def test_exception_early_stop_asap(setup_rpc): - """Even the first partitions have finished to process, the partition before - the failed partition should be killed as soon as possible. - """ - - class ExpectedException(Exception): - pass - - class Pass(nn.Module): - def forward(self, x): - return x - - counter = 0 - - class Counter(nn.Module): - def forward(self, x): - time.sleep(0.1) - - nonlocal counter - counter += 1 - - return x - - class Raise(nn.Module): - def forward(self, x): - raise ExpectedException - - model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, chunks=3) - - with pytest.raises(ExpectedException): - model(torch.rand(3)) - - # If the early stop doesn't work, it would be 3 instead. - assert counter == 2 - - -def test_nested_input(setup_rpc): - class NestedInput(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, inp): - return inp - - model = nn.Sequential(NestedInput()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - # TypeError: expected Tensor, but got tuple - with pytest.raises(TypeError): - model((a, (a, b))).local_value() - - # TypeError: expected Tensor, but got list - with pytest.raises(TypeError): - model((a, [a, b])).local_value() - - -def test_input_pair(setup_rpc): - class Two(nn.Module): - def __init__(self): - super().__init__() - self.fc_a = nn.Linear(1, 1) - self.fc_b = nn.Linear(1, 1) - - def forward(self, a, b): - return (self.fc_a(a), self.fc_b(b)) - - model = nn.Sequential(Two()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - b = torch.rand(10, 1, requires_grad=True) - - a_out, b_out = model(a, b).local_value() - loss = (a_out + b_out).mean() - loss.backward() - - assert a.grad is not None - assert b.grad is not None - - -def test_multi_sequence_input(setup_rpc): - class MultiSeq(nn.Module): - def forward(self, tup1, tup2): - return tup1, tup2 - - model = Pipe(nn.Sequential(MultiSeq())) - with pytest.raises(TypeError): - model([torch.rand(10), torch.rand(10)], [torch.rand(10), torch.rand(10)]) - - -def test_input_singleton(setup_rpc): - class One(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(1, 1) - - def forward(self, a): - return (self.fc(a),) - - model = nn.Sequential(One()) - model = Pipe(model, chunks=2) - - a = torch.rand(10, 1, requires_grad=True) - - (a_out,) = model(a).local_value() - loss = a_out.mean() - loss.backward() - - assert all(p.grad is not None for p in model.parameters()) - assert a.grad is not None - - -def test_input_varargs(setup_rpc): - model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model) - - a = torch.rand(1) - b = torch.rand(1) - - # TypeError: forward() takes 2 positional arguments but 3 were given - with pytest.raises(TypeError): - model(a, b) - - -def test_non_tensor(setup_rpc): - class NonTensor(nn.Module): - def forward(self, _): - return "hello" - - model = nn.Sequential(NonTensor()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model(x) - - with pytest.raises(TypeError): - model("hello") - - -def test_non_tensor_sequence(setup_rpc): - class NonTensorTuple(nn.Module): - def forward(self, x): - return (x, "hello") - - class NonTensorArgs(nn.Module): - def forward(self, x: str, y: bool): - return x, y - - model = nn.Sequential(NonTensorTuple()) - model = Pipe(model) - x = torch.rand(1) - - with pytest.raises(TypeError): - model((x, "hello")) - - with pytest.raises(TypeError): - model([x, "hello"]) - - model = nn.Sequential(NonTensorArgs()) - model = Pipe(model) - - with pytest.raises(TypeError): - # Need atleast one Tensor. - model("hello", True) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_valid_non_tensor(checkpoint, setup_rpc): - class NonTensor1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool, d: Tensor): - res = b + a if c else b * a - if d is not None: - res += d - return res, c, a, b, "hello", d - - class NonTensor2(nn.Module): - def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor): - res = a * c if b else a + c - res += d - return c, res, a, d + f if f is not None else d, b, e, f - - model = Pipe( - nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint - ) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - d = torch.rand(10, 10) - res = model(a, b, c, d).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a + d) * a) + b, res[1]) - assert torch.allclose(b + a + d, res[2]) - else: - assert torch.allclose(((b * a) + d + a) + b, res[1]) - assert torch.allclose(b * a + d, res[2]) - assert torch.allclose(b + d, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert torch.allclose(d, res[6]) - - # Test one of the tensors can be None - res = model(a, b, c, None).local_value() - assert 7 == len(res) - assert [a] * 5 == res[0] - if c: - assert torch.allclose(((b + a) * a) + b, res[1]) - assert torch.allclose(b + a, res[2]) - else: - assert torch.allclose(((b * a) + a) + b, res[1]) - assert torch.allclose(b * a, res[2]) - assert torch.allclose(b, res[3]) - assert [c] * 5 == res[4] - assert ["hello"] * 5 == res[5] - assert [None] * 5 == res[6] - - # Need atleast one tensor. - with pytest.raises(TypeError): - model(a, None, c, None) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_tensor_output(checkpoint, setup_rpc): - class Model1(nn.Module): - def forward(self, a: int, b: Tensor, c: bool): - return a, c, "hello" - - class Model2(nn.Module): - def forward(self, a: int, b: bool, c: str): - return a, c, b - - model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5) - a = random.randint(0, 10) - b = torch.rand(10, 10) - c = random.randint(0, 1) == 0 - - # Need atleast one tensor across partitions too. - with pytest.raises(TypeError): - res = model(a, b, c).local_value() - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_uneven_batch_size(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(6, 10) - res = model(a, b, c).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 3 == res[1] # 3 chunks - assert torch.allclose(c, res[2]) - - # Two tensors producing uneven chunks would fail. - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(3, 10) - b = random.randint(0, 10) - c = torch.rand(4, 10) - - with pytest.raises(RuntimeError, match="Found different number of chunks"): - model(a, b, c) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_no_chunk(checkpoint, setup_rpc): - class Model(nn.Module): - def forward(self, a: Tensor, b: int, c: Tensor): - return a, b, c - - model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) - a = torch.rand(10, 10) - b = random.randint(0, 10) - c = torch.rand(10, 10) - res = model(a, b, NoChunk(c)).local_value() - assert torch.allclose(a, res[0]) - assert [b] * 5 == res[1] - # c gets replicated due to NoChunk and the same tensor gets concatenated 5 - # times in the output. - assert torch.allclose(torch.cat((c, c, c, c, c)), res[2]) - - # Test invalid type for NoChunk - with pytest.raises(TypeError, match="NoChunk only supported for tensors"): - NoChunk(b) - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=2, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) - assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) - - -@pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint, setup_rpc): - bn = nn.BatchNorm2d(3) - pipe_bn = deepcopy(bn) - pipe = Pipe( - nn.Sequential(pipe_bn), - chunks=1, - checkpoint=checkpoint, - deferred_batch_norm=True, - ) - - x = torch.rand(4, 3, 10, 10) - pipe(x).local_value().mean().backward() - bn(x).mean().backward() - - assert pipe[0].weight.grad is not None - assert pipe[0].bias.grad is not None - - assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) - assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) - - -def test_devices(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - c = nn.Linear(1, 1) - - # There are extra two devices. - model = nn.Sequential(a, b, c) - model = Pipe(model) - - cpu = torch.device("cpu") - # Extra devices must be discarded. - assert model.devices == [cpu, cpu, cpu] - - -def test_partitions(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], nn.Sequential) - assert isinstance(model.partitions[1], nn.Sequential) - - assert "partitions.0.0.weight" in model.state_dict() - - -@skip_if_no_cuda -def test_merged_partitions(setup_rpc): - a = nn.Linear(1, 1).to(0) - b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0) - c = nn.Linear(1, 1) - d = nn.Linear(1, 2) - - model = nn.Sequential(a, b, c, d) - model = Pipe(model) - - assert isinstance(model.partitions, nn.ModuleList) - assert isinstance(model.partitions[0], PipeSequential) - assert isinstance(model.partitions[1], PipeSequential) - assert list(model.partitions[0]) == [a, b[0], b[1]] - assert list(model.partitions[1]) == [c] - assert list(model.partitions[2]) == [d] - - -def test_deny_moving(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(a, b) - model = Pipe(model) - - # Moving is denied. - with pytest.raises(TypeError): - model.cuda() - - with pytest.raises(TypeError): - model.cpu() - - with pytest.raises(TypeError): - model.to(torch.device("cuda")) - - with pytest.raises(TypeError): - model.to(0) - - with pytest.raises(TypeError): - model.to("cuda") - - with pytest.raises(TypeError): - model.to(device=0) - - with pytest.raises(TypeError): - model.to(torch.rand(1)) - - with pytest.raises(TypeError): - model.to(tensor=torch.rand(1)) - - # Casting is allowed. - model.half() - model.to(torch.double) - model.to(dtype=torch.float) - - -def test_empty_module(setup_rpc): - # Empty sequential module is not illegal. - model = nn.Sequential() - model = Pipe(model) - - assert model(torch.tensor(42)).local_value() == torch.tensor(42) - - # But only tensor or tensors is legal in Pipe. - with pytest.raises(TypeError): - model(42) - - -def test_named_children(setup_rpc): - a = nn.Linear(1, 1) - b = nn.Linear(1, 1) - - model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model) - - names = {n for n, _ in model.named_modules()} - assert "partitions.0.0" in names - assert "partitions.1.0" in names - - # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires - # several methods in its namespace. - with pytest.raises(AttributeError): - model.a - - -def test_verify_module_non_sequential(setup_rpc): - with pytest.raises( - TypeError, match="module must be nn.Sequential to be partitioned" - ): - Pipe(nn.Module()) - - -def test_verify_module_duplicate_children(setup_rpc): - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(conv, conv) - - with pytest.raises( - ValueError, match="module with duplicate children is not supported" - ): - Pipe(model) - - -@skip_if_no_cuda -def test_verify_module_params_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, param1, param2): - super().__init__() - self.param1 = param1 - self.param2 = param2 - - conv1 = nn.Conv2d(3, 3, 1) - conv2 = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv1, conv2.cuda())) - - with pytest.raises( - ValueError, - match=r"should have all parameters on a single device, please use .to\(\)" - " to place the module on a single device", - ): - Pipe(model) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_verify_nested_modules(setup_rpc): - model = nn.Sequential( - nn.Sequential(nn.Linear(32, 16).cuda(0), nn.Linear(16, 8).cuda(0)), - nn.Sequential(nn.Linear(8, 4).cuda(1), nn.Linear(4, 2).cuda(1)), - ) - - pipe = Pipe(model) - out = pipe(torch.rand(10, 32).cuda(0)) - assert out.local_value().device == torch.device("cuda:1") - assert out.local_value().size() == torch.Size([10, 2]) - - -def test_verify_module_duplicate_parameters_on_same_device(setup_rpc): - class Surrogate(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(Surrogate(conv), Surrogate(conv)) - - Pipe(model) - - -def test_forward_lockstep(setup_rpc): - timeline = [] - - class DelayedLog(nn.Module): - def __init__(self, j, seconds): - super().__init__() - self.i = 0 - self.j = j - self.seconds = seconds - - def forward(self, x): - time.sleep(self.seconds) - - timeline.append((self.i, self.j)) - self.i += 1 - - return x - - model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) - model = Pipe(model, chunks=3) - model(torch.rand(3, 1)) - - # Expected timeline: (Logs are recorded at !) - # - # Partition #0: 0! 1! 2! - # Partition #1: 000! 111! 222! - # - assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] - - -@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -@skip_if_no_cuda -def test_multiple_inputs(checkpoint, setup_rpc): - class Module1(nn.Module): - def forward(self, a, b, c): - return a + b + c, a * b * c - - class Module2(nn.Module): - def forward(self, a, b): - return a + b - - model = Pipe( - nn.Sequential(Module1().cuda(0), Module2().cuda(0)), - chunks=2, - checkpoint=checkpoint, - ) - t = torch.rand(10) - res = model(t, t, t).local_value() - assert torch.equal(res, (t + t + t) + (t * t * t)) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_inputs_wrong_device(setup_rpc): - class Module1(nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(5)) - - def forward(self, a, b): - return a + b + self.param, b - - # Start inputs on wrong device and ensure Pipe moves them correctly. - a = torch.rand(10).cuda(1) - b = torch.rand(10).cuda(1) - model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2) - with pytest.raises( - ValueError, - match="All inputs should be on the same device as the first partition", - ): - model(a, b) - - -@pytest.mark.skipif(not TEST_MULTIGPU, reason="Need atleast two GPUs") -def test_with_device_wrapper(setup_rpc): - fc1 = nn.Linear(16, 8).cuda(0) - fc2 = nn.Linear(8, 4).cuda(1) - dropout = nn.Dropout() - - model = nn.Sequential(fc1, fc2, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(dropout, "cuda:1")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:1") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0"), torch.device("cuda:1")] == model.devices - - model = nn.Sequential(fc1, WithDevice(fc2, "cuda:0")) - model = Pipe(model, chunks=8) - assert ( - torch.device("cuda:0") == model(torch.rand(16, 16).cuda(0)).local_value().device - ) - assert [torch.device("cuda:0")] == model.devices - assert torch.device("cuda:0") == fc2.weight.device - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_pipeline.py b/test/distributed/pipeline/sync/test_pipeline.py deleted file mode 100644 index 9548cb959db1..000000000000 --- a/test/distributed/pipeline/sync/test_pipeline.py +++ /dev/null @@ -1,36 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -from torch.distributed.pipeline.sync.pipeline import _clock_cycles -from torch.testing._internal.common_utils import run_tests - - -def test_clock_cycles(): - assert list(_clock_cycles(1, 1)) == [[(0, 0)]] - assert list(_clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] - assert list(_clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] - - assert list(_clock_cycles(3, 3)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1), (0, 2)], - [(2, 1), (1, 2)], - [(2, 2)], - ] - - assert list(_clock_cycles(4, 2)) == [ - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1)], - [(3, 0), (2, 1)], - [(3, 1)], - ] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_stream.py b/test/distributed/pipeline/sync/test_stream.py deleted file mode 100644 index f9702c8e4152..000000000000 --- a/test/distributed/pipeline/sync/test_stream.py +++ /dev/null @@ -1,198 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -import torch - -from torch.distributed.pipeline.sync.stream import ( - CPUStream, - current_stream, - default_stream, - get_device, - is_cuda, - new_stream, - record_stream, - use_device, - use_stream, - wait_stream, -) -from torch.testing._internal.common_utils import run_tests - -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="cuda required" -) - - -class TestNewStream: - def test_new_stream_cpu(self): - stream = new_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_new_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream != torch.cuda.default_stream() - - -class TestCurrentStream: - def test_current_stream_cpu(self): - stream = current_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_current_stream_cuda(self): - stream = current_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.current_stream() - - -class TestDefaultStream: - def test_default_stream_cpu(self): - stream = default_stream(torch.device("cpu")) - assert stream is CPUStream - - @skip_if_no_cuda - def test_default_stream_cuda(self): - stream = default_stream(torch.device("cuda")) - assert isinstance(stream, torch.cuda.Stream) - assert stream == torch.cuda.default_stream() - - -class TestUseDevice: - def test_use_device_cpu(self): - with use_device(torch.device("cpu")): - pass - - @skip_if_no_cuda - def test_use_device_cuda(self): - with use_device(torch.device("cuda")): - pass - - -class TestUseStream: - def test_use_stream_cpu(self): - with use_stream(CPUStream): - pass - - @skip_if_no_cuda - def test_use_stream_cuda(self): - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - assert current_stream(torch.device("cuda")) == stream - - -class TestGetDevice: - def test_get_device_cpu(self): - assert get_device(CPUStream).type == "cpu" - - @skip_if_no_cuda - def test_get_device_cuda(self): - stream = current_stream(torch.device("cuda")) - assert get_device(stream).type == "cuda" - - -class TestWaitStream: - def _test_wait_stream(self, source, target, cuda_sleep=None): - with use_stream(target): - if is_cuda(target): - cuda_sleep(0.5) - x = torch.ones(100, 100, device=get_device(target)) - - wait_stream(source, target) - - with use_stream(source): - assert x.sum().item() == 10000 - - def test_wait_stream_cpu_cpu(self): - source = CPUStream - target = CPUStream - self._test_wait_stream(source, target) - - @skip_if_no_cuda - def test_wait_stream_cpu_cuda(self, cuda_sleep): - source = CPUStream - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cpu(self, cuda_sleep): - source = new_stream(torch.device("cuda")) - target = CPUStream - self._test_wait_stream(source, target, cuda_sleep) - - @skip_if_no_cuda - def test_wait_stream_cuda_cuda(self, cuda_sleep): - source = current_stream(torch.device("cuda")) - target = new_stream(torch.device("cuda")) - self._test_wait_stream(source, target, cuda_sleep) - - -class TestRecordStream: - def test_record_stream_cpu(self): - # It should silently ignore CPU tensors. - x = torch.rand(1, device=torch.device("cpu")) - record_stream(x, CPUStream) - - @skip_if_no_cuda - def test_record_stream_cuda(self, cuda_sleep): - # This test detects unexpected block reallocation. For reliable test, - # the stream to allocate tensors is isolated. The allocator will not - # reuse free blocks which were allocated from another stream. - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(1, device=torch.device("cuda")) - - stream = new_stream(torch.device("cuda")) - record_stream(x, stream) - with use_stream(stream): - cuda_sleep(0.5) - - # 'x' is deleted at Python's perspective. But the block of 'x' is still - # required for 'stream'. 'y' shouldn't be allocated to the block. - data_ptr = x.data_ptr() - del x - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - y = torch.rand(1, device=torch.device("cuda")) - assert y.data_ptr() != data_ptr - - # Pause Python until 'stream' finishes tasks queued. Now the block of - # 'x' is free to be reallocated. - wait_stream(CPUStream, stream) - with torch.cuda.stream(stream_alloc): - z = torch.rand(1, device=torch.device("cuda")) - assert z.data_ptr() == data_ptr - - @skip_if_no_cuda - def test_record_stream_shifted_view(self, cuda_sleep): - # Issue: https://github.com/pytorch/pytorch/issues/27366 - stream_alloc = new_stream(torch.device("cuda")) - with torch.cuda.stream(stream_alloc): - x = torch.rand(2, device=torch.device("cuda")) - - y = x[1:] - assert y.data_ptr() > x.data_ptr() - - stream = new_stream(torch.device("cuda")) - with use_stream(stream): - cuda_sleep(0.5) - record_stream(y, stream) - - data_ptr = x.data_ptr() - del x, y - - stream_alloc.synchronize() - with torch.cuda.stream(stream_alloc): - z = torch.rand(2, device=torch.device("cuda")) - assert z.data_ptr() != data_ptr - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_transparency.py b/test/distributed/pipeline/sync/test_transparency.py deleted file mode 100644 index a87a04150fdc..000000000000 --- a/test/distributed/pipeline/sync/test_transparency.py +++ /dev/null @@ -1,55 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import torch -from torch import nn - -from torch.distributed.pipeline.sync import Pipe -from torch.testing._internal.common_utils import run_tests - - -def test_simple_linears(setup_rpc): - def sum_grad(parameters): - return sum(p.grad.sum() for p in parameters if p.grad is not None) - - def zero_grad(parameters): - for p in parameters: - p.grad = None - - inputs = torch.rand(8, 1) - model = nn.Sequential( - nn.Linear(1, 2), - nn.Linear(2, 4), - nn.Linear(4, 2), - nn.Linear(2, 1), - ) - - # Without Pipe - outputs = model(inputs) - loss = outputs.mean() - loss.backward() - - grad_without_pipe = sum_grad(model.parameters()) - - zero_grad(model.parameters()) - - # With Pipe - model = Pipe(model, chunks=4) - - outputs = model(inputs).local_value() - loss = outputs.mean() - loss.backward() - - grad_with_pipe = sum_grad(model.parameters()) - - # Both grads should be identical. - assert torch.allclose(grad_with_pipe, grad_without_pipe) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipeline/sync/test_worker.py b/test/distributed/pipeline/sync/test_worker.py deleted file mode 100644 index f82af2ea0067..000000000000 --- a/test/distributed/pipeline/sync/test_worker.py +++ /dev/null @@ -1,118 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -import threading - -import pytest - -import torch - -from torch.distributed.pipeline.sync.microbatch import Batch -from torch.distributed.pipeline.sync.stream import CPUStream -from torch.distributed.pipeline.sync.worker import spawn_workers, Task -from torch.testing._internal.common_utils import run_tests - - -class fake_device: - """A test double for :class:`torch.device`. Every fake device is different - with each other. - """ - - type = "fake" - index = None - - -def test_compute_multithreading(): - """Task.compute should be executed on multiple threads.""" - thread_ids = set() - - def log_thread_id(): - thread_id = threading.current_thread().ident - thread_ids.add(thread_id) - return Batch(()) - - with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): - for i in range(2): - t = Task(CPUStream, compute=log_thread_id, finalize=None) - in_queues[i].put(t) - for i in range(2): - out_queues[i].get() - - assert len(thread_ids) == 2 - - -def test_compute_success(): - """Task.compute returns (True, (task, batch)) on success.""" - - def _42(): - return Batch(torch.tensor(42)) - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=_42, finalize=None) - in_queues[0].put(t) - ok, (task, batch) = out_queues[0].get() - - assert ok - assert task is t - assert isinstance(batch, Batch) - assert batch[0].item() == 42 - - -def test_compute_exception(): - """Task.compute returns (False, exc_info) on failure.""" - - def zero_div(): - 0 / 0 - - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - t = Task(CPUStream, compute=zero_div, finalize=None) - in_queues[0].put(t) - ok, exc_info = out_queues[0].get() - - assert not ok - assert isinstance(exc_info, tuple) - assert issubclass(exc_info[0], ZeroDivisionError) - - -@pytest.mark.parametrize("grad_mode", [True, False]) -def test_grad_mode(grad_mode): - def detect_grad_enabled(): - x = torch.rand(1, requires_grad=torch.is_grad_enabled()) - return Batch(x) - - with torch.set_grad_enabled(grad_mode): - with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): - task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) - in_queues[0].put(task) - - ok, (_, batch) = out_queues[0].get() - - assert ok - assert batch[0].requires_grad == grad_mode - - -def test_worker_per_device(): - cpu = torch.device("cpu") - cpu0 = torch.device("cpu", index=0) - fake1 = fake_device() - fake2 = fake_device() - - with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): - assert len(in_queues) == len(out_queues) == 5 - - # 0: cpu, 1: cpu, 2: cpu0 - assert in_queues[0] is in_queues[1] is in_queues[2] - assert out_queues[0] is out_queues[1] is out_queues[2] - - # 3: fake1, 4: fake2 - assert in_queues[3] is not in_queues[4] - assert out_queues[3] is not out_queues[4] - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipelining/model_registry.py b/test/distributed/pipelining/model_registry.py index babc1cfa1096..5f0c9baf3b1e 100644 --- a/test/distributed/pipelining/model_registry.py +++ b/test/distributed/pipelining/model_registry.py @@ -51,6 +51,27 @@ def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): return x +class ModelWithParamAlias(torch.nn.Module): + default_dhid = 512 + default_batch_size = 256 + + def __init__(self, d_hid: int = default_dhid): + super().__init__() + self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x, y): + x = torch.mm(x, self.mm_param0) + x = x + y + x = self.lin0(x) + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin1(x) + x = torch.relu(x) + return x + + # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid: int): diff --git a/test/distributed/pipelining/test_chunkspec.py b/test/distributed/pipelining/test_chunkspec.py deleted file mode 100644 index 1b104e59ec77..000000000000 --- a/test/distributed/pipelining/test_chunkspec.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["oncall: distributed"] -import torch -from torch.distributed.pipelining import ( - ArgsChunkSpec, - KwargsChunkSpec, - pipe_split, - pipeline, -) -from torch.testing._internal.common_utils import run_tests, TestCase - - -d_hid = 512 -batch_size = 256 - -torch.manual_seed(0) - - -class ModelWithKwargs(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin1 = torch.nn.Linear(d_hid, d_hid) - self.lin2 = torch.nn.Linear(d_hid, d_hid) - - def forward(self, x, y, z=torch.zeros(batch_size, d_hid)): - x = torch.mm(x, self.mm_param0) - x = x + y - x = torch.relu(x) - x = x + z - pipe_split() - x = torch.mm(x, self.mm_param1) - x = self.lin1(x) - pipe_split() - x = torch.relu(x) - x = torch.mm(x, self.mm_param2) - pipe_split() - x = self.lin2(x) - x = torch.relu(x) - return x - - -class ChunkSpecTests(TestCase): - def test_chunk_spec(self): - mod = ModelWithKwargs() - - x = torch.randn(batch_size, d_hid) - y = torch.randn(batch_size, d_hid) - z = torch.randn(batch_size, d_hid) - - chunks = 4 - - with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}): - pipe = pipeline( - mod, - chunks, - example_args=(x, y), - example_kwargs={"z": z}, - ) - - assert pipe.num_stages == 4 - - ref = mod(x, y, z) - out = pipe(x, y, z)[0] - torch.testing.assert_close(out, ref) - print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/pipelining/test_composability.py b/test/distributed/pipelining/test_composability.py index bbf3f1929fbc..a2a37a6e0740 100644 --- a/test/distributed/pipelining/test_composability.py +++ b/test/distributed/pipelining/test_composability.py @@ -16,8 +16,8 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.pipelining import ManualPipelineStage -from torch.distributed.pipelining.PipelineSchedule import ( +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( PipelineScheduleSingle, Schedule1F1B, ScheduleGPipe, @@ -127,14 +127,13 @@ def apply_dp(partial_model, dp_type): def build_stage(stage_idx, num_stages): partial_model, offset = get_stage_module(stage_idx, num_stages) dp_model = apply_dp(partial_model, dp_type) - stage = ManualPipelineStage( + stage = PipelineStage( dp_model, stage_idx, num_stages, self.device, group=pp_group, input_args=input_mb[0], - num_microbatches=num_microbatches, ) return stage, offset diff --git a/test/distributed/pipelining/test_microbatch.py b/test/distributed/pipelining/test_microbatch.py index c526c6ff7b91..9f67c2c37ea4 100644 --- a/test/distributed/pipelining/test_microbatch.py +++ b/test/distributed/pipelining/test_microbatch.py @@ -1,6 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +from model_registry import ModelWithKwargs + import torch +from torch.distributed.pipelining import pipeline from torch.distributed.pipelining.microbatch import ( merge_chunks, split_args_kwargs_into_chunks, @@ -10,6 +13,7 @@ d_hid = 512 +torch.manual_seed(0) class MicrobatchTests(TestCase): @@ -49,9 +53,39 @@ def test_split_and_merge(self): }, ) torch.testing.assert_close(merged_kwargs, kwargs) - print("Microbatch test passed") + def test_chunk_spec(self): + mod = ModelWithKwargs() + batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE + + x = torch.randn(batch_size, d_hid) + y = torch.randn(batch_size, d_hid) + + num_chunks = 4 + + args_chunk_spec = TensorChunkSpec.from_tuple((0,)) + kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0}) + + args_split, kwargs_split = split_args_kwargs_into_chunks( + (x,), + {"y": y}, + num_chunks, + args_chunk_spec, + kwargs_chunk_spec, + ) + + pipe = pipeline( + mod, + mb_args=args_split[0], + mb_kwargs=kwargs_split[0], + ) + + ref = mod(x, y) + out = pipe(x, y)[0] + torch.testing.assert_close(out, ref) + print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") + if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_pipe.py b/test/distributed/pipelining/test_pipe.py index a9e283d3cedf..d4d158bc9d5f 100644 --- a/test/distributed/pipelining/test_pipe.py +++ b/test/distributed/pipelining/test_pipe.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] -from model_registry import MLPModule +from model_registry import MLPModule, ModelWithParamAlias import torch from torch.distributed.pipelining import pipe_split, pipeline @@ -13,7 +13,7 @@ d_hid = 512 -batch_size = 256 +microbatch_size = 16 torch.manual_seed(0) @@ -64,20 +64,34 @@ def forward(self, x, y): return x - y +EXPECTED_N_STAGES = { + ExampleCode: 4, + MultiMLP: 4, + ModelWithParamAlias: 2, +} + +# Currently, we don't enforce full set equality on the FQNs between the original +# and pipelined models, because in the multi-use param case, PP will deduplicate +# the FQNs from the state_dict. +# TODO +CHECK_FQN_SET_EQUALITY = False + + class PipeTests(TestCase): - @parametrize("ModelClass", [ExampleCode, MultiMLP]) + @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias]) def test_model_split(self, ModelClass): mod = ModelClass() - x = torch.randn(batch_size, d_hid) - y = torch.randn(batch_size, d_hid) + x = torch.randn(microbatch_size, d_hid) + y = torch.randn(microbatch_size, d_hid) pipe = pipeline( mod, - num_chunks=4, - example_args=(x, y), + mb_args=(x, y), ) - assert pipe.num_stages == 4, f"nstages = {pipe.num_stages}, expect 4" + assert ( + pipe.num_stages == EXPECTED_N_STAGES[ModelClass] + ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" ref_out = mod(x, y) out = pipe(x, y)[0] @@ -90,14 +104,17 @@ def test_model_split(self, ModelClass): new_names = set() for idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(idx) - new_names.update(stage_mod.state_dict().keys()) - - assert ( - old_names == new_names - ), f""" - old names {old_names} - new names {new_names} - """ + stage_fqns = set(stage_mod.state_dict().keys()) + assert stage_fqns.issubset(old_names) + new_names.update(stage_fqns) + + if CHECK_FQN_SET_EQUALITY: + assert ( + old_names == new_names + ), f""" + old names {old_names} + new names {new_names} + """ print("Qualname check passed") diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 232f69d8bcef..e67459d5b44b 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -1,16 +1,18 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import copy +import logging import os import sys import tempfile +import unittest +from typing import Dict, List, Optional, Tuple from model_registry import ModelWithKwargs, MultiMLP import torch import torch.distributed as dist from torch.distributed.pipelining import ( - ManualPipelineStage, pipeline, PipelineStage, Schedule1F1B, @@ -18,6 +20,8 @@ ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) +from torch.distributed.pipelining.schedules import _Action, _ComputationType +from torch.distributed.pipelining.stage import _PipelineStageBase from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, @@ -29,6 +33,7 @@ skip_but_pass_in_sandcastle_if, ) +logger = logging.getLogger(__name__) d_hid = 512 batch_size = 256 @@ -36,6 +41,24 @@ torch.manual_seed(0) +class MockPipelineStage(_PipelineStageBase): + def __init__(self, *args, **kwargs): + # Mock the necessary attributes + self.num_stages = kwargs.get("num_stages", 1) + self.group_size = kwargs.get("group_size", 1) + self.group_rank = kwargs.get("group_rank", 0) + self.group = kwargs.get("group", None) + + def _create_grad_recv_info(self, *args, **kwargs): + return None + + def _prepare_forward_infra(self, n_microbatches): + pass + + def _prepare_backward_infra(self, n_microbatches): + pass + + class ScheduleTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: @@ -52,6 +75,46 @@ def setUpClass(cls): dev_id = cls.rank % torch.cuda.device_count() cls.device = torch.device(f"cuda:{dev_id}") + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + def test_multi_iter(self, ScheduleClass): + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + chunks = 4 + x_mb = x.chunk(chunks)[0] + + # Create a pipeline + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + mb_args=(x_mb,), + split_spec=split_spec, + ) + + stage = pipe.build_stage( + self.rank, + self.device, + ) + + # Attach to a schedule + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) + + # Run + for _ in range(20): + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] + out = schedule.step(target=target, losses=losses) + else: + schedule.step() + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @@ -65,17 +128,18 @@ def test_kwargs_with_tracer(self, ScheduleClass): loss_fn = torch.nn.MSELoss(reduction="sum") chunks = 4 + x_mb = x.chunk(chunks)[0] + y_mb = y.chunk(chunks)[0] + pipe = pipeline( mod, - chunks, - example_args=(x,), - example_kwargs={"y": y}, + mb_args=(x_mb,), + mb_kwargs={"y": y_mb}, ) - stage = PipelineStage( - pipe, + stage = pipe.build_stage( self.rank, - device=self.device, + self.device, ) # Attach to a schedule @@ -126,18 +190,17 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass): # Create a pipeline chunks = 4 + x_mb = x.chunk(chunks)[0] split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, - chunks, - example_args=(x,), + mb_args=(x_mb,), split_spec=split_spec, ) - stage = PipelineStage( - pipe, + stage = pipe.build_stage( self.rank, - device=self.device, + self.device, ) # Attach to a schedule @@ -205,12 +268,11 @@ def test_grad_with_manual(self, ScheduleClass): stage_module = full_mod.get_submodule(submod_name) chunks = 4 # Create a pipeline stage to wrap that submodule - stage = ManualPipelineStage( + stage = PipelineStage( stage_module, self.rank, self.world_size, self.device, - chunks, input_args=x.chunk(chunks)[0], ) @@ -289,12 +351,11 @@ def test_grad_with_manual_interleaved(self, ScheduleClass): chunks = 8 input_args = x.chunk(chunks)[0] stages = [ - ManualPipelineStage( + PipelineStage( stage_module, stage_idx, n_stages, self.device, - chunks, input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, stage_indices) @@ -344,7 +405,213 @@ def test_grad_with_manual_interleaved(self, ScheduleClass): instantiate_parametrized_tests(ScheduleTest) + +def format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]): + import itertools + + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = ( + "=========== ALL_RANK_ACTIONS ===========\n" + + header_row + + "\n" + + "\n".join(formatted_rows) + + "\n" + ) + return formatted_table + + +class TestSchedulePlan(unittest.TestCase): + def _validate_pipeline_order( + self, + pipeline_order: Dict[int, List[Optional[_Action]]], + num_microbatches: int, + num_stages: int, + ): + """ + pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...] + + Validating that the pipeline order follows the rules: + 1. Forward action for a microbatch must be before the Backward action for that microbatch + 2. Recv for a microbatch must be before the send for that microbatch + 3. Microbatch index is handled in sequential order for each stage + 4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it + 5. Same microbatch cannot be handled in the same time step across ranks + """ + # microbatch_index: (current computation type, current stage) + error_msg = [] + microbatch_process_info: Dict[int, Tuple(_ComputationType, int)] = {} + max_timestep = max(len(rank_list) for rank_list in pipeline_order.values()) + for timestep in range(max_timestep): + error_msg = [] + current_timestep_actions = [] + for rank in range(len(pipeline_order)): + action = ( + pipeline_order[rank][timestep] + if timestep < len(pipeline_order[rank]) + else None + ) + if action is not None: + current_timestep_actions.append(action) + + # TODO: enable this + # if len(current_timestep_actions) == 0: + # error_msg.append( + # "All actions were None, there is an unnecessary gap in the schedule" + # ) + + # Ensure that no microbatch is operated on twice in current_timestep_actions + unique_microbatch_indices = { + action[1] for action in current_timestep_actions + } + if len(unique_microbatch_indices) != len(current_timestep_actions): + error_msg.append( + "Duplicate microbatch index found in current_timestep_actions" + ) + + # Add additional checks for other rules here... + for action in current_timestep_actions: + computation_type, mb_index, stage_index = action + + if mb_index >= num_microbatches: + error_msg.append(f"Microbatch index {mb_index} out of range") + + # first microbatch + if mb_index not in microbatch_process_info: + if computation_type != _ComputationType.FORWARD or stage_index != 0: + error_msg.append(f"Incorrect start for microbatch {mb_index}") + microbatch_process_info[mb_index] = (computation_type, stage_index) + else: + # if the microbatch is included, check that the current stage is right after prev + prev_computation, prev_stage = microbatch_process_info[mb_index] + if prev_computation == _ComputationType.FORWARD: + if prev_stage == num_stages - 1: + expected_stage = num_stages - 1 + expected_computation = _ComputationType.BACKWARD + else: + expected_stage = prev_stage + 1 + expected_computation = _ComputationType.FORWARD + elif prev_computation == _ComputationType.BACKWARD: + if prev_stage == 0: + error_msg.append( + f"[{mb_index=}] already finished backward computation" + ) + expected_stage = None + expected_computation = None + else: + expected_stage = prev_stage - 1 + expected_computation = _ComputationType.BACKWARD + else: + raise ValueError( + f"Computation type {prev_computation} not supported" + ) + + if expected_computation is not None: + if expected_computation != computation_type: + error_msg.append( + f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}" + ) + + if expected_stage != stage_index: + error_msg.append( + f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}" + ) + + microbatch_process_info[mb_index] = ( + expected_computation, + expected_stage, + ) + + if len(error_msg) != 0: + self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg)) + + @parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS]) + def test_pipeline_order(self, ScheduleClass): + # Define a list of test cases with varying num_local_stages, num_microbatches, and group_size + # These should succeed since num_microbatches % group_size == 0 + test_cases = [ + # small number of stages + (2, 2, 2), + (2, 4, 4), + (2, 8, 2), + (2, 8, 4), + (2, 8, 8), + (4, 4, 4), + (4, 8, 4), + (4, 8, 8), + # large microbatches + (4, 16, 4), + (4, 32, 4), + (4, 64, 4), + # large groups + (4, 16, 16), + (4, 32, 32), + (4, 128, 64), + # odd num pipeline stages + (3, 2, 2), + (3, 8, 2), + (3, 12, 4), + # odd group_sizes + (4, 6, 3), + (4, 10, 5), + ] + for num_local_stages, num_microbatches, group_size in test_cases: + with self.subTest( + num_local_stages=num_local_stages, + num_microbatches=num_microbatches, + group_size=group_size, + ): + print(f"{num_local_stages=} {num_microbatches=} {group_size=}") + num_stages = num_local_stages * group_size + stages = [ + MockPipelineStage(group_size=group_size, num_stages=num_stages) + for i in range(num_local_stages) + ] + + schedule = ScheduleClass(stages, num_microbatches) + # print(format_pipeline_order(schedule.pipeline_order)) + self._validate_pipeline_order( + schedule.pipeline_order, num_microbatches, num_stages + ) + + +instantiate_parametrized_tests(TestSchedulePlan) + if __name__ == "__main__": + # Run only the TestSchedulePlan tests (single process) + loader = unittest.TestLoader() + suite = loader.loadTestsFromTestCase(TestSchedulePlan) + runner = unittest.TextTestRunner() + runner.run(suite) + # Check if GPU and NCCL are available if not ( dist.is_available() diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 97a147cb357a..fac2be495ce0 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist from torch.distributed.pipelining import ( - ManualPipelineStage, + build_stage, pipeline, PipelineStage, ScheduleGPipe, @@ -82,19 +82,18 @@ def test_tracer(self, ModelClass): mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) + x_mb = x.chunk(chunks)[0] split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, - chunks, - example_args=(x,), + mb_args=(x_mb,), split_spec=split_spec, ) - stage = PipelineStage( - pipe, + stage = pipe.build_stage( self.rank, - device=self.device, + self.device, ) # Attach to a schedule @@ -150,17 +149,23 @@ def test_tracer_kwargs(self, ModelClass): x = torch.randn(batch_size, d_hid, device=self.device) y = torch.randn(batch_size, d_hid, device=self.device) + x_mb = x.chunk(chunks)[0] + y_mb = y.chunk(chunks)[0] + pipe = pipeline( mod, - chunks, - example_args=(x,), - example_kwargs={"y": y}, + mb_args=(x_mb,), + mb_kwargs={"y": y_mb}, ) - stage = PipelineStage( - pipe, + stage_mod = pipe.get_stage_module(self.rank) + + # Test build_stage + stage = build_stage( + stage_mod, self.rank, - device=self.device, + pipe.info(), + self.device, ) # Attach to a schedule @@ -211,12 +216,11 @@ def test_manual(self): x = torch.randn(batch_size, d_hid, device=self.device) - stage = ManualPipelineStage( + stage = PipelineStage( stage_mod, self.rank, self.world_size, self.device, - chunks, input_args=x.chunk(chunks)[0], ) diff --git a/test/distributed/pipelining/test_transformer.py b/test/distributed/pipelining/test_transformer.py index 9742c77b606a..070a62d11638 100644 --- a/test/distributed/pipelining/test_transformer.py +++ b/test/distributed/pipelining/test_transformer.py @@ -7,7 +7,7 @@ d_hid = 16 n_layers = 8 -batch_size = 4 +microbatch_size = 4 class MLPModule(torch.nn.Module): @@ -36,8 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerTests(TestCase): def test_ir(self): transformer = TransformerLike() - print("Original model:\n", transformer) - x = torch.randn(batch_size, d_hid) + x = torch.randn(microbatch_size, d_hid) # Split into 2 stages num_stages = 2 @@ -45,7 +44,6 @@ def test_ir(self): pipe = pipeline( transformer, - 1, (x,), split_spec=split_spec, ) @@ -59,19 +57,18 @@ def get_layers(module): layers = [] for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) - print(f"\nStage {stage_idx}: \n", stage_mod) layers += get_layers(stage_mod) # Check layer completeness orig_layers = get_layers(transformer) assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}" - print("Layers matched! ", layers) + print("Layers matched!") # Check equivalence ref = transformer(x) out = pipe(x)[0] torch.testing.assert_close(out, ref) - print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") + print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") if __name__ == "__main__": diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 37eaf599e4d8..ef2e48d8ee9f 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -48,7 +48,6 @@ def test_unflatten(self): pipe = pipeline( mod, - 1, (x,), {"constant": constant}, ) diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py new file mode 100644 index 000000000000..56ae8a14dcde --- /dev/null +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -0,0 +1,147 @@ +# Owner(s): ["module: c10d"] +import unittest + +import torch +import torch.distributed as dist +from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code +from torch.distributed._cuda_p2p import test_with_non_cuda_p2p_group +from torch.distributed._functional_collectives import ( + all_gather_tensor, + reduce_scatter_tensor, +) +from torch.distributed._tensor import DeviceMesh +from torch.distributed._tensor.placement_types import Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule +from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.utils._triton import has_triton + + +@instantiate_parametrized_tests +class MicroPipelineTPTest(TestCase): + def setUp(self): + torch._inductor.config._micro_pipeline_tp = True + + self.rank = 0 + self.world_size = 2 + torch.cuda.set_device("cuda:0") + + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + + def tearDown(self): + dist.destroy_process_group() + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @parametrize("A_dims", [2, 3]) + @parametrize("gather_dim", [0, 1, 2]) + @fresh_inductor_cache() + def test_fuse_all_gather_matmul(self, A_dims, gather_dim): + if gather_dim >= A_dims: + return + + group = dist.group.WORLD + + def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + A = all_gather_tensor(A_shard, gather_dim=gather_dim, group=group) + return A @ B + + if A_dims == 2: + A_shard_shape = [64, 32] + elif A_dims == 3: + A_shard_shape = [2, 64, 32] + else: + raise AssertionError(f"Invalid A_dims: {A_dims}") + + A_shard_shape[gather_dim] //= self.world_size + A_shard = torch.rand(*A_shard_shape, device="cuda") + B = torch.rand(32, 16, device="cuda") + + with test_with_non_cuda_p2p_group(): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, A_shard, B) + + if gather_dim == A_dims - 1: + assert "fused_all_gather_matmul" not in code + assert "all_gather_into_tensor" in code + else: + # Decomposing the matmul on the K dimension is not supported + assert "fused_all_gather_matmul" in code + assert "all_gather_into_tensor" not in code + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @parametrize("A_dims", [2, 3]) + @parametrize("scatter_dim", [0, 1, 2]) + @fresh_inductor_cache() + def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim): + if scatter_dim >= A_dims: + return + + group = dist.group.WORLD + + def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return reduce_scatter_tensor(A @ B, "avg", scatter_dim, group) + + if A_dims == 2: + A = torch.rand(64, 32, device="cuda") + elif A_dims == 3: + A = torch.rand(2, 64, 32, device="cuda") + else: + raise AssertionError(f"Invalid A_dims: {A_dims}") + B = torch.rand(32, 16, device="cuda") + + with test_with_non_cuda_p2p_group(): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, A, B) + + assert "fused_matmul_reduce_scatter" in code + assert "reduce_scatter_tensor" not in code + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @parametrize("shard_dim", [0, 1]) + @fresh_inductor_cache() + def test_dtensor_seq_par(self, shard_dim: int): + model = MLPModule(device="cuda", bias=False) + device_mesh = DeviceMesh( + "cuda", + torch.arange(0, self.world_size), + ) + parallelize_plan = { + "net1": ColwiseParallel(input_layouts=Shard(shard_dim)), + "net2": RowwiseParallel(output_layouts=Shard(shard_dim)), + } + model = parallelize_module(model, device_mesh, parallelize_plan) + if shard_dim == 0: + inp = torch.rand(8, 10, device="cuda") + elif shard_dim == 1: + inp = torch.rand(2, 8, 10, device="cuda") + else: + raise AssertionError("Invalid shard_dim") + + with test_with_non_cuda_p2p_group(): + compiled = torch.compile(model) + code = run_and_get_triton_code(compiled, inp) + + assert "fused_all_gather_matmul" in code + assert "all_gather_into_tensor" not in code + assert "fused_matmul_reduce_scatter" in code + assert "reduce_scatter_tensor" not in code + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 4fea855a85b9..50ec40291cd7 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -599,6 +599,9 @@ def test_comm_split_optimization(self): @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @skip_but_pass_in_sandcastle_if( + torch.cuda.nccl.version()[-1] == "x", "NCCL test not for NCCLX" + ) def test_comm_split_subgroup(self): # Test `ncclCommSplit` for smaller subgroups of the world when # we've passed a specific device_id to init_process_group. @@ -3514,7 +3517,8 @@ class NCCLTraceTest(NCCLTraceTestBase): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) - def test_short(self, timing_enabled): + @parametrize("include_collectives", [True, False]) + def test_short(self, timing_enabled, include_collectives): if self.rank == self.MAIN_PROCESS_RANK: return pg = self._create_process_group_nccl() @@ -3529,10 +3533,16 @@ def test_short(self, timing_enabled): # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api time.sleep(1) - - t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + if include_collectives: + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + else: + t = pickle.loads( + torch._C._distributed_c10d._dump_nccl_trace( + includeCollectives=False, includeStackTraces=None, onlyActive=None + ) + ) ver = t["version"] - self.assertEqual(ver, "2.1") + self.assertEqual(ver, "2.2") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -3541,35 +3551,40 @@ def test_short(self, timing_enabled): self.assertIn("ranks", default_pg_info) global_ranks = pg_config["0"]["ranks"] self.assertEqual(len(json.loads(global_ranks)), self.world_size) - t = t["entries"] - self.assertEqual(len(t), 2) - last = t[-1] - self.assertEqual(last["process_group"], ("0", "default_pg")) - self.assertEqual(last["state"], "completed") - s = last["time_discovered_started_ns"] - f = last["time_discovered_completed_ns"] - self.assertEqual(last["record_id"], 1) - self.assertIsNotNone(f) - if timing_enabled: - self.assertIsNotNone(s) - self.assertTrue(s <= f) - self.assertIn("test_c10d_nccl.py", str(last["frames"])) - self.assertEqual(last["input_sizes"], ((3, 4),)) - self.assertEqual(last["input_dtypes"], ["Float"]) - self.assertEqual(last["output_sizes"], ((3, 4),)) - self.assertEqual(last["output_dtypes"], ["Float"]) - self.assertEqual(last["collective_seq_id"], 2) - now = datetime.now() - event_created_time = datetime.fromtimestamp( - last["time_created_ns"] / 1000000000 - ) - before_test = now - timedelta(minutes=1) - self.assertTrue(before_test < event_created_time < now) - if timing_enabled: - # very loose bounds, measured 0.036 ms on devgpu - self.assertTrue(0 < last["duration_ms"] < 100) + if include_collectives: + self.assertEqual(len(t["entries"]), 2) + t = t["entries"] + self.assertEqual(len(t), 2) + last = t[-1] + self.assertEqual(last["process_group"], ("0", "default_pg")) + self.assertEqual(last["state"], "completed") + s = last["time_discovered_started_ns"] + f = last["time_discovered_completed_ns"] + self.assertEqual(last["record_id"], 1) + self.assertIsNotNone(f) + if timing_enabled: + self.assertIsNotNone(s) + self.assertTrue(s <= f) + self.assertIn("test_c10d_nccl.py", str(last["frames"])) + self.assertEqual(last["input_sizes"], ((3, 4),)) + self.assertEqual(last["input_dtypes"], ["Float"]) + self.assertEqual(last["output_sizes"], ((3, 4),)) + self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["collective_seq_id"], 2) + self.assertEqual(last["timeout_ms"], 600000) + now = datetime.now() + event_created_time = datetime.fromtimestamp( + last["time_created_ns"] / 1000000000 + ) + before_test = now - timedelta(minutes=1) + self.assertTrue(before_test < event_created_time < now) + if timing_enabled: + # very loose bounds, measured 0.036 ms on devgpu + self.assertTrue(0 < last["duration_ms"] < 100) + else: + self.assertTrue("duration_ms" not in last) else: - self.assertTrue("duration_ms" not in last) + self.assertTrue("entries" not in t) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -3641,12 +3656,14 @@ def test_long(self): self.assertEqual(last["input_dtypes"], ["Float"]) self.assertEqual(last["output_sizes"], ((3, 4),)) self.assertEqual(last["output_dtypes"], ["Float"]) + self.assertEqual(last["timeout_ms"], 600000) self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("timing_enabled", [True, False]) - def test_trace_while_active(self, timing_enabled): + @parametrize("only_active", [True, False]) + def test_trace_while_active(self, timing_enabled, only_active): if self.rank == self.MAIN_PROCESS_RANK: for c in self.children_pipes: self.assertEqual(c.recv(), "next") @@ -3667,17 +3684,26 @@ def test_trace_while_active(self, timing_enabled): if self.rank != 0: pg.allreduce(a).wait() e.synchronize() - t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + t = pickle.loads( + torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active) + ) t = t["entries"] - self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") - if self.rank == 0: - self.assertEqual(t[-1]["collective_seq_id"], 1) - self.assertEqual(t[-1]["state"], "completed") - else: - self.assertEqual(t[-1]["collective_seq_id"], 2) - self.assertEqual( - t[-1]["state"], self.started_or_scheduled(timing_enabled) - ) + if only_active: + if self.rank == 0: + self.assertEqual(len(t), 0) + else: + self.assertEqual(len(t), 1) + if not only_active: + if self.rank == 0: + self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") + self.assertEqual(t[-1]["collective_seq_id"], 1) + self.assertEqual(t[-1]["state"], "completed") + else: + self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") + self.assertEqual(t[-1]["collective_seq_id"], 2) + self.assertEqual( + t[-1]["state"], self.started_or_scheduled(timing_enabled) + ) self.parent.send("next") self.assertEqual("next", self.parent.recv()) @@ -3845,6 +3871,7 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertTrue(0.001 < duration < 10000, duration) else: self.assertTrue("duration_ms" not in t["entries"][coalesced_op]) + self.assertEqual(t["entries"][coalesced_op]["timeout_ms"], 600000) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py index fb0067f2dd2e..594c028ae9d4 100644 --- a/test/distributed/test_control_collectives.py +++ b/test/distributed/test_control_collectives.py @@ -8,6 +8,17 @@ from torch.testing._internal.common_utils import run_tests, TestCase +# simple example of user code that takes the base class ControlCollectives +# and executes multiple different collectives +def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int: + timeout = timedelta(seconds=10) + # first a barrier + collectives.barrier("1", timeout, True) + # then an all_sum + out = collectives.all_sum("2", rank, timeout) + return out + + class TestCollectives(TestCase): def test_barrier(self) -> None: store = dist.HashStore() @@ -180,6 +191,20 @@ def test_unique(self) -> None: with self.assertRaisesRegex(Exception, "Key foo has already been used"): collectives.all_sum("foo", 2) + def test_simple_user_func(self) -> None: + store = dist.HashStore() + world_size = 4 + + def f(rank: int) -> None: + # user need to create child collectives + # but simple_user_func do not need to be changed for different child collectives + store_collectives = dist._StoreCollectives(store, rank, world_size) + out = simple_user_func(store_collectives, rank) + self.assertEqual(out, sum(range(world_size))) + + with ThreadPool(world_size) as pool: + pool.map(f, range(world_size)) + if __name__ == "__main__": assert ( diff --git a/test/distributed/test_cuda_p2p.py b/test/distributed/test_cuda_p2p.py index 14ff4bd3d0eb..1e743896bc7b 100644 --- a/test/distributed/test_cuda_p2p.py +++ b/test/distributed/test_cuda_p2p.py @@ -6,9 +6,12 @@ import torch.distributed as dist from torch.distributed._cuda_p2p import ( + _fused_all_gather_matmul_fallback, + _fused_matmul_reduce_scatter_fallback, get_cuda_p2p_backend, get_p2p_buffer_size, is_cuda_p2p_group, + p2p_usage_counter, ) from torch.testing._internal.common_distributed import ( MultiProcessTestCase, @@ -16,6 +19,8 @@ skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, @@ -43,6 +48,7 @@ def requires_cuda_p2p_access(): ) +@instantiate_parametrized_tests @requires_nccl() @requires_cuda_p2p_access() class ProcessGroupCudaP2PTest(MultiProcessTestCase): @@ -137,6 +143,79 @@ def test_p2p_buffer(self) -> None: torch.cuda.synchronize() dist.destroy_process_group() + @skipIfRocm + @skip_if_lt_x_gpu(2) + @parametrize("gather_dim", [0, 1]) + def test_fused_all_gather_matmul(self, gather_dim: int) -> None: + B = 8 + M = 64 + N = 16 + K = 32 + BUFFER_SIZE = B * M * K // self.world_size * 4 + + self._init_process_group(BUFFER_SIZE) + group = dist.group.WORLD + rank = self.rank + world_size = self.world_size + + torch.manual_seed(42 + rank) + A_shard = torch.rand(B, M // self.world_size, K, device="cuda") + Bs = [torch.rand(K, N, device="cuda") for _ in range(3)] + + ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + with p2p_usage_counter() as counter: + ag_output_1, mm_outputs_1 = torch.ops.cuda_p2p.fused_all_gather_matmul( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + assert counter["fused_all_gather_matmul"] == 1 + + assert torch.allclose(ag_output_0, ag_output_1) + assert ag_output_0.stride() == ag_output_1.stride() + for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): + assert torch.allclose(mm_output_0, mm_output_1) + assert mm_output_0.stride(), mm_output_1.stride() + + dist.barrier() + torch.cuda.synchronize() + dist.destroy_process_group() + + @skipIfRocm + @skip_if_lt_x_gpu(2) + @parametrize("scatter_dim", [0, 1]) + def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: + B = 8 + M = 64 + N = 16 + K = 32 + BUFFER_SIZE = B * M * N // self.world_size * 4 * 2 + + self._init_process_group(BUFFER_SIZE) + group = dist.group.WORLD + rank = self.rank + world_size = self.world_size + + torch.manual_seed(42 + rank) + A = torch.rand(B, M, K, device="cuda") + B = torch.rand(K, N, device="cuda") + + output_0 = _fused_matmul_reduce_scatter_fallback( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + with p2p_usage_counter() as counter: + output_1 = torch.ops.cuda_p2p.fused_matmul_reduce_scatter( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + assert counter["fused_matmul_reduce_scatter"] == 1 + + assert torch.allclose(output_0, output_1) + assert output_0.stride() == output_1.stride() + + dist.barrier() + torch.cuda.synchronize() + dist.destroy_process_group() + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 8f70ee2f0b7d..22d8b0fbbdce 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -73,7 +73,7 @@ def test_assert_invalid_mesh_tensor(self): device_mesh = DeviceMesh(self.device_type, mesh) @with_comms - def test_get_group(self): + def test_get_group_and_get_all_groups(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") @@ -82,16 +82,17 @@ def test_get_group(self): tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] - self.assertEqual(len(mesh_2d.get_group()), 2) - self.assertEqual(mesh_2d.get_group()[0], mesh_2d.get_group("dp")) - self.assertEqual(mesh_2d.get_group()[1], mesh_2d.get_group("tp")) - self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp")) self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp")) self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group()) self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group()) + groups = mesh_2d.get_all_groups() + self.assertEqual(len(groups), 2) + self.assertTrue(tp_mesh.get_group() in groups) + self.assertTrue(dp_mesh.get_group() in groups) + @with_comms def test_get_local_rank_raises_exception(self): mesh_shape = (2, self.world_size // 2) @@ -126,7 +127,7 @@ def test_device_mesh_2d(self): mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] for dim, dim_group in enumerate(dim_to_subgroups): @@ -191,7 +192,7 @@ def test_from_group_with_invalid_mesh(self): DeviceMesh.from_group(global_pg, "cuda", invalid_mesh) device_mesh = init_device_mesh(self.device_type, (2, 2)) - groups = device_mesh.get_group() + groups = device_mesh.get_all_groups() invalid_mesh = (0, 1, 2, 3) # 1D mesh when we need 2D regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups" with self.assertRaisesRegex(ValueError, regex): @@ -208,6 +209,15 @@ def test_raises_invalid_device_type(self): "cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp") ) + @with_comms + def test_set_mesh_dim_group_options(self): + device_type = "cuda" if torch.cuda.is_available() else "cpu" + _mesh_resources._set_mesh_dim_group_options(1, "fake", None) + + mesh_tensor = torch.arange(4).reshape(2, 2) + mesh = DeviceMesh(device_type, mesh_tensor) + self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") + class DeviceMeshTestNDim(DTensorTestBase): @property @@ -221,7 +231,7 @@ def test_device_mesh_nd(self): mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): self.assertTrue(dim < mesh_tensor.ndim) @@ -420,16 +430,16 @@ def world_size(self): @with_comms def test_raises_no_mesh_dim_found(self): - with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."): + with self.assertRaisesRegex( + RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" + ): mesh = init_device_mesh(self.device_type, (2, 4)) child_mesh = mesh["DP"] @with_comms def test_raises_invalid_mesh_dim_name(self): - child_mesh_dim_name = "PP" - with self.assertRaisesRegex( - KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist." - ): + child_mesh_dim_name = ("PP",) + with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): mesh_dim_names = ("DP", "TP") mesh = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=mesh_dim_names @@ -437,7 +447,7 @@ def test_raises_invalid_mesh_dim_name(self): child_mesh = mesh[child_mesh_dim_name] @with_comms - def test_get_item(self): + def test_get_item_2d(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( @@ -467,9 +477,41 @@ def test_get_item_1d(self): dp_mesh = mesh["dp"] self.assertEqual(dp_mesh, mesh) - with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"): + with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): dp_mesh = mesh["dim0"] + @with_comms + def test_get_item_3d(self): + mesh_shape = (2, 2, 2) + mesh_dim_names = ("Replicate", "Shard", "TP") + mesh_3d = init_device_mesh( + self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names + ) + + tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] + tp_group_idx = int(self.rank / 2) + self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) + + shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] + shard_group_idx = self.rank % 2 + self.rank // 4 * 2 + self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) + + replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] + replicate_group_idx = self.rank % 4 + self.assertEqual( + mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] + ) + + # We support both UX for nD slicing. + # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] + hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] + hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] + hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] + hsdp_group_idx = self.rank % 2 + self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) + self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) + self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) + @with_comms def test_cache_and_reuse_submesh_slice_result(self): mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) @@ -640,9 +682,9 @@ def test_all_gather_uneven(self): ) unpadded_list = [ ( - unpad_tensor(big_tensor_chunks[i], shard_dim, pad_sizes[i]) + unpad_tensor(big_tensor, shard_dim, pad_sizes[i]) if pad_sizes[i] > 0 - else big_tensor_chunks[i] + else big_tensor ) for i, big_tensor in enumerate(big_tensor_chunks) ] @@ -762,7 +804,7 @@ def test_broadcast_nd(self): local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ @@ -779,7 +821,7 @@ def test_scatter_nd(self): mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups - dim_to_subgroups = mesh.get_group() + dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index b31a2f717537..db44f1ce915d 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1084,12 +1084,14 @@ def _(ctx): # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( """local "L['self']" ID_MATCH""" - ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( - f"""{expected_guard_source} "L['self'].net" ID_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( - f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" + f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" + ).check( + f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" ).run( GUARDS_FILE.getvalue() ) diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 8383101d2093..cd126cc0d358 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -265,13 +265,17 @@ def num_keys_total(self): class TCPStoreTest(TestCase, StoreTestBase): + _use_libuv = False + def _create_store(self): - store = create_tcp_store() + store = create_tcp_store(use_libuv=self._use_libuv) store.set_timeout(timedelta(seconds=300)) return store def _create_store_with_ws(self, addr, world_size): - return create_tcp_store(addr, world_size, wait_for_workers=False) + return create_tcp_store( + addr, world_size, wait_for_workers=False, use_libuv=self._use_libuv + ) def test_address_already_in_use(self): err_msg_reg = "^The server socket has failed to listen on any local " @@ -282,8 +286,14 @@ def test_address_already_in_use(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 - store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841 + store1 = dist.TCPStore( + addr, port, 1, True, use_libuv=self._use_libuv + ) # noqa: F841 + store2 = dist.TCPStore( + addr, port, 1, True, use_libuv=self._use_libuv + ) # noqa: F841 + self.assertEqual(store1.libuvBackend, self._use_libuv) + self.assertEqual(store2.libuvBackend, self._use_libuv) @retry_on_connect_failures def test_multitenancy(self): @@ -293,8 +303,14 @@ def test_multitenancy(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 - store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 + store1 = dist.TCPStore( + addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv + ) # type: ignore[call-arg] # noqa: F841 + store2 = dist.TCPStore( + addr, port, 1, True, multi_tenant=True, use_libuv=self._use_libuv + ) # type: ignore[call-arg] # noqa: F841 + self.assertEqual(store1.libuvBackend, self._use_libuv) + self.assertEqual(store2.libuvBackend, self._use_libuv) @skip_if_win32() @retry_on_connect_failures @@ -308,6 +324,7 @@ def test_init_pg_and_rpc_with_same_socket(self): # We internally use a multi-tenant TCP store. Both PG and RPC should successfully # initialize even when using the same socket address. + os.environ["USE_LIBUV"] = "1" if self._use_libuv else "0" dist.init_process_group( backend="gloo", init_method="env://", @@ -325,7 +342,10 @@ def test_init_pg_and_rpc_with_same_socket(self): rpc_backend_options=backend_opts, ) + del os.environ["USE_LIBUV"] + assert "USE_LIBUV" not in os.environ rpc.shutdown() + dist.destroy_process_group() @skip_if_win32() def test_take_over_listen_socket(self): @@ -334,8 +354,16 @@ def test_take_over_listen_socket(self): addr, port, *_ = listen_sock.getsockname() listen_fd = listen_sock.detach() - store = dist.TCPStore(addr, port, 1, is_master=True, master_listen_fd=listen_fd) + store = dist.TCPStore( + addr, + port, + 1, + is_master=True, + master_listen_fd=listen_fd, + use_libuv=self._use_libuv, + ) + self.assertEqual(store.libuvBackend, self._use_libuv) store.set("key", "value") self.assertEqual(b"value", store.get("key")) @@ -373,7 +401,11 @@ def test_numkeys_delkeys(self): def _create_client(self, index, addr, port, world_size): client_store = dist.TCPStore( - addr, port, world_size=world_size, timeout=timedelta(seconds=10) + addr, + port, + world_size=world_size, + timeout=timedelta(seconds=10), + use_libuv=self._use_libuv, ) self.assertEqual(b"value", client_store.get("key")) client_store.set(f"new_key{index}", f"new_value{index}") @@ -387,6 +419,7 @@ def _create_client(self, index, addr, port, world_size): def _multi_worker_helper(self, world_size): addr = DEFAULT_HOSTNAME server_store = self._create_store_with_ws(addr, world_size) + self.assertEqual(server_store.libuvBackend, self._use_libuv) server_store.set("key", "value") port = server_store.port @@ -402,6 +435,7 @@ def test_multi_worker_with_nonfixed_world_size(self): def test_append(self): store = self._create_store() + self.assertEqual(store.libuvBackend, self._use_libuv) store.set("foo", "po") store.append("foo", "tato") store.append("bar", "po") @@ -411,12 +445,14 @@ def test_append(self): def test_multi_set(self): store = self._create_store() + self.assertEqual(store.libuvBackend, self._use_libuv) store.multi_set(["foo", "bar"], ["po", "tato"]) self.assertEqual(b"po", store.get("foo")) self.assertEqual(b"tato", store.get("bar")) def test_multi_get(self): store = self._create_store() + self.assertEqual(store.libuvBackend, self._use_libuv) store.set("foo", "po") store.set("bar", "tato") v0, v1 = store.multi_get(["foo", "bar"]) @@ -429,7 +465,14 @@ def test_store_timeout_on_missing_clients(self): r"Timed out after \d+ seconds waiting for clients. \d+/\d+ clients joined.", ): # world_size is 2 so it should timeout - dist.TCPStore("localhost", 0, 2, True, timeout=timedelta(seconds=2)) + dist.TCPStore( + "localhost", + 0, + 2, + True, + timeout=timedelta(seconds=2), + use_libuv=self._use_libuv, + ) # when wait_for_workers is not set, then there should be no exception raised dist.TCPStore( @@ -439,10 +482,13 @@ def test_store_timeout_on_missing_clients(self): True, timeout=timedelta(seconds=2), wait_for_workers=False, + use_libuv=self._use_libuv, ) class LibUvTCPStoreTest(TCPStoreTest): + _use_libuv = True + def _create_store(self): store = create_tcp_store(use_libuv=True) store.set_timeout(timedelta(seconds=300)) @@ -453,6 +499,33 @@ def _create_store_with_ws(self, addr, world_size): addr, world_size, wait_for_workers=False, use_libuv=True ) + def test_take_over_listen_socket(self): + """ + override the take_over_listen_socket test in TCPStoreTest. + Reason: we have not thoroughly tested libuv TCPStore initialization using + open Socket so we decide to not support this use for now. + TODO (xilunwu): enable this use case + """ + listen_sock: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_sock.bind(("localhost", 0)) + addr, port, *_ = listen_sock.getsockname() + listen_fd = listen_sock.detach() + + err_msg_reg = ( + "^The libuv TCPStore backend does not support " + "initialization with an listen fd" + ) + + with self.assertRaisesRegex(NotImplementedError, err_msg_reg): + store = dist.TCPStore( + addr, + port, + 1, + is_master=True, + master_listen_fd=listen_fd, + use_libuv=self._use_libuv, + ) + class PrefixTCPStoreTest(TestCase, StoreTestBase): def setUp(self): @@ -768,7 +841,7 @@ def test_extended_methods_fallbacks(self): class TestMultiThreadedWait(MultiThreadedTestCase): - # TODO: Use less hacky means of instantiating stores. + # TODO (xilunwu): Use less hacky means of instantiating stores. # Note, stores accumulate values per test. stores = [ dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1), @@ -776,9 +849,9 @@ class TestMultiThreadedWait(MultiThreadedTestCase): dist.PrefixStore( "pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) ), - create_tcp_store(), + create_tcp_store(use_libuv=False), create_tcp_store(use_libuv=True), - dist.PrefixStore("pre", create_tcp_store()), + dist.PrefixStore("pre", create_tcp_store(use_libuv=False)), dist.PrefixStore("pre", create_tcp_store(use_libuv=True)), ] @@ -871,7 +944,12 @@ def handler(a, b): self.assertTrue(rank_res[1], "rank1") -class InitPgWithUvStore(TestCase): +class InitPgWithNonUvStore(TestCase): + """ + This test shows how to use the legacy TCPStore (non-libuv) backend since libuv is now + the default backend. + """ + def tearDown(self): super().tearDown() os.environ.pop("USE_LIBUV", None) @@ -884,13 +962,13 @@ def test_with_url_param(self): "gloo", rank=0, world_size=1, - init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=1", + init_method=f"tcp://{DEFAULT_HOSTNAME}:{port}?use_libuv=0", ) self._run_test() def test_with_env_var(self): port = common.find_free_port() - os.environ["USE_LIBUV"] = "1" + os.environ["USE_LIBUV"] = "0" os.environ["MASTER_ADDR"] = DEFAULT_HOSTNAME os.environ["MASTER_PORT"] = str(port) dist.init_process_group("gloo", rank=0, world_size=1, init_method="env://") @@ -904,7 +982,7 @@ def _run_test(self): while isinstance(store, dist.PrefixStore): store = store.underlying_store self.assertTrue(isinstance(store, dist.TCPStore)) - self.assertTrue(store.libuvBackend) + self.assertFalse(store.libuvBackend) dist.destroy_process_group() diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index f94a4a6e5283..ca935ea69bc8 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -135,6 +135,7 @@ def test_onnxrt(self): def test_tvm(self): self._check_backend_works("tvm") self._check_backend_works("tvm", options={"scheduler": None}) + self._check_backend_works("tvm", options={"opt_level": 0}) def test_list_backends(self): self.assertIn("inductor", torch._dynamo.list_backends()) diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index 3bbf7270b06b..0e813c883785 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -8,7 +8,7 @@ import torch import torch._dynamo.test_case from torch._dynamo import bytecode_analysis, bytecode_transformation -from torch._dynamo.testing import skipIfNotPy311 +from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312 class BytecodeTests(torch._dynamo.test_case.TestCase): @@ -414,6 +414,119 @@ def test_remove_dead_code_with_exn_table_entries(self): self.assertEqual(tab[0].end, 4) self.assertEqual(tab[0].target, 6) + def test_bytecode_from_template(self): + def fn(d1): + for k, v in d1.items(): + d2[k] = v + + varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"} + insts = bytecode_transformation.bytecode_from_template(fn, varname_map) + for inst in insts: + self.assertIsNone(inst.starts_line) + if inst.opname.startswith("LOAD"): + self.assertNotIn(inst.argval, varname_map) + if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"): + self.assertIsNone(inst.arg) + self.assertFalse(inst.opname.startswith("RETURN")) + + @skipIfNotPy311 + def test_bytecode_from_template_noprefix(self): + # Test that 3.11+ prefix instructions are removed + def gen_fn(): + cl = None + + def fn(): + return cl + + return fn + + fn = gen_fn() + + dis_insts = list(dis.get_instructions(fn)) + names = {inst.opname for inst in dis_insts} + self.assertIn("RESUME", names) + self.assertIn("COPY_FREE_VARS", names) + + insts = bytecode_transformation.bytecode_from_template(fn) + names = {inst.opname for inst in insts} + self.assertNotIn("RESUME", names) + self.assertNotIn("COPY_FREE_VARS", names) + + def test_bytecode_from_template_noreturn1(self): + # Test that functions with multiple returns will have their + # returns replaced with jumps to the end + def fn(): + if x: + return y + z = 3 + return z + + dis_insts = list(dis.get_instructions(fn)) + dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts)) + self.assertGreater(len(dis_returns), 1) + self.assertTrue(dis_insts[-1].opname.startswith("RETURN")) + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + self.assertEqual(len(dis_insts), len(insts)) + for i0, i1 in zip(dis_insts, insts): + if i0.opname.startswith("RETURN"): + if i1 is insts[-1]: + continue + self.assertIn("JUMP", i1.opname) + self.assertIs(i1.target, insts[-1]) + + # Should work with 3.10, but testing with 3.11+ is sufficient. + # In 3.8, `fn` ends with a RETURN_VALUE. + @skipIfNotPy311 + def test_bytecode_from_template_noreturn2(self): + # Test function that doesn't end with RETURN_VALUE + def fn(): + if x: + return x + if x: + return x + raise RuntimeError + + dis_insts = list(dis.get_instructions(fn)) + self.assertFalse(dis_insts[-1].opname.startswith("RETURN")) + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + self.assertEqual(insts[-2].opname, dis_insts[-1].opname) + self.assertEqual(len(dis_insts) + 1, len(insts)) + for i0, i1 in zip(dis_insts, insts): + if i0.opname.startswith("RETURN"): + self.assertIn("JUMP", i1.opname) + self.assertIs(i1.target, insts[-1]) + + @skipIfNotPy312 + def test_bytecode_from_template_noreturn_const(self): + # Test 3.12+ RETURN_CONST + def fn(): + if x: + return 1 + return 0 + + dis_insts = list(dis.get_instructions(fn)) + dis_return_consts = list( + filter(lambda x: x.opname == "RETURN_CONST", dis_insts) + ) + self.assertGreater(len(dis_return_consts), 1) + self.assertTrue(dis_insts[-1].opname == "RETURN_CONST") + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + insts_i = 0 + for i, inst in enumerate(dis_insts): + if inst.opname == "RETURN_CONST": + self.assertEqual(insts[insts_i].opname, "LOAD_CONST") + insts_i += 1 + if insts_i != len(insts) - 1: + self.assertIn("JUMP", insts[insts_i].opname) + self.assertIs(insts[insts_i].target, insts[-1]) + insts_i += 1 + class BytecodeHookTests(torch._dynamo.test_case.TestCase): def test_bytecode_hook(self): diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 651c392f5dd2..47f8e8eeb863 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -497,7 +497,7 @@ def forward(self, x): a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") - with torch.cuda.amp.autocast(dtype=torch.torch.float64): + with torch.cuda.amp.autocast(dtype=torch.float64): c_float64 = torch.mm(a_float32, b_float32) return c_float64 @@ -796,7 +796,7 @@ def forward(self, x): self.assertEqual(exported.dtype, real_dtype) self.assertEqual(exported.device.index, 0) - self.assertEqual(exported.dtype, torch.torch.float16) + self.assertEqual(exported.dtype, torch.float16) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_autocast_arguments_binding(self): diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 890edca40ccc..c13fcd31dab7 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -304,13 +304,12 @@ def f3(x): self.assertEqual(cnt.frame_count, 0) def test_torch_guards_stack_frame_register_inlining_disable(self): - y = torch.nn.Parameter(torch.tensor([0.25, 0.25])) x = torch.tensor([0.5, 0.5]) class encoder(torch.nn.Module): def __init__(self, y): super().__init__() - self.register_parameter("param", y) + self.a = y @torch._dynamo.disable def helper(self, x, y): @@ -318,9 +317,9 @@ def helper(self, x, y): def forward(self, a, *args): x = a + a - return self.helper(x, self.param) + return self.helper(x, self.a) - e = encoder(y) + e = encoder(2.0) seen_frames = [] import contextlib @@ -465,6 +464,44 @@ def fn(a, b, c): self.assertEqual(cnt.frame_count, 1) + def test_assume_constant_result_on_user_defined_fn(self): + @torch._dynamo.assume_constant_result + def const_fn(n, s): + return torch.full([n], s) + + def fn(B): + B = const_fn(B.size(0), 13) + X = B * 2 + return X.tolist() + + B_list = [8] * 32 + + B = torch.tensor(B_list, dtype=torch.int32) + torch._dynamo.decorators.mark_static(B, 0) + + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + self.assertEqual( + fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) + ) + + def test_assume_constant_result_on_computation_with_graph_input(self): + @torch._dynamo.assume_constant_result + def check(y): + return y[0].item() == 1 + + def fn(x, y): + if check(y): + return x + 2 + else: + return x + 1 + + y = torch.tensor([1]) + x = torch.tensor(1) + + self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 4ceed0fad3dd..a3c63ef66152 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,17 +78,11 @@ def make_dynamic_cls(cls): del test if TEST_Z3: - # this only fails when z3 is available - unittest.expectedFailure( - # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. - # Ref: https://github.com/sympy/sympy/issues/25146 - DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 - ) - - # TODO model is somehow not being freed when z3 is available - unittest.expectedFailure( - DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 - ) + if not config.inline_inbuilt_nn_modules: + # TODO model is somehow not being freed when z3 is available + unittest.expectedFailure( + DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 + ) unittest.expectedFailure( # Test is only valid without dynamic shapes @@ -99,6 +93,13 @@ def make_dynamic_cls(cls): DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_dynamic_shapes # noqa: F821 ) +# Also take more than 30m as of 15cc9f2e7e7b2b175f24755925dc38d4d430905d +DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes = slowTest( # noqa: F821 + DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes # noqa: F821 +) +DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes = slowTest( # noqa: F821 + DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes # noqa: F821 +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py new file mode 100644 index 000000000000..1cf31f9edc36 --- /dev/null +++ b/test/dynamo/test_exceptions.py @@ -0,0 +1,234 @@ +# Owner(s): ["module: dynamo"] + +import torch +import torch._dynamo.config + +import torch._dynamo.test_case +import torch._functorch.config +import torch.utils.checkpoint + + +class ExceptionTests(torch._dynamo.test_case.TestCase): + def test_exception(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError + except Exception: + x = torch.sigmoid(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception2(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError + except (NotImplementedError, AttributeError) as e: + x = torch.sigmoid(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception3(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError("Not implemented") + except AssertionError: + x = torch.sigmoid(x) + except NotImplementedError: + x = torch.cos(x) + finally: + x = torch.cos(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_with_another_exception(self): + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + raise NotImplementedError("Not implemented") + except NotImplementedError as e: + x = torch.sigmoid(x) + try: + x = torch.cos(x) + raise AssertionError + except AssertionError: + x = torch.cos(x) + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_else(self): + def gn(x): + return torch.cos(x) + + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + x = gn(x) + except Exception: + x = torch.sigmoid(x) + else: + x = torch.cos(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + # TODO(anijain2305) - does not work with fullgraph=True + def test_exception_with_another_exception2(self): + def gn(x): + try: + x = torch.cos(x) + raise NotImplementedError("Not implemented") + except NotImplementedError as e: + x = torch.sigmoid(x) + raise + + def fn(x): + try: + x = torch.cos(x) + gn(x) + except Exception: + pass + return x + + x = torch.randn(4) + ref = fn(x) + # Cant use fullgraph=True because RERAISE is not supported + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn(x) + + # TODO(anijain2305) - does not work with fullgraph=True + def test_exception_with_ctx_manager(self): + def fn(x): + x = torch.cos(x) + try: + with torch.no_grad(): + x = torch.sin(x) + raise NotImplementedError("Not implemented") + except NotImplementedError as e: + x = torch.sigmoid(x) + return x + + x = torch.randn(4) + ref = fn(x) + # Cant use fullgraph=True because WITH_EXCEPT_START is not supported + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_raised_from_child(self): + def gn(): + raise NotImplementedError("foo") + + def fn(x): + x = torch.cos(x) + try: + x = torch.sin(x) + gn() + x = torch.sin(x) + except Exception: + x = torch.sigmoid(x) + + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_nn_module_getattr(self): + class A: + def __init__(self): + self._b = 20 + + def __getattr__(self, name): + fixed_name = "_" + name + if fixed_name in self.__dict__: + return self.__dict__[fixed_name] + raise AttributeError(f"{name} absent") + + class B(A): + def __init__(self): + self.a = 10 + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return 30 + + obj = B() + + def fn(x): + return x * obj.a * obj.b * obj.c + + x = torch.ones(4) + ref = fn(x) + print(ref) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) + def test_custom_getattr_on_module_exception(self): + class Foo(torch.nn.Module): + def __init__(self, a=3): + super().__init__() + self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2)) + + def __getattr__(self, name): + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "a_copy": + return self.a + raise + + def forward(self, x): + return x * self.a * self.a_copy + + mod = Foo() + opt_mod = torch.compile(mod, backend="eager", fullgraph=True) + + x = torch.ones(4) + self.assertEqual(mod(x), opt_mod(x)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 9f1417e23247..dbf983faabb7 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -1509,6 +1509,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: graph, guards = torch._dynamo.export(model)(inp) self.assertEqual(model(inp), graph(inp)) + def test_export_with_constant_in_unspecialized_nn_module(self): + class Module(torch.nn.Module): + def __init__(self, y): + super().__init__() + self.y = y + + @torch._dynamo.assume_constant_result + def check(self): + return self.y[0].item() == 1 + + def forward(self, x): + # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo + self.device = x.device + + if self.check(): + return x + 1 + else: + return x + 2 + + model = Module(torch.tensor([1])) + inp = torch.ones(3, 4) + graph, _ = torch._dynamo.export(model)(inp) + self.assertEqual(model(inp), graph(inp)) + def test_export_decomp(self): def f(x): return x.t() + x.t() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 472e9c56bae6..e2baebf60321 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1614,6 +1614,11 @@ def test_ndarray_builtin_functions(x): def test_numpy_dtype_argument_to_function(x): return np.ones_like(x, dtype=np.float64) + @make_test + def test_numpy_dtype_call_in_function(x): + dt = np.dtype("float") + return np.full_like(x, 2.4, dtype=dt) + @make_test def test_numpy_linalg(x): return np.linalg.norm(x.numpy(), axis=0) @@ -2192,6 +2197,84 @@ def inner(): self.assertTrue(same(program(input1, input2), input1 + input1)) + @parametrize("int_or_float", ("int", "float")) + def test_np_constant_collections_as_input(self, int_or_float): + info_func = getattr(np, f"{int_or_float[0]}info") + dt_string_arg = f"{int_or_float}16" + np_dt_attr = getattr(np, dt_string_arg) + + dt_args = [dt_string_arg, np_dt_attr] + arg_variants_iter = itertools.chain( + dt_args, map(np.dtype, dt_args), map(info_func, dt_args) + ) + + def func(a, b, info_or_dt): + return a + info_func(info_or_dt).max + + opt_fn = torch.compile(func) + + a = torch.randn(2) + b = torch.randn(2) + eager_result = func(a, b, dt_args[0]) + + for arg in arg_variants_iter: + opt_result = opt_fn(a, b, arg) + self.assertTrue(same(opt_result, eager_result)) + + @parametrize( + "typ, info_func", + [ + (int, np.iinfo), + (float, np.finfo), + ], + name_fn=lambda t, _: t.__name__, + ) + def test_np_constant_collections_guards(self, typ, info_func): + def func_info(a, info): + return a + info.max + + def func_dtype(a, dt): + return a + info_func(dt).max + + dt_args = [ + np.dtype(typ), + np.ones((1,), dtype=typ).dtype, + np.dtype(np.dtype(typ).name), + np.dtype(typ.__name__), + ] + cnts_1 = torch._dynamo.testing.CompileCounter() + opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype) + a = torch.zeros(3, dtype=typ) + for arg in dt_args: + r = opt_fn_dtype(a, arg) + # each should produce an identical arg + self.assertEqual(cnts_1.frame_count, 1) + + cnts_2 = torch._dynamo.testing.CompileCounter() + opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info) + info_args = [info_func(dt) for dt in dt_args] + for arg in info_args: + r = opt_fn_info(a, arg) + + # each should produce an identical arg + self.assertEqual(cnts_2.frame_count, 1) + + if typ is float: + dt_extra = np.dtype(np.float16) + else: + dt_extra = np.dtype(np.int16) + info_extra = info_func(dt_extra) + + eager_result_dtype = func_dtype(a, dt_extra) + compile_result_dtype = opt_fn_dtype(a, dt_extra) + self.assertEqual(cnts_1.frame_count, 2) + self.assertEqual(eager_result_dtype, compile_result_dtype) + + eager_result_info = func_info(a, info_extra) + compile_result_info = opt_fn_info(a, info_extra) + self.assertEqual(cnts_2.frame_count, 2) + self.assertEqual(eager_result_info, compile_result_info) + def test_compare_constant_and_tensor(self): for op in [ operator.lt, @@ -2229,6 +2312,24 @@ def test(x, y): test(True, False) test(torch.ones(4, dtype=torch.float32), 1.1) + def test_index(self): + def fn(x, t): + v = operator.index(x) + torch.mul(t, v) + + def test(a, b): + self.assertEqual(opt_fn(a, b), fn(a, b)) + + for dynamic in [True, False]: + torch._dynamo.reset() + opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn) + t = torch.ones(1) + test(10, t) + test(-100, t) + test(10, t) + test(False, t) + test(True, t) + def test_truth(self): def fn(x, y): return operator.truth(x) and bool(y) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 9b86a90b02f3..c934cf55e8f5 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2746,6 +2746,26 @@ def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm + def test_hessian_graph_break(self): + counters.clear() + + def wrapper_fn(x): + return torch.func.hessian(torch.sin)(x) + + x = torch.randn(4, 3) + expected = wrapper_fn(x) + got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 2) + self.assertEqual( + { + "'skip function disable in file _dynamo/decorators.py'": 1, + "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, + }, + {munge_exc(k): v for k, v in counters["graph_break"].items()}, + ) + + @unittest.expectedFailure def test_hessian(self): counters.clear() @@ -2880,6 +2900,7 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) + @unittest.expectedFailure def test_hessian_argnums(self): counters.clear() @@ -3032,6 +3053,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """ return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""", ) + @unittest.expectedFailure def test_hessian_disable_capture(self): counters.clear() @@ -3058,6 +3080,26 @@ def wrapper_fn(x): ) self.assertEqual(actual, expected) + def test_jacrev_graph_break(self): + counters.clear() + + def wrapper_fn(x): + return torch.func.jacrev(torch.sin)(x) + + x = torch.randn(4, 3) + expected = wrapper_fn(x) + got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 2) + self.assertEqual( + { + "'skip function disable in file _dynamo/decorators.py'": 1, + "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, + }, + {munge_exc(k): v for k, v in counters["graph_break"].items()}, + ) + + @unittest.expectedFailure def test_jacrev(self): counters.clear() @@ -3134,6 +3176,7 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) + @unittest.expectedFailure def test_jacrev_two_tensors_argnums(self): counters.clear() @@ -3216,6 +3259,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacrev_has_aux(self): counters.clear() @@ -3300,6 +3344,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacrev_disable_capture(self): counters.clear() @@ -4246,6 +4291,26 @@ def wrapper_fn(x, y): self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) + def test_jacfwd_graph_break(self): + counters.clear() + + def wrapper_fn(x): + return torch.func.jacfwd(torch.sin)(x) + + x = torch.randn(4, 3) + expected = wrapper_fn(x) + got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) + self.assertEqual(expected, got) + self.assertEqual(len(counters["graph_break"]), 2) + self.assertEqual( + { + "'skip function disable in file _dynamo/decorators.py'": 1, + "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, + }, + {munge_exc(k): v for k, v in counters["graph_break"].items()}, + ) + + @unittest.expectedFailure def test_jacfwd(self): counters.clear() @@ -4329,6 +4394,7 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) + @unittest.expectedFailure def test_jacfwd_two_tensors_argnums(self): counters.clear() @@ -4418,6 +4484,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacfwd_has_aux(self): counters.clear() @@ -4512,6 +4579,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacfwd_randomness(self): counters.clear() @@ -4615,6 +4683,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) + @unittest.expectedFailure def test_jacfwd_disable_capture(self): counters.clear() @@ -5118,10 +5187,10 @@ def wrapper_fn(x): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): - l_self_tensor_constant0 = L_self_tensor_constant0 + def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): + l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ - alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None + alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) @@ -5140,16 +5209,16 @@ def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): actual, """\ class GraphModule(torch.nn.Module): - def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): - getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ - getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ + def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): + l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ + l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ l_flat_tangents_1_ = L_flat_tangents_1_ - _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None + _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None - mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None + mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None return (mul_tensor,) """, ) diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py new file mode 100644 index 000000000000..f7ba32bc15f3 --- /dev/null +++ b/test/dynamo/test_inline_inbuilt_nn_modules.py @@ -0,0 +1,62 @@ +# Owner(s): ["module: dynamo"] + +from torch._dynamo import config +from torch._dynamo.testing import make_test_cls_with_patches + +try: + from . import ( + test_aot_autograd, + test_functions, + test_higher_order_ops, + test_misc, + test_modules, + # test_repros, + ) +except ImportError: + import test_aot_autograd + import test_functions + import test_higher_order_ops + import test_misc + import test_modules + + +test_classes = {} + + +def make_inline_inbuilt_nn_modules_cls(cls): + suffix = "_inline_inbuilt_nn_modules" + + cls_prefix = "InlineInbuiltNNModules" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "inline_inbuilt_nn_modules", True), + xfail_prop="_expected_failure_inline_inbuilt_nn_modules", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +tests = [ + test_misc.MiscTests, + test_functions.FunctionTests, + test_modules.NNModuleTests, + test_higher_order_ops.HigherOrderOpTests, + test_higher_order_ops.FuncTorchHigherOrderOpTests, + test_aot_autograd.AotAutogradFallbackTests, + # test_repros.ReproTests, +] +for test in tests: + make_inline_inbuilt_nn_modules_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 6412d015d56b..9014be6f7557 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -3,6 +3,7 @@ import torch._dynamo from torch._dynamo.test_minifier_common import MinifierTestBase +from torch.testing._internal.common_utils import skipIfNNModuleInlined requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") @@ -111,6 +112,7 @@ def test_after_dynamo_cuda_accuracy_backend_passes(self): ) # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + @skipIfNNModuleInlined() @requires_cuda def test_cpu_cuda_module_after_dynamo(self): backend_name = "relu_compile_error_TESTING_ONLY" diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 3ed06a55c837..02f7c68aa1a9 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -33,6 +33,7 @@ import torch.onnx.operators import torch.utils._pytree as pytree +import torch.utils.cpp_extension from torch import Tensor from torch._C import FileCheck from torch._dynamo import allow_in_graph @@ -223,6 +224,38 @@ def fn(x): with self.assertRaises(TypeError): fn(torch.randn(16)) + def test_cpp_extension_recommends_custom_ops(self): + cpp_source = """ + #include + at::Tensor foobar(const at::Tensor& x) { + return x.clone(); + } + """ + module = torch.utils.cpp_extension.load_inline( + name="mylib", + cpp_sources=cpp_source, + functions="foobar", + verbose=True, + ) + + x = torch.ones(2, 2, requires_grad=True) + counters.clear() + + @torch.compile(backend="eager") + def f(x): + return module.foobar(x) + + with self.assertWarnsOnceRegex( + UserWarning, ".*https://pytorch.org/docs/main/notes/custom_operators.html.*" + ): + f(x) + self.assertEqual(len(counters["graph_break"]), 1) + first_graph_break = list(counters["graph_break"].keys())[0] + self.assertExpectedInline( + first_graph_break, + """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/docs/main/notes/custom_operators.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""", + ) + def test_callpacked(self): def call_packed(args): a, b, c = args @@ -665,9 +698,9 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None return ()""", ) @@ -724,11 +757,11 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None - getitem_4: "f32[3]" = foo_default[0] - getitem_5: "f32[3]" = foo_default[1]; foo_default = None + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None return (getitem_4, getitem_5)""", ) @@ -816,9 +849,9 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg0_1 = arg1_1 = None + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = None return ()""", ) @@ -4231,6 +4264,59 @@ def fn_has_breaks(x): opt_fn(x) self.assertEqual(cnts.frame_count, 2) + def test_id_guarded_object(self): + class UDO: + @torch.compile(backend="eager") + def call(self, x, ref_id): + self_id = id(self) + if self_id == ref_id: + x = torch.mul(x, 1.0) + else: + x = torch.mul(x, 0) + return x + + # Make sure we do recompile when id(self) is executed on + # different self objects. + x = torch.ones(2) + obj1 = UDO() + obj1_id = id(obj1) + self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) + + obj2 = UDO() + # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. + self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2)) + + def test_id_guarded_module(self): + class M(torch.nn.Module): + def forward(self, x, ref_id): + self_id = id(self) + if self_id == ref_id: + x = torch.mul(x, 1.0) + else: + x = torch.mul(x, 0) + return x + + cnts = torch._dynamo.testing.CompileCounter() + + # Make sure we do recompile when id(self) is executed on + # different self objects. + x = torch.ones(2) + m1 = M() + m1_id = id(m1) + opt_m1 = torch._dynamo.optimize(cnts, nopython=True)(m1) + self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) + self.assertEqual(opt_m1(x, m1_id), torch.ones(2)) + + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 1) + + m2 = M() + opt_m2 = torch._dynamo.optimize(cnts, nopython=True)(m2) + # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. + self.assertEqual(opt_m2(x, m1_id), torch.zeros(2)) + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.op_count, 2) + def test_id_of_nn_module(self): class M(torch.nn.Module): def forward(self, x, ref_id): @@ -6754,7 +6840,7 @@ def fn(): x += 1 return x - opt_fn = torch._dynamo.optimize("eager")(fn) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) self.assertEqual(opt_fn(), torch.tensor([2.0])) def test_nested_sequential_with(self): @@ -9195,8 +9281,8 @@ def test_shape_env_equal_constructor(self): ShapeEnv not equal: field values don't match: ==> settings: values don't match. - > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False) - > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False) + > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) + > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False) """, ) self._replay_and_check(main) @@ -9223,7 +9309,7 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9257,7 +9343,7 @@ def test_shape_env_equal_unbacked(self): > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} + > Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} > Right: {} """, ) @@ -9334,8 +9420,8 @@ def test_shape_env_equal_evaluate_expr_replacement(self): > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) @@ -9372,8 +9458,8 @@ def test_shape_env_equal_evaluate_expr_refinement(self): > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} + > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} """, ) self._replay_and_check(main) @@ -9398,10 +9484,7 @@ def test_shape_env_equal_runtime_assert(self): ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(Mod(u0, 3), 0)]} - > Right: {} -==> divisible: values don't match. - > Left: {Mod(u0, 3)} + > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} @@ -10321,6 +10404,24 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_module_dunder_dict(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.foo = 1 + self.bar = 2 + self.baz = 3 + + def forward(self, x): + if "foo" in self.__dict__: + return x * self.bar + return x * self.baz + + mod = MyModule() + x = torch.randn(10) + opt_mod = torch.compile(mod, backend="eager", fullgraph=True) + self.assertEqual(mod(x), opt_mod(x)) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index b2c1581d7e86..e6a6fc6dab58 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -101,6 +101,15 @@ def fn(obj: BaseModelOutput): self._common(fn, 2) + @maybe_skip + def test_mo_getattr_missing(self): + def fn(obj: BaseModelOutput): + if getattr(obj, "asdf", None) is not None: + obj.asdf += 1 + return obj.attentions + 1 + + self._common(fn, 1) + @maybe_skip def test_mo_getitem(self): def fn(obj: BaseModelOutput): @@ -166,6 +175,59 @@ def fn(obj): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) + @maybe_skip + def test_mo_init2(self): + # this ModelOutput subclass runs a different __post_init__ codepath + @dataclasses.dataclass + class MyDataClass(ModelOutput): + x: torch.FloatTensor = None + + def fn(x): + obj = MyDataClass(x=x) + return obj + + inp = torch.randn(3, 3) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + self.assertEqual(fn(inp).x, opt_fn(inp).x) + + @maybe_skip + def test_mo_init_with_disable(self): + # Can result in "non-function or method super: " + # graph breaks (although it may not be the first) + # Minimal repro for https://github.com/pytorch/pytorch/issues/126028 + @dataclasses.dataclass + class MyDataClass(ModelOutput): + x: torch.FloatTensor = None + + @torch._dynamo.disable(recursive=False) + def fn(x): + return MyDataClass(x=x) + + inp = torch.randn(3, 3) + opt_fn = torch._dynamo.optimize("eager")(fn) + self.assertEqual(fn(inp).x, opt_fn(inp).x) + + @maybe_skip + def test_mo_newkey(self): + obj = BaseModelOutput() + + def fn(obj): + return obj["wwww"] + 1 + + inp = torch.randn(3, 3) + obj["wwww"] = inp + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + self.assertEqual(fn(obj), opt_fn(obj)) + + @maybe_skip + def test_mo_from_outside(self): + def fn(obj): + return obj.attentions + 1 + + obj = BaseModelOutput(attentions=torch.randn(3, 3)) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + self.assertEqual(fn(obj), opt_fn(obj)) + @maybe_skip def test_HF_bert_model_output(self): class BertPooler(torch.nn.Module): diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index b22f02ee2fcc..dbfef8af4386 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3,6 +3,8 @@ import collections import copy import itertools +import os +import tempfile import traceback import types import unittest @@ -16,10 +18,10 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch.nn.functional as F +from torch._dynamo.debug_utils import same_two_models from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.testing import expectedFailureDynamic, same -from torch._dynamo.utils import ifdynstaticdefault from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import Parameter, UninitializedParameter @@ -1105,37 +1107,6 @@ def forward(self, x): return self.m(x) -class ModuleWithIntAttr(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(4, 4) - self.step = 10 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + 1 - self.step += 1 - return self.layer(x) + self.step - - -class UnspecInlinableModule(torch.nn.Module): - torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule - - def forward(self, x): - return torch.sin(x) - - -class UnspecModuleWithIntAttr(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = UnspecInlinableModule() - self.step = 10 - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + 1 - self.step += 1 - return self.layer(x) + self.step - - def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -1389,31 +1360,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) - def test_nn_module_unspec_int_attr(self): - for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]: - mod = module_class() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod)) - x = torch.randn(3, 4) - - # Compiling self.step as static. - ref1 = mod(x) - res1 = opt_mod(x) - self.assertTrue(torch.allclose(ref1, res1)) - self.assertEqual(cnt.frame_count, 1) - - # Compiling self.step as dynamic. - ref2 = mod(x) - res2 = opt_mod(x) - self.assertTrue(torch.allclose(ref2, res2)) - self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) - - # No re-compilation! - ref3 = mod(x) - res3 = opt_mod(x) - self.assertTrue(torch.allclose(ref3, res3)) - self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) - # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module1(self): @@ -2739,6 +2685,49 @@ def fn(x): self.assertEqual(test_functions._variable, 1) self.assertEqual(res, 3 * torch.ones(10)) + @unittest.skipIf( + "inductor" not in torch._dynamo.list_backends(), + "inductor backend is not available", + ) + def test_save_and_load_inductor(self): + mod = MockModule() + opt_mod = torch.compile(mod, backend="inductor") + inp = torch.randn(10, 10) + opt_mod(inp) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) + loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + loaded_model(inp) + self.assertTrue(same_two_models(loaded_model, mod, [inp])) + self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) + + torch._dynamo.reset() # force recompiles + torch._inductor.metrics.generated_kernel_count = 0 + loaded_model(inp) + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) + + def test_save_and_load_all_backends(self): + mod = MockModule() + inp = torch.randn(10, 10) + for backend in torch._dynamo.list_backends(): + try: + opt_mod = torch.compile(mod, backend=backend) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) + loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + torch._dynamo.reset() # force recompiles + torch._inductor.metrics.generated_kernel_count = 0 + opt_mod(inp) + opt_success = torch._inductor.metrics.generated_kernel_count == 0 + torch._dynamo.reset() # force recompiles + torch._inductor.metrics.generated_kernel_count = 0 + loaded_model(inp) + loaded_success = torch._inductor.metrics.generated_kernel_count == 0 + self.assertEqual(opt_success, loaded_success) + except torch._dynamo.exc.BackendCompilerFailed: + pass + def test_monkeypatching_forward(self): class FakeModule(torch.nn.Module): def forward(self, x): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index ae317a78d96f..8dd1b91f43f7 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2,6 +2,7 @@ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_rewrite_assert_with_msg and test_rewrite_assert_without_msg) """ + # Owner(s): ["module: dynamo"] import collections import contextlib @@ -31,6 +32,7 @@ import torch._functorch.config import torch.library +import torch.utils._pytree as pytree from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import CompileCounter, rand_strided, same @@ -1150,13 +1152,12 @@ def test_reformer_eval(self): def test_reformer_train(self): with torch.enable_grad(): cnt = self._reformer(nopython=False) - # cant inline torch.autograd.Function means graph break - if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline(cnt.frame_count, """1""") - self.assertExpectedInline(cnt.op_count, """5""") - else: - self.assertExpectedInline(cnt.frame_count, """1""") - self.assertExpectedInline(cnt.op_count, """5""") + expected_op_count = ( + """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5""" + ) + + self.assertExpectedInline(cnt.frame_count, """1""") + self.assertExpectedInline(cnt.op_count, expected_op_count) @disable_translation_validation_if_dynamic_shapes def test_longformer_chunk(self): @@ -1609,7 +1610,10 @@ def test_issue175(self): opt_model(inp) opt_model(inp) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 12) + + self.assertEqual( + 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + ) def test_exec_import(self): def fn1(): @@ -3051,7 +3055,7 @@ def f(x): with self.assertRaisesRegex(AssertionError, "torch.Size"): opt_f(args) self.assertEqual( - torch._dynamo.utils.counters["unimplemented"][ + torch._dynamo.utils.counters["graph_break"][ "assert with non-string message" ], 1, @@ -5026,6 +5030,78 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) + def test_hasattr_builtin(self): + class MyClass: + foo: int = 1 + + def func(x, m): + if getattr(type(m), "foo", 0): + return x + MyClass.foo + return x + + opt_func = torch.compile(func, backend="eager", fullgraph=True) + m = MyClass() + x = torch.zeros(()) + self.assertEqual(func(x, m), opt_func(x, m)) + self.assertEqual(func(x, 0), opt_func(x, 0)) + + def test_grad(self): + def fn(x, y): + x._grad = y + return x.grad.data + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager") + self.assertEqual(fn(x, y), opt_fn(x, y)) + + def test_nn_module_stack_bc(self): + from torch._dynamo.mutation_guard import GenerationTracker + + def compiler(gm, *args): + module_stacks = [ + node.meta.get("nn_module_stack", None) for node in gm.graph.nodes + ] + module_stacks, _ = pytree.tree_flatten(module_stacks) + module_stacks = [x for x in module_stacks if isinstance(x, str)] + for stack in module_stacks: + self.assertTrue("_module" not in stack) + return gm.forward + + class SubMod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.submod1 = SubMod() + self.submod2 = SubMod() + + def forward(self, x): + return self.submod1(x) + self.submod2(x) + + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) + + with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) + + # an example similar to Pippy usecase + mod = Mod() + GenerationTracker.tag(mod.submod1) + GenerationTracker.mark_class_dynamic(type(mod.submod1)) + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index 43fe1ba1ece4..3ced7859cd7e 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -170,6 +170,8 @@ def test_do_not_skip_side_effects(self): global _variable, _variable_2 for mode in range(1, 7): + torch._dynamo.reset() + _variable = 0 _variable_2 = 0 diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index ea44a5e0771d..c27118c74fde 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -77,6 +77,14 @@ def format(self, record): metadata["stack"] = "STACK" if "compilation_metrics" in metadata: metadata["compilation_metrics"] = "METRICS" + if "describe_storage" in metadata: + metadata["describe_storage"]["describer_id"] = "ID" + if "describe_tensor" in metadata: + metadata["describe_tensor"]["describer_id"] = "ID" + if "view_func" in metadata["describe_tensor"]: + metadata["describe_tensor"]["view_func"] = "VIEW_FUNC" + if "describe_source" in metadata: + metadata["describe_source"]["describer_id"] = "ID" return json.dumps(metadata) @@ -136,6 +144,9 @@ def test_schedule(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -157,6 +168,9 @@ def test_cudagraphs(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -182,7 +196,13 @@ def fn(x, y): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "l_y_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -191,6 +211,9 @@ def fn(x, y): {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -211,6 +234,9 @@ def test_example_fn(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -234,6 +260,9 @@ def test_dynamo_error(self): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 ) @@ -263,6 +292,9 @@ def throw(x): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -310,10 +342,16 @@ def forward(self, x): {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1024, 1024], "l__self___layers_0": [1024, 1024], "l__self___layers_1": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -350,6 +388,9 @@ def fn(x): {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -379,11 +420,23 @@ def fn(a, b): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 800}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 20], "is_leaf": true, "stride": [20, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 2400}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [20, 30], "is_leaf": true, "stride": [30, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [10, 20], "l_b_": [20, 30], "matmul": [10, 30]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": ["s0", "s1"], "l_b_": ["s1", "s3"], "matmul": ["s0", "s3"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -414,11 +467,17 @@ def inner(x, ys, zs): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 96887da09ea3..1bb571ccd0e3 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -433,6 +433,43 @@ def fn(x): res = fn(input) self.assertIsInstance(res, LocalSubclass) + def test_torch_function_list_args(self): + HANDLED_FUNCTIONS = {} + + class MyClass: + def __init__(self, foo): + self.foo = foo + + @classmethod + def __torch_function__( + cls, + func, + types, + args=(), + kwargs=None, + ): + if kwargs is None: + kwargs = {} + if func not in HANDLED_FUNCTIONS or not all( # noqa: C419 + [ # noqa: C419 + issubclass(t, (torch.Tensor, MyClass)) for t in types + ] + ): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + + def _stack(input, dim=0, *, out=None): + return MyClass(sum([x.foo for x in input])) + + HANDLED_FUNCTIONS[torch.stack] = _stack + + @torch.compile(backend="eager", fullgraph=True) + def fn(v0, v1): + return torch.stack([v0, v1]) + + ret = fn(MyClass(1), MyClass(1)) + self.assertEqual(ret.foo, 2) + @parametrize( "comparison", [ diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 317fd15195ba..d5fdc006949e 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -332,7 +332,6 @@ def fn(x): # specialization is allowed) opt_fn(x) - @unittest.expectedFailure def test_conv1d_symint_padding(self): kernel = torch.randn(1, 1, 4) @@ -341,7 +340,6 @@ def func(x): out = F.conv1d(x, kernel, padding=padding, stride=2) return out - # TODO: NameError: name 's1' is not defined when dynamic=True opt_func = torch.compile(func) x = torch.randn(1, 1, 175) diff --git a/test/dynamo/test_view.py b/test/dynamo/test_view.py new file mode 100644 index 000000000000..2d63e86af162 --- /dev/null +++ b/test/dynamo/test_view.py @@ -0,0 +1,41 @@ +# Owner(s): ["module: dynamo"] +import torch + +import torch._dynamo +import torch._dynamo.test_case + + +@torch._dynamo.config.patch("capture_scalar_outputs", True) +class ViewTests(torch._dynamo.test_case.TestCase): + def test_view_to_2d(self): + @torch.compile(fullgraph=True, backend="eager") + def f(t, _u0): + u0 = t[0].item() + u1 = t[1].item() + torch._check_is_size(u0) + torch._check_is_size(u1) + n = u0 * u1 + a = torch.randn(n) + return a.view(-1, _u0) + + t = torch.tensor([2, 4], dtype=torch.int32) + f(t, 2) + + def test_view_to_1d(self): + @torch.compile(fullgraph=True, backend="eager") + def f(t, _n): + u0 = t[0].item() + u1 = t[1].item() + torch._check_is_size(u0) + torch._check_is_size(u1) + a = torch.randn(u0, u1) + return a.view(_n) + + t = torch.tensor([2, 4], dtype=torch.int32) + f(t, 8) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_max_norm deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_basic deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor b/test/dynamo_expected_failures/TestEmbeddingNN.test_embedding_sparse_empty_tensor deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset b/test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report b/test/dynamo_expected_failures/TestExperimentalUtils.test_profiler_pattern_matcher_json_report deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype0 b/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype0 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype1 b/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype1 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype2 b/test/dynamo_expected_failures/TestFromBuffer.test_basic_little_dtype2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check b/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim b/test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterDict b/test/dynamo_expected_failures/TestNN.test_ParameterDict deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_ParameterList b/test/dynamo_expected_failures/TestNN.test_ParameterList deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_Sequential_iadd b/test/dynamo_expected_failures/TestNN.test_Sequential_iadd deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting b/test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag b/test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR b/test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_linear_broadcasting b/test/dynamo_expected_failures/TestNN.test_linear_broadcasting deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/benchmarks/operator_benchmark/c2/__init__.py b/test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion similarity index 100% rename from benchmarks/operator_benchmark/c2/__init__.py rename to test/dynamo_expected_failures/TestNN.test_overwrite_module_params_on_conversion diff --git a/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False b/test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization_swap_False deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/caffe2/quantization/__init__.py b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True similarity index 100% rename from caffe2/quantization/__init__.py rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_forward_swap_True diff --git a/test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta b/test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True similarity index 100% rename from test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta rename to test/dynamo_expected_failures/TestNNParametrization.test_new_spectral_norm_swap_True diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu b/test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu b/test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning b/test/dynamo_expected_failures/TestPruningNN.test_identity_pruning deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu b/test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency similarity index 100% rename from test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu rename to test/dynamo_expected_failures/TestPruningNN.test_pruning_id_consistency diff --git a/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc b/test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F b/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 669c3d91e849..eeee3685e1fb 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -248,6 +248,8 @@ aten::_foreach_log2 aten::_foreach_log2.out aten::_foreach_log2_ aten::_foreach_log_ +aten::_foreach_max +aten::_foreach_max.out aten::_foreach_maximum.List aten::_foreach_maximum.List_out aten::_foreach_maximum.Scalar @@ -386,6 +388,7 @@ aten::_int_mm aten::_int_mm.out aten::_is_all_true aten::_is_any_true +aten::_jagged_to_padded_dense_forward aten::_lazy_clone aten::_linalg_check_errors aten::_linalg_det @@ -475,6 +478,7 @@ aten::_nnpack_spatial_convolution.out aten::_nnz aten::_pack_padded_sequence aten::_pack_padded_sequence.out +aten::_padded_dense_to_jagged_forward aten::_pdist_backward aten::_pdist_backward.out aten::_pdist_forward @@ -643,8 +647,6 @@ aten::adaptive_max_pool3d_backward.grad_input aten::addbmm aten::addbmm.out aten::addr_ -aten::alias_copy -aten::alias_copy.out aten::allclose aten::angle aten::angle.out diff --git a/test/export/test_converter.py b/test/export/test_converter.py index ab6c3b802418..362e0a6b2ba3 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,25 +1,575 @@ # Owner(s): ["oncall: export"] +import unittest +from typing import Dict, Tuple + import torch + +import torch.utils._pytree as pytree + from torch._dynamo.test_case import TestCase from torch._export.converter import TS2EPConverter - +from torch.export import ExportedProgram from torch.testing._internal.common_utils import run_tests +requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") + class TestConverter(TestCase): + def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram: + ts_model = torch.jit.script(mod) + ep = TS2EPConverter(ts_model, inp).convert() + ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) + orig_out, _ = pytree.tree_flatten(mod(*inp)) + + # Check module. + if isinstance(mod, torch.nn.Module): + self.assertEqual( + ep.module().state_dict().keys(), + mod.state_dict().keys(), + ) + + # Check results. + self.assertEqual(len(ep_out), len(orig_out)) + for ep_t, orig_t in zip(ep_out, orig_out): + if isinstance(ep_t, torch.Tensor): + self.assertEqual(ep_t.shape, orig_t.shape) + self.assertTrue(torch.allclose(ep_t, orig_t)) + else: + self.assertEqual(ep_t, orig_t) + return ep + def test_ts2ep_converter_basic(self): - class Module(torch.nn.Module): + class MSingle(torch.nn.Module): def forward(self, x, y): return x + y - m = Module() + class MMulti(torch.nn.Module): + def forward(self, x, y): + x = x.cos() + 1 + y = y.sin() - 1 + return x, y + inp = (torch.ones(1, 3), torch.ones(1, 3)) + self._check_equal_ts_ep_converter(MSingle(), inp) + self._check_equal_ts_ep_converter(MMulti(), inp) - ts_model = torch.jit.script(m) - ep = TS2EPConverter(ts_model, inp).convert() + def test_ts2ep_converter_container_output(self): + # Output is a List. + class MOutputList(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x * x + b = y + y + return [a, b] + + # Output is a Tuple. + class MOutputTuple(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x * x + b = y + y + return (a, b) + + # Output is a Dict. + class MOutputDict(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x * x + b = y + y + return {"data": {"mul": a, "add": b}} + + inp = (torch.tensor(4), torch.tensor(4)) + + self._check_equal_ts_ep_converter(MOutputList(), inp) + self._check_equal_ts_ep_converter(MOutputTuple(), inp) + self._check_equal_ts_ep_converter(MOutputDict(), inp) + + def test_aten_dim(self): + class Module(torch.nn.Module): + def forward(self, x): + num_dim = x.dim() + return torch.ones(num_dim) + + inp = (torch.ones(1, 3),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten_len(self): + class Module(torch.nn.Module): + def forward(self, x): + length = len(x) + return torch.ones(length) + + inp = (torch.ones(2, 3),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___getitem___list(self): + class Module(torch.nn.Module): + def forward(self, x): + y = torch.split(x, 2) + return y[0] + + inp = (torch.rand((3, 2)),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___getitem___dict(self): + class Module(torch.nn.Module): + def forward(self, x): + y = torch.split(x, 2) + d_int = {0: y[0], 1: y[1]} + d_str = {"0": y[0], "1": y[1]} + d_bool = {True: y[0], False: y[1]} + d_float = {0.1: y[0], 2.3: y[1]} + return d_int[0], d_str["0"], d_bool[True], d_float[0.1] + + inp = (torch.rand((3, 2)),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_device(self): + class Module(torch.nn.Module): + def forward(self, x): + device = x.device + return torch.ones(2, 3, device=device) + + inp = (torch.rand(3, 4),) + self._check_equal_ts_ep_converter(Module(), inp) + + @requires_cuda + def test_prim_device_cuda(self): + class Module(torch.nn.Module): + def forward(self, x): + device = x.device + return torch.ones(2, 3, device=device) + + inp = (torch.rand((3, 4), device="cuda:0"),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_dtype(self): + class Module(torch.nn.Module): + def forward(self, x): + dtype = x.dtype + return torch.ones(2, 3, dtype=dtype) + + for dtype in [ + torch.float32, + torch.double, + ]: + inp = (torch.rand((3, 4), dtype=dtype),) + self._check_equal_ts_ep_converter(Module(), inp) + + for dtype in [ + torch.uint8, + torch.int8, + torch.int32, + ]: + inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_convert_if_basic(self): + class M(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + if x: + return y * y + else: + return y + y + + inp = (torch.tensor(True), torch.tensor(4)) + ep = self._check_equal_ts_ep_converter(M(), inp) + + torch.testing.assert_close( + ep.module()(torch.tensor(False), torch.tensor(4)), + M()(torch.tensor(False), torch.tensor(4)), + ) + + def test_convert_if_multiple_out(self): + class M(torch.nn.Module): + def true_fn(self, y, z): + return (z * z, z + z) + + def false_fn(self, y, z): + return (y * y * y, y + y) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + z = y * y + + if x: + res = self.true_fn(y, z) + else: + res = self.false_fn(y, z) + + return res[0] + res[1] + + inp = (torch.tensor(True), torch.tensor(4)) + ep = self._check_equal_ts_ep_converter(M(), inp) + + torch.testing.assert_close( + ep.module()(torch.tensor(False), torch.tensor(4)), + M()(torch.tensor(False), torch.tensor(4)), + ) + + def test_profiler__record_function(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + handle = torch.ops.profiler._record_function_enter_new("foo", None) + y = x * 2 + 4 + torch.ops.profiler._record_function_exit(handle) + return y + + x = torch.randn(10, 10) + self._check_equal_ts_ep_converter(Module(), (x,)) + + def test_aten_floordiv(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x // 2 + + x = torch.randn(10, 10) + self._check_equal_ts_ep_converter(Module(), (x,)) + + def test_aten___is__(self): + class Module(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[bool, torch.Tensor]: + z = x + 1 + return x is y, z + + inp = (torch.randn(10, 10), torch.rand(10, 10)) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___isnot__(self): + class Module(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[bool, torch.Tensor]: + z = x + 1 + return x is not y, z + + inp = (torch.randn(10, 10), torch.rand(10, 10)) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_aten___not__(self): + class Module(torch.nn.Module): + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[bool, torch.Tensor]: + z = x + 1 + return not (x is not y), z + + inp = (torch.randn(10, 10), torch.rand(10, 10)) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_ts2ep_converter_unpack(self): + class MUnpackList(torch.nn.Module): + def forward(self, x): + x, y = torch.split(x, 2) + return x + y + + class MUnpackTuple(torch.nn.Module): + def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): + x, y = x_tuple + x = x.cos() + return x + y + + inp = (torch.ones(4),) + self._check_equal_ts_ep_converter(MUnpackList(), inp) + inp = ((torch.zeros(1, 4), torch.ones(1, 4)),) + self._check_equal_ts_ep_converter(MUnpackTuple(), inp) + + def test_convert_nn_module_with_nested_param(self): + class M(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + class NestedM(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + self.m = M(dim) + + def forward(self, x: torch.Tensor): + return self.linear(self.m(x)) + + class SuperNestedM(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + self.m = NestedM(dim) + + def forward(self, x: torch.Tensor): + return self.linear(self.m(x)) + + inp = (torch.ones(3),) + orig_m = NestedM(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + orig_m = SuperNestedM(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + def test_convert_nn_module_with_nested_buffer(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + x + + class NestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m = M() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + self.m(x) + + class SuperNestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m = NestedM() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + self.m(x) + + inp = (torch.ones(1),) + orig_m = NestedM() + ep = self._check_equal_ts_ep_converter(orig_m, inp) + orig_m = SuperNestedM() + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + def test_convert_nn_module_with_nested_if_and_buffer(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + return self.w + x + + class NestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m1 = M() + self.m2 = M() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + if torch.sum(x) > 1: + return self.w + self.m1(x) + else: + return self.w + self.m2(x) + + # Super nested, parameters neeed to lifted + # multiple times. + class SuperNestedM(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.m1 = NestedM() + self.m2 = NestedM() + self.register_buffer("w", torch.randn(1)) + + def forward(self, x: torch.Tensor): + if torch.max(x) > 1: + return self.w + self.m1(x) + else: + return self.w + self.m2(x) + + # Super nested module testing. + inp = (torch.ones(1),) + orig_m = SuperNestedM() + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 1 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + def test_convert_nn_module_with_nested_if_and_param(self): + class M(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + class NestedM(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.m1 = M(dim) + self.m2 = M(dim) + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + if torch.sum(x) > 1: + return self.linear(self.m1(x)) + else: + return self.linear(self.m2(x)) + + # Super nested, parameters neeed to lifted + # multiple times. + class SuperNestedM1(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.m1 = NestedM(dim) + self.m2 = NestedM(dim) + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + if torch.max(x) > 1: + return self.linear(self.m1(x)) + else: + return self.linear(self.m2(x)) + + # Super nested, even the input needs to be + # lifted recursively due to value propogation optimiztaion. + class SuperNestedM2(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.m1 = NestedM(dim) + self.m2 = NestedM(dim) + self.linear = torch.nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor): + if torch.sum(x) > 1: + return self.linear(self.m1(x)) + else: + return self.linear(self.m2(x)) + + # Basic module testing. + inp = (torch.ones(3),) + orig_m = M(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 0.8 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + # Nested module testing. + inp = (torch.ones(3),) + orig_m = NestedM(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 0.8 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + # Super nested module testing. + inp = (torch.ones(3),) + orig_m = SuperNestedM1(3) + ep = self._check_equal_ts_ep_converter(orig_m, inp) + + t = inp[0] + t -= 0.8 + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) + + # # Super nested module testing. + # inp = (torch.ones(3),) + # orig_m = SuperNestedM2(3) + # ep = self._check_equal_ts_ep_converter(orig_m, inp) + + # t = inp[0] + # t -= 0.8 + # torch.testing.assert_close( + # ep.module()(*inp), + # orig_m(*inp), + # ) + + def test_ts2ep_converter_contains(self): + class MIn(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.dtype in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + class MNotIn(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.dtype in [-1] + + class MTensorIn(torch.nn.Module): + def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): + return x in x_dict + + inp = (torch.tensor(4),) + self._check_equal_ts_ep_converter(MIn(), inp) + self._check_equal_ts_ep_converter(MNotIn(), inp) + + inp = (torch.tensor(4), {torch.tensor(4): "foo"}) + self._check_equal_ts_ep_converter(MTensorIn(), inp) + inp = (torch.tensor(1), {torch.tensor(4): "foo"}) + self._check_equal_ts_ep_converter(MTensorIn(), inp) + + def test_ts2ep_converter_custom_op(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + torch.library.define( + "mylib::foo", + "(Tensor x) -> Tensor", + lib=lib, + ) + + # PyTorch custorm op implementation + @torch.library.impl( + "mylib::foo", + "CompositeExplicitAutograd", + lib=lib, + ) + def foo_impl(x): + return x + x + + # Meta function of the custom op. + @torch.library.impl_abstract( + "mylib::foo", + lib=lib, + ) + def foo_meta(x): + return x + x + + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.mylib.foo(x) + + inp = (torch.randn(3, 3),) + m = M() + self._check_equal_ts_ep_converter(m, inp) + + def test_convert_func_without_param(self): + def func1(x, y): + return x + y + + def func2(x, y): + if x.sum() > 0: + return x + y + else: + return x - y + + inp = ( + torch.tensor(1), + torch.tensor(1), + ) + self._check_equal_ts_ep_converter(func1, inp) + + ep = self._check_equal_ts_ep_converter(func2, inp) - torch.testing.assert_close(ep.module()(*inp)[0], m(*inp)) + t = inp[0] + t -= 1 + torch.testing.assert_close( + ep.module()(*inp), + func2(*inp), + ) if __name__ == "__main__": diff --git a/test/export/test_export.py b/test/export/test_export.py index 5c1dbe2602ca..19acbbca39f1 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -149,6 +149,7 @@ class Inp: NON_STRICT_SUFFIX = "_non_strict" RETRACEABILITY_SUFFIX = "_retraceability" +SERDES_SUFFIX = "_serdes" PREDISPATCH_SUFFIX = "_pre_dispatch" @@ -160,6 +161,10 @@ def is_retracebility_test(test_name): return test_name.endswith(RETRACEABILITY_SUFFIX) +def is_serdes_test(test_name): + return test_name.endswith(SERDES_SUFFIX) + + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestDynamismExpression(TestCase): def test_export_inline_constraints(self): @@ -1969,6 +1974,137 @@ def forward(self, x, y, z): dynamic_shapes = {"x": (3 * _dx - 1,), "y": (3 * _dx,), "z": (3 * _dx + 2,)} export(Foo(), inputs, dynamic_shapes=dynamic_shapes) + def test_refine_dynamic_shapes_from_suggested_fixes(self): + from torch.export.dynamic_shapes import ( + refine_dynamic_shapes_from_suggested_fixes, + ) + + def helper(model, inputs, dynamic_shapes): + # export, fail, parse & refine suggested fixes, re-export + try: + export(Foo(), inps, dynamic_shapes=dynamic_shapes) + raise Exception("should have raised constraint violation error") + except torch._dynamo.exc.UserError as exc: + new_shapes = refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + export(Foo(), inps, dynamic_shapes=new_shapes) + return new_shapes + + # specialize dims + derived dims + class Foo(torch.nn.Module): + def forward(self, x, y, z): + x0 = x + y[1:] + z[2:] + x1 = x @ torch.randn(4, 4) + return x0, x1 + + inps = ( + torch.randn( + 4, + ), + torch.randn( + 5, + ), + torch.randn( + 6, + ), + ) + dx = Dim("dx", max=16) + dynamic_shapes = {"x": (dx,), "y": (dx + 1,), "z": (dx + 2,)} + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual(new_shapes["x"][0], 4) + self.assertEqual(new_shapes["z"][0], 6) + + # refine lower, upper bound + class Foo(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] >= 6 and y.shape[0] <= 16: + return x * 2.0, y + 1 + + inps = (torch.randn(16), torch.randn(12)) + dynamic_shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)} + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual(new_shapes["x"][0].min, 6) + self.assertEqual(new_shapes["y"][0].max, 16) + + # divisiblity, will introduce new root + class Foo(torch.nn.Module): + def forward(self, x): + if x.shape[0] >= 9: + return x.reshape([-1, 3]) + + inps = ( + torch.randn( + 15, + ), + ) + dynamic_shapes = ((Dim("dx"),),) + new_shapes = helper(Foo(), inps, dynamic_shapes) + dim = new_shapes[0][0] + root = dim.root + self.assertEqual(dim.fn(2), 6) + self.assertEqual(root.min, 3) + + # turn dim into derived dim/relation + class Foo(torch.nn.Module): + def forward(self, x, y): + return x + y[4:] + + inps = (torch.randn(6, 4), torch.randn(10, 4)) + dynamic_shapes = { + "x": (Dim("dx0"), Dim("dx1")), + "y": (Dim("dy0"), Dim("dy1")), + } + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual(new_shapes["x"][0], new_shapes["y"][0].root) # dy0 = dx0 + 4 + self.assertEqual(new_shapes["y"][0].fn(5), 9) + self.assertEqual(new_shapes["x"][1], new_shapes["y"][1]) # dx1 = dy1 + + # nested dynamic shapes spec + class Foo(torch.nn.Module): + def forward(self, x, y): + x0 = x[0]["data"] + x[1] + x[2][2:] + x1 = y["a"] @ torch.randn(4, 4) + x2 = y["b"] @ torch.randn(6, 6) + return x0, x1, x2 + + inps = ( + [ + {"data": torch.randn(4, 4)}, + torch.randn(4, 4), + torch.randn(6, 4), + ], + { + "a": torch.randn(8, 4), + "b": torch.randn(9, 6), + }, + ) + dynamic_shapes = { + "x": [ + {"data": (Dim("dx00"), Dim("dx01"))}, + (Dim("dx10"), Dim("dx11")), + (Dim("dx20"), Dim("dx21")), + ], + "y": { + "a": (Dim("dya0"), Dim("dya1")), + "b": (Dim("dyb0"), Dim("dyb1")), + }, + } + new_shapes = helper(Foo(), inps, dynamic_shapes) + self.assertEqual( + new_shapes["x"][0]["data"][0], new_shapes["x"][1][0] + ) # dx10 = dx00 + self.assertEqual( + new_shapes["x"][2][0].root, new_shapes["x"][0]["data"][0] + ) # dx20 = dx00 + 2 + self.assertEqual(new_shapes["x"][2][0].fn(10), 12) + self.assertEqual( + new_shapes["x"][0]["data"][1], new_shapes["x"][1][1] + ) # dx11 = dx01 + self.assertEqual(new_shapes["y"]["a"][1], 4) + self.assertEqual(new_shapes["y"]["b"][1], 6) + self.assertEqual(new_shapes["y"]["b"][0].__name__, "dyb0") # unchanged + def test_dynamic_shapes_spec_with_pytree(self): from torch.export import Dim, export from torch.utils._pytree import tree_map @@ -3312,6 +3448,15 @@ def forward(self, x): "torch.ops.aten._assert_scalar.default", 1, exactly=True ).run(ep.graph_module.code) + ep = ep.run_decompositions() + + FileCheck().check_count( + "torch.ops.aten.sym_constrain_range.default", 1, exactly=True + ).run(ep.graph_module.code) + FileCheck().check_count( + "torch.ops.aten._assert_scalar.default", 1, exactly=True + ).run(ep.graph_module.code) + def test_non_arg_name_dynamic_shapes_api(self): class Foo(torch.nn.Module): def forward(self, a, b): @@ -5039,8 +5184,9 @@ def forward(self, x): export(f, (inputs,), dynamic_shapes=dynamic_shapes) def test_disable_forced_specializations(self): - # case 1 - # check disable_forced_specializations flag behaves correctly + # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags + # both behave correctly, avoiding forced specializations and deferring to runtime. + # case 1: modulo guards from torch.export import dims class Mod4Reshape(torch.nn.Module): @@ -5055,31 +5201,36 @@ def forward(self, x): r".*dx = .* must be specialized to 10 because the guards generated for it are too complex(.*\n)*" r".*dy = .* must be specialized to 72 because the guards generated for it are too complex(.*\n)*", ): - torch.export._trace._export( + export( Mod4Reshape(), inputs, dynamic_shapes={"x": (dx, dy)}, - strict=False, - _disable_forced_specializations=False, ) - ep = torch.export._trace._export( + + torch.export._trace._export( # just check this successfully compiles Mod4Reshape(), inputs, dynamic_shapes={"x": (dx, dy)}, strict=False, _disable_forced_specializations=True, ) + ep = torch.export._trace._export( + Mod4Reshape(), + inputs, + dynamic_shapes={"x": (dx, dy)}, + _allow_complex_guards_as_runtime_asserts=True, + ) out1 = ep.module()(torch.randn(8, 7)) self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape) - out2 = ep.module()(torch.randn(4, 3)) - self.assertEqual(out2.shape, torch.ones(3, 4, 1).shape) + out2 = ep.module()(torch.randn(12, 11)) + self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape) with self.assertRaisesRegex( RuntimeError, - r"shape .*7, 4, -1.* is invalid for input of size 64", + r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'", ): ep.module()(torch.randn(8, 8)) # fail - # case 2 + # case 2: 2d reshape class FreeReshape(torch.nn.Module): def forward(self, x, y, z): return x.reshape([-1]) + y.reshape([-1]) + z # s0*s1 = s2*s3 = s4 @@ -5090,9 +5241,9 @@ def forward(self, x, y, z): torch.randn(48), ) dynamic_shapes = { - "x": [Dim(f"dx{i}") for i in range(2)], - "y": [Dim(f"dy{i}") for i in range(2)], - "z": [Dim(f"dz{i}") for i in range(1)], + "x": [Dim(f"dx{i}", min=2) for i in range(2)], + "y": [Dim(f"dy{i}", min=2) for i in range(2)], + "z": [Dim(f"dz{i}", min=4) for i in range(1)], } with self.assertRaisesRegex( # this will force specialize torch._dynamo.exc.UserError, @@ -5100,32 +5251,85 @@ def forward(self, x, y, z): r".*dx0 = .* must be specialized to 6 because the guards generated for it are too complex(.*\n)*" r".*dx1 = .* must be specialized to 8 because the guards generated for it are too complex(.*\n)*", ): - torch.export._trace._export( + export( FreeReshape(), inputs, dynamic_shapes=dynamic_shapes, - strict=False, - _disable_forced_specializations=False, ) - ep = torch.export._trace._export( + torch.export._trace._export( FreeReshape(), inputs, dynamic_shapes=dynamic_shapes, strict=False, _disable_forced_specializations=True, ) + ep = torch.export._trace._export( + FreeReshape(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48)) self.assertEqual(out1.shape, torch.ones(48).shape) out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40)) self.assertEqual(out2.shape, torch.ones(40).shape) with self.assertRaisesRegex( RuntimeError, - r"The size of tensor a .* must match the size of tensor b .* at non-singleton dimension 0", + r"Runtime assertion failed for expression Eq\(s0\*s1 \- s2\*s3, 0\) on node 'eq.*'", ): # fail only at runtime ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail + # case 3: 3d reshape (previously failing with different issue) + class Reshape3d(torch.nn.Module): + def forward(self, x, y): + return x.reshape([-1]) + y # s0*s1*s2 = s3 + + inputs = ( + torch.randn(4, 3, 2), + torch.randn(24), + ) + dynamic_shapes = { + "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)), + "y": (Dim("dy", min=8),), + } + with self.assertRaisesRegex( # this will force specialize + torch._dynamo.exc.UserError, + r".*Specializations unexpectedly required(.*\n)*" + r"Suggested fixes:(.*\n)*" + r".*dx0 = 4(.*\n)*" + r".*dx1 = 3(.*\n)*" + r".*dx2 = 2(.*\n)*" + r".*dy = 24(.*\n)*", + ): + export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + ) + + torch.export._trace._export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + _disable_forced_specializations=True, + ) + ep = torch.export._trace._export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) + out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126)) + self.assertEqual(out1.shape, torch.ones(126).shape) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\(s0\*s1\*s2 \- s3, 0\) on node 'eq.*'", + ): # fail only at runtime + ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail + def test_disable_forced_specializations_errors(self): - # check error messages with disable_forced_specializations=False/True + # check error messages with disable_forced_specializations = False/True class Foo(torch.nn.Module): def forward(self, w, x, y, z): return w.reshape([-1]) + x, y + z # simple: s0*s1 = s2, s3 = s4 @@ -5142,7 +5346,7 @@ def forward(self, w, x, y, z): "y": [Dim("dy")], # y & z incorrect, export is supposed to fail. "z": [Dim("dz")], # suggested fix should be to match these up. } - with self.assertRaisesRegex( # if disable=False, suggested fixes should specialize 3, 4, 12. + with self.assertRaisesRegex( # if allow = False, suggested fixes should specialize 3, 4, 12. torch._dynamo.exc.UserError, r".*Specializations unexpectedly required(.*\n)*" r"Suggested fixes:(.*\n)*" @@ -5172,6 +5376,108 @@ def forward(self, w, x, y, z): _disable_forced_specializations=True, ) + def test_reshape_view_helper(self): + # see: https://github.com/pytorch/pytorch/issues/126607 + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.view(x.size(1), -1) + # torch/_refs/__init__/_reshape_view_helper() will generate guards on reshape kernel(?) + # Ne(s0, 20), so that reshape isn't no-op + # Ne(Mod(s0, 20), 0), so that reshape needs to first flatten [s0, 20, 16] -> [s0*20, 16] + # then split_dim -> [20, s0, 16] + # check that these show up in graph + return torch.nn.functional.softmax( + x, dim=0 + ) # don't think softmax actually creates any issues, just part of original test + + model = Model() + x = torch.rand(1024, 20, 16) + dynamic_shapes = {"x": {0: Dim("batch")}} + ep = torch.export._trace._export( + model, + (x,), + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(s0, 20\)", + ): + ep.module()(torch.randn(20, 20, 16)) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)", + ): + ep.module()(torch.randn(400, 20, 16)) + ep.module()(torch.randn(42, 20, 16)) + + def test_allow_explicit_guards_as_runtime_asserts(self): + # check that explicit guards are treated as runtime assertions + class Foo(torch.nn.Module): + def forward(self, x, y): + # check that negation of first guard also shows up as runtime assertion + if x.shape[0] == y.shape[0]: # False + return x + y + elif x.shape[0] == y.shape[0] ** 3: # False + return x + 2, y + 3 + elif x.shape[0] ** 2 == y.shape[0] * 3: # True + return x * 2.0, y * 3.0 + + inputs = (torch.randn(6), torch.randn(12)) + dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]} + ep = torch.export._trace._export( + Foo(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ) + # check forward pass + out0, out1 = ep.module()(torch.randn(9), torch.randn(27)) + self.assertEqual(out0.shape, torch.ones(9).shape) + self.assertEqual(out1.shape, torch.ones(27).shape) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(s0 \- s1, 0\)", + ): # fail only at runtime + ep.module()(torch.randn(4), torch.randn(4)) # fail + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(s0 \- s1\**3, 0\)", + ): + ep.module()(torch.randn(64), torch.randn(4)) # fail + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\(s0\**2 \- 3\*s1, 0\)", + ): + ep.module()(torch.randn(10), torch.randn(9)) # fail + + # this should be set with command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1, + # but dynamo checks that at torch import time, so setting os.environ makes no difference + # instead, manually patch dynamo config and test. + # test that setting this flag removes runtime asserts + from torch._dynamo import config as _dynamo_config + + with _dynamo_config.patch( + do_not_emit_runtime_asserts=True, + ): + ep = torch.export._trace._export( + Foo(), + inputs, + dynamic_shapes=dynamic_shapes, + _allow_complex_guards_as_runtime_asserts=True, + ).run_decompositions() + + self.assertEqual( + [ + node.target == torch.ops.aten._assert_scalar.default + for node in ep.graph.nodes + ].count(True), + 0, + ) + def test_constant_aliasing(self): class M1(torch.nn.Module): def __init__(self, m2, foo): diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index bd11cd7f8366..52848134721f 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -23,14 +23,12 @@ def mocked_serder_export(*args, **kwargs): def make_dynamic_cls(cls): - suffix = "_serdes" - cls_prefix = "SerDesExport" test_class = testing.make_test_cls_with_mocked_export( cls, cls_prefix, - suffix, + test_export.SERDES_SUFFIX, mocked_serder_export, xfail_prop="_expected_failure_serdes", ) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index b8ed2ef69f53..012b35c910b5 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -806,6 +806,49 @@ def forward(self, x): dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}} self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes) + def test_multiple_getitem(self): + class M(torch.nn.Module): + def forward(self, x): + a, b = torch.topk(x, 2) + a = a * 2 + return a, b + + ep = torch.export.export(M(), (torch.ones(3),)) + + # insert another getitem node + for node in ep.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor: + getitem_0 = node.args[0] + with ep.graph.inserting_before(getitem_0): + getitem_copy = ep.graph.node_copy(getitem_0) + mul_node = ep.graph.call_function( + torch.ops.aten.mul.Tensor, (getitem_copy, 2) + ) + mul_node.meta = copy.copy(getitem_copy.meta) + node.args = (getitem_0, mul_node) + + deserialized_ep = deserialize(serialize(ep)) + + inp = (torch.randn(3),) + orig_res = ep.module()(*inp) + res = deserialized_ep.module()(*inp) + self.assertTrue(torch.allclose(orig_res[0], res[0])) + self.assertTrue(torch.allclose(orig_res[1], res[1])) + + # The deserialized graph should have deduped getitem calls + self.assertExpectedInline( + deserialized_ep.graph_module.code.strip("\n"), + """\ +def forward(self, x): + topk_default = torch.ops.aten.topk.default(x, 2); x = None + getitem = topk_default[0] + getitem_1 = topk_default[1]; topk_default = None + mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2) + mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None + return (mul, getitem_1) + """, + ) + @parametrize( "name,case", get_filtered_export_db_tests(), diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 3e4a78c61769..42c87bf4c10e 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -188,6 +188,19 @@ def forward(self, obj_attr, x, n): return (add,)""", ) + def test_method_schema(self): + tq = _empty_tensor_queue() + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + fake_obj = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq) + self.assertExpectedInline( + str(fake_obj.push.schema), + """push(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, Tensor _1) -> NoneType _0""", + ) + self.assertExpectedInline( + str(fake_obj.pop.schema), + """pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0) -> Tensor _0""", + ) + @parametrize("pre_dispatch", [True, False]) def test_attribute(self, pre_dispatch): class MyModule(torch.nn.Module): @@ -255,15 +268,8 @@ def forward(self, token, obj_attr, x): ) @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_input(self, pre_dispatch, fakify_script_obj): + def test_input(self, pre_dispatch): cc = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = cc._type().qualified_name() # type: ignore[att-defined] - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) class MyModule(torch.nn.Module): def __init__(self): @@ -295,19 +301,11 @@ def forward(self, x, cc): # aot_export_function runs the program twice # in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function # We also have a re-tracing test, which doubles the count. - if fakify_script_obj: - self.assertEqual(self.foo_add_tensor_counter, 4) + self.assertEqual(self.foo_add_tensor_counter, 4) @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_input_as_custom_op_argument(self, pre_dispatch, fakify_script_obj): + def test_input_as_custom_op_argument(self, pre_dispatch): cc = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = cc._type().qualified_name() # type: ignore[att-defined] - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) class MyModule(torch.nn.Module): def __init__(self): @@ -322,16 +320,13 @@ def forward(self, x, cc): torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear() # Even though a C++ implementation for takes_foo.default is registered, # we still need the python implementation for takes_foo.default to trace with FakeFoo. - if fakify_script_obj: - with self.assertRaisesRegex( - RuntimeError, "no python implementation is found" - ): - self._test_export_same_as_eager( - MyModule(), - (torch.ones(2, 3), cc), - strict=False, - pre_dispatch=pre_dispatch, - ) + with self.assertRaisesRegex(RuntimeError, "no python implementation is found"): + self._test_export_same_as_eager( + MyModule(), + (torch.ones(2, 3), cc), + strict=False, + pre_dispatch=pre_dispatch, + ) torch.ops._TorchScriptTesting.takes_foo.default.py_impl( torch._C.DispatchKey.Meta @@ -364,8 +359,7 @@ def forward(self, token, x, cc): ) @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_torchbind_alias(self, pre_dispatch, fakify_script_obj): + def test_torchbind_alias(self, pre_dispatch): class F2(torch.nn.Module): def __init__(self, foo): super().__init__() @@ -378,12 +372,6 @@ class F1(torch.nn.Module): def __init__(self): super().__init__() self.alpha = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = self.alpha._type().qualified_name() - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) self.beta = self.alpha self.gamma = self.alpha self.foo = F2(self.gamma) @@ -402,8 +390,7 @@ def forward(self, x): # TODO(pianpwk): look into this @unittest.expectedFailure @parametrize("pre_dispatch", [True, False]) - @parametrize("fakify_script_obj", [True, False]) - def test_torchbind_input_and_alias(self, pre_dispatch, fakify_script_obj): + def test_torchbind_input_and_alias(self, pre_dispatch): # alias as model attribute class F3(torch.nn.Module): def forward(self, x, foo): @@ -411,12 +398,6 @@ def forward(self, x, foo): return x + self.foo.add_tensor(x) foo = torch.classes._TorchScriptTesting._Foo(10, 20) - if not fakify_script_obj: - qual_name = foo._type().qualified_name() # type: ignore[att-defined] - if torch._library.fake_class_registry.has_fake_class(qual_name): - torch._library.fake_class_registry.deregister_fake_class( - "_TorchScriptTesting::_Foo" - ) self._test_export_same_as_eager( F3(), (torch.ones(2, 3), foo), strict=False, pre_dispatch=pre_dispatch ) @@ -939,8 +920,10 @@ def size(self): def tearDown(self): torch._dynamo.reset() - def test_compile_script_object_input(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_script_object_input(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() class Model(torch.nn.Module): def __init__(self): @@ -984,23 +967,25 @@ def forward(self, tq, x): # does not return L_tq_ as output. This is because it's able # to detect that L_tq_ is an input therefore don't return # it as graph output. Related logic is in dynamo/codegen.py - self.assertExpectedInline( - backend.graphs[0].code.strip(), - """\ -def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor): - l_tq_ = L_tq_ - l_x_ = L_x_ - cos = l_x_.cos() - call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None - sin = l_x_.sin(); l_x_ = None - call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None - call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop') - call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None - x_sin = call_torchbind_2 - 1; call_torchbind_2 = None - return (x_sin,)""", - ) + if backend == "eager": + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ + def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor): + l_tq_ = L_tq_ + l_x_ = L_x_ + cos = l_x_.cos() + call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None + sin = l_x_.sin(); l_x_ = None + call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None + call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop') + call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None + x_sin = call_torchbind_2 - 1; call_torchbind_2 = None + return (x_sin,)""", + ) - def test_compile_script_object_input_guards(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_script_object_input_guards(self, backend): class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -1013,7 +998,7 @@ def forward(self, tq, x): return x_sin, tq mod = Model() - cnt = torch._dynamo.testing.CompileCounter() + cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) x = torch.randn(2, 3) tq1 = _empty_tensor_queue() @@ -1084,8 +1069,10 @@ def forward(self, tq, x): torch.compile(mod, backend=cnt)(tq3, x) self.assertEqual(cnt.frame_count, 2) - def test_compile_error_on_input_aliasing_contents(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_input_aliasing_contents(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() class Model(torch.nn.Module): def __init__(self): @@ -1103,7 +1090,11 @@ def forward(self, tq, x): with self.assertRaisesRegex(RuntimeError, "is alising"): torch.compile(mod, backend=backend)(tq1, x) - def test_compile_error_on_script_obj_setattr(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_script_obj_setattr(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() + def setattr_f(tq): tq.a = 1 return tq @@ -1111,19 +1102,25 @@ def setattr_f(tq): with self.assertRaisesRegex( RuntimeError, "call method __setattr__ on script object is not safe" ): - torch.compile(setattr_f, backend="eager")(_empty_tensor_queue()) + torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_script_obj_missing_attr(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() - def test_compile_error_on_script_obj_missing_attr(self): def setattr_f(tq): return tq._not_defined_attr with self.assertRaisesRegex( RuntimeError, "doesn't define method _not_defined_attr" ): - torch.compile(setattr_f, backend="eager")(_empty_tensor_queue()) + torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) - def test_compile_body_aliasing_contents(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_body_aliasing_contents(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() def f(tq, x): x1 = x.view(-1) @@ -1138,7 +1135,7 @@ def f(tq, x): f(_empty_tensor_queue(), x), torch.compile(f, backend=backend)(_empty_tensor_queue(), x), ) - if not torch._dynamo.is_compiling(): + if not torch._dynamo.is_compiling() and backend == "eager": self.assertExpectedInline( backend.graphs[0].code.strip(), """\ @@ -1156,8 +1153,10 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject): return (sub, add)""", ) - def test_compile_error_on_non_fakified_method(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_error_on_non_fakified_method(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() def f(tq, x): x1 = x.view(-1) @@ -1175,7 +1174,8 @@ def f(tq, x): ): torch.compile(f, backend=backend)(_empty_tensor_queue(), x) - def test_compile_obj_as_hop_input(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_as_hop_input(self, backend): def f(tq, x): def fn(tq, x): tq.push(x) @@ -1187,10 +1187,11 @@ def fn(tq, x): _assertEqualScriptObject( self, f(_empty_tensor_queue(), x), - torch.compile(f, backend="eager")(_empty_tensor_queue(), x), + torch.compile(f, backend=backend)(_empty_tensor_queue(), x), ) - def test_compile_obj_closure(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_closure(self, backend): def f(x): def inner_f(x): tq.push(x.sin()) @@ -1204,7 +1205,8 @@ def inner_f(x): x = torch.randn(3, 2) _assertEqualScriptObject(self, f(x), opt_f(x)) - def test_compile_global_obj(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_global_obj(self, backend): global _TENSOR_QUEUE_GLOBAL_TEST _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() @@ -1212,7 +1214,7 @@ def f(x): _TENSOR_QUEUE_GLOBAL_TEST.push(x.sin()) return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST - opt_f = torch.compile(f, backend="eager") + opt_f = torch.compile(f, backend=backend) x = torch.randn(3, 2) eager_ret = f(x) opt_ret = opt_f(x) @@ -1239,8 +1241,10 @@ def f(tq, x): ) self.assertEqual(cnt.frame_count, 4) - def test_compile_obj_attributes(self): - backend = EagerAndRecordGraphs() + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_attributes(self, backend): + if backend == "eager": + backend = EagerAndRecordGraphs() class Model(torch.nn.Module): def __init__(self): @@ -1254,21 +1258,23 @@ def forward(self, x): x = torch.randn(2, 3) opt_f = torch.compile(Model(), backend=backend) _assertEqualScriptObject(self, Model()(x), opt_f(x)) - self.assertEqual(len(backend.graphs), 1) - # lifted as input. In the future, we would want to cosolidate this - # with non-strict behavior, where they're set as attributes. - self.assertExpectedInline( - backend.graphs[0].code.strip(), - """\ -def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor): - l_self_tq = L_self_tq - l_x_ = L_x_ - call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None - call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None - return (call_torchbind_1,)""", - ) + if backend == "eager": + self.assertEqual(len(backend.graphs), 1) + # lifted as input. In the future, we would want to cosolidate this + # with non-strict behavior, where they're set as attributes. + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ + def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor): + l_self_tq = L_self_tq + l_x_ = L_x_ + call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None + call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None + return (call_torchbind_1,)""", + ) - def test_compile_obj_torchbind_op(self): + @parametrize("backend", ["eager", "aot_eager"]) + def test_compile_obj_torchbind_op(self, backend): def f(tq, x): torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1) @@ -1276,7 +1282,7 @@ def f(tq, x): torch.ops._TorchScriptTesting.queue_push(tq, x.sin()) return tq.pop(), tq.pop() + tq.size(), tq - opt_f = torch.compile(f, backend="eager") + opt_f = torch.compile(f, backend=backend) x = torch.randn(2) _assertEqualScriptObject( self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) @@ -1334,6 +1340,7 @@ def __obj_unflatten__(cls, flattend_foo): instantiate_parametrized_tests(TestExportTorchbind) +instantiate_parametrized_tests(TestCompileTorchbind) if __name__ == "__main__": run_tests() diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 3ca58e8fff79..383287db421a 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -312,6 +312,31 @@ def forward(self, x): export_module.module(), unflattened, (torch.randn((2, 3)),) ) + @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") + def test_unflatten_preserve_with_unused_input(self): + class M1(torch.nn.Module): + def forward(self, x, a, b): + return x + a, b + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.m1 = M1() + + def forward(self, x, y): + a, b = torch.topk(y, 2) + return self.m1(x, a, b)[0] + + ep = torch.export.export( + M(), + (torch.randn(2), torch.randn(5)), + preserve_module_call_signature=("m1",), + strict=False, + ) + ep.graph.eliminate_dead_code() + unflattened = unflatten(ep) + self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) + def test_unflatten_wrong_input(self): class Mod(torch.nn.Module): def __init__(self): @@ -747,6 +772,28 @@ def forward(self, x): unep = unflatten(ep) self.assertTrue(torch.allclose(unep(*inps), m(*inps))) + def test_attr_as_submod_input(self): + class layer(torch.nn.Module): + def forward(self, x, const) -> torch.Tensor: + return x + const + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("const", torch.ones(4, 8)) + self.layers = torch.nn.ModuleList([layer() for _ in range(2)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = layer(x, self.const) + return x + + mod = M() + x = torch.randn(4, 8) + ep = export(mod, (x,)) + unflattened = unflatten(ep) + torch.testing.assert_close(unflattened(x), mod(x)) + if __name__ == "__main__": run_tests() diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 81b85a4fe42f..88927e8bf7ce 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -135,7 +135,8 @@ ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), - ("aten::_efficient_attention_forward", datetime.date(2024, 1, 15)), + ("aten::_efficient_attention_forward", datetime.date(2024, 7, 1)), + ("aten::_efficient_attention_backward", datetime.date(2024, 7, 1)), ("onednn::qconv1d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py new file mode 100644 index 000000000000..a9b1d00b9929 --- /dev/null +++ b/test/functorch/test_ac.py @@ -0,0 +1,302 @@ +# Owner(s): ["oncall: pt2"] +import random + +import torch +import torch._functorch.config as config +from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase +from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.utils.flop_counter import FlopCounterMode + + +def compile_with_ac(f, memory_budget): + return torch.compile(f, backend="aot_eager_decomp_partition") + + +def get_act_mem(f): + out = f() + out.backward() + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + +def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + +def create_pair(B_I, O): + # results in B_I * O memory, requires B_I * B_I * O flops + # arithmetic intensity of B_I + x = torch.randn(B_I * 512, B_I * 512, requires_grad=True) + w = torch.randn(B_I * 512, O * 512, requires_grad=True) + return x, w + + +def get_mem_and_flops(f, memory_budget=None): + # Returns megabytes rounded to 1 decimal point and FLOPs + # Note that each value of size (512, 512, torch.float32) is 1 MiB + torch._dynamo.reset() + with config.patch(activation_memory_budget=memory_budget): + if memory_budget is not None: + f = torch.compile(f, backend="aot_eager_decomp_partition") + + # We round this to nearest 10th of a megabyte. + return round(get_act_mem(f), 1), get_bw_flops(f) + + +class MemoryBudgetTest(TestCase): + def setUp(self): + super().setUp() + torch.set_default_device("cuda") + + def test_rematerializes_cheap(self): + def f(x, w): + x = x.cos() + x = torch.mm(x, w) + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + w = torch.randn(512, 512, requires_grad=True) + + def call(): + return f(x, w) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 1.0) + mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0) + # Recomputing `.cos()` is not free here. + self.assertEqual(mem_10, 1.0) + self.assertEqual(eager_flops, flops_10) + mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5) + # We can just recompute `x.cos()` here to only depend on the inputs + self.assertEqual(mem_5, 0.0) + self.assertEqual(flops_5, eager_flops) + + def test_matmul_even_chain(self): + def f(x, ws): + x = x.cos() + for w in ws: + x = torch.mm(x, w).cos() + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + for budget in range(0, 11): + mem, flops = get_mem_and_flops(call, memory_budget=budget / 10) + if budget <= 5: + # We start saving the matmuls + self.assertEqual(mem, budget) + self.assertEqual(flops, eager_flops + (5 - budget)) + elif budget < 10: + # We're only recomputing the `cos` operations + self.assertEqual(mem, 5.0) + self.assertEqual(flops, eager_flops) + elif budget == 10: + self.assertEqual(mem, 10.0) + self.assertEqual(flops, eager_flops) + + def test_matmul_uneven_chain(self): + # This function is constructed so that we are saving one input of size + # [512, in_dim] for each w + # In addition, every matmul has a same ratio of compute to "memory + # saved", so this test is essentially testing our knapsack solving + + def f(x, ws): + xs = [torch.mm(x, w).cos() for w in ws] + return sum([x.sum() for x in xs]) + + x = torch.randn(512, 512, requires_grad=True) + + def make_weights(w_shapes): + ws = [] + for idx, dim in enumerate(w_shapes): + ws.append(torch.randn(512, dim * 512, requires_grad=True)) + return ws + + def make_weights_chain(w_shapes): + ws = [] + for idx, _ in enumerate(w_shapes): + old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512 + new_dim = w_shapes[idx] * 512 + ws.append(torch.randn(old_dim, new_dim, requires_grad=True)) + return ws + + weight_configs = [ + ( + [11, 3, 4, 2], + [ + 18, # 11 + 4 + 3 + 17, # 11 + 4 + 2 + 16, # 11 + 3 + 2 + 15, # 11 + 4 + 14, # 11 + 3 + 13, # 11 + 2 + 11, # 11 + 2 + 7, # 4 + 3 + 6, # 4 + 2 + 5, # 3 + 2 + ], + ), + ( + [3, 5, 11, 17, 14], + [ + 42, # 17 + 14 + 9 + 30, # 11 + 15 + 5 + 19, # 11 + 5 + 3 + 8, # 5 + 3 + 3, # 3 + ], + ), + ] + random.seed(0) + random_arr = [random.randint(0, 50) for _ in range(10)] + exact_sums = [] + for i in range(10): + random.shuffle(random_arr) + exact_sums.append(sum(random_arr[:i])) + weight_configs.append((random_arr, exact_sums)) + + for weight_shapes, exact_solves in weight_configs: + ws = make_weights(weight_shapes) + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + total_mem = sum(weight_shapes) + self.assertEqual(eager_mem, sum(weight_shapes)) + for mem_achieved in exact_solves: + mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem) + self.assertEqual(mem, mem_achieved) + + def test_prioritize_cheaper_matmul(self): + def f(xs, ws): + xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] + return sum([x.sum() for x in xs]) + + x1, w1 = create_pair(1, 4) + x2, w2 = create_pair(2, 2) + + def call(): + return f([x1, x2], [w1, w2]) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 8) + self.assertEqual(eager_flops, 24) + comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5) + self.assertEqual(comp_mem, 4) + # We are recomputing x1 @ w1 here! + self.assertEqual(comp_flops, eager_flops + 4) + + @config.patch(activation_memory_budget_runtime_estimator="profile") + def test_profile(self): + def f(x, ws): + x = x.cos() + for w in ws: + x = torch.mm(x, w).cos() + return x.sum() + + x = torch.randn(512, 512, requires_grad=True) + ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] + + def call(): + return f(x, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + mem, flops = get_mem_and_flops(call, memory_budget=0.2) + # We start saving the matmuls + self.assertEqual(mem, 2) + self.assertEqual(flops, eager_flops + 3) + + def test_prioritize_cheaper_matmul2(self): + def f(xs, ws): + xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] + return sum([x.sum() for x in xs]) + + data = [(4, 4), (6, 2), (2, 6)] + xs, ws = zip(*[create_pair(a, b) for a, b in data]) + + def call(): + return f(xs, ws) + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, 40) + self.assertEqual(eager_flops, 320) + mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem) + # Save w1 and w2 + self.assertEqual(mem, 28) + # We're recomputing w3 (the cheap one!) + self.assertEqual(flops - eager_flops, 2 * 2 * 6) + mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem) + # Save w2. Note that even though saving w1 gets us closer to our memory + # limit, w2 is actually *more* FLOPs than w1! + self.assertEqual(mem, 12) + self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4) + + def test_attention_vs_linear(self): + def f(x, w): + orig_shape = x.shape + x = x.reshape(1, 1, x.shape[0], x.shape[1]) + # I know this isn't technically right lol + x = torch.nn.functional.scaled_dot_product_attention( + x, x, x, is_causal=False + ).reshape(*orig_shape) + x = torch.mm(x, w) + x = x.cos() + return x.sum() + + def try_seq_length(S, D, expected_recompute): + x = torch.randn(S * 512, D * 512, requires_grad=True) + w = torch.randn(D * 512, D * 512, requires_grad=True) + + def call(): + return f(x, w) + + with FlopCounterMode(display=False) as mode: + call() + mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm] + attn_flops = mode.get_total_flops() - mm_flops + mm_flops /= 512**3 * 2 + attn_flops /= 512**3 * 2 + + eager_mem, eager_flops = get_mem_and_flops(call) + self.assertEqual(eager_mem, S * D * 2) + + mem, flops = get_mem_and_flops( + call, memory_budget=0.6 + ) # Force it to recompute one of mm or attn + self.assertEqual(mem, S * D) + if expected_recompute == "attn": + expected_flops = attn_flops + else: + expected_flops = mm_flops + self.assertEqual(flops - eager_flops, expected_flops) + + # General behind this test is that if sequence length * 2 > D, then + # attention is more expensive than the linear. + try_seq_length(1, 1, "mm") + try_seq_length(1, 3, "attn") + try_seq_length(2, 2, "mm") + try_seq_length(2, 1, "mm") + try_seq_length(2, 5, "attn") + try_seq_length(4, 7, "mm") + try_seq_length(4, 9, "attn") + + +if __name__ == "__main__": + # I'm using the cuda memory allocator to verify memory allocations + if HAS_CUDA and not TEST_WITH_ROCM: + run_tests() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index a3ebb9eb08a5..7bce7d558abb 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -11,7 +11,7 @@ import unittest import warnings from contextlib import nullcontext -from functools import partial +from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch @@ -26,6 +26,7 @@ from functorch.compile import ( aot_function, aot_module, + aot_module_simplified, compiled_function, compiled_module, default_decompositions, @@ -39,11 +40,7 @@ ) from functorch.experimental import control_flow from torch._decomp import decomposition_table -from torch._functorch.aot_autograd import ( - aot_export_joint_simple, - aot_export_module, - aot_module_simplified, -) +from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._higher_order_ops.out_dtype import out_dtype from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node @@ -71,6 +68,7 @@ skipIfRocm, skipIfTorchDynamo, TestCase, + xfail_inherited_tests, xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db @@ -288,7 +286,62 @@ def is_in_base(t, maybe_tensors): return False +def skipIfDynamoInput(reason): + """ + Skip TestAOTAutograd if running with dynamo input + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if isinstance(self, TestAOTAutogradWithDynamo): + self.skipTest( + f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}" + ) + else: + func(self, *args, **kwargs) + + return wrapper + + return decorator + + class TestAOTAutograd(AOTTestCase): + def run_autograd( + self, + f: Callable, + fw_graph_cell: List[Optional[Callable]], + decompositions: Optional[Dict], + keep_input_mutations: bool, + dynamic: bool, + ): + """ + Runs aot_autograd with the specified settings on f. + """ + if isinstance(f, nn.Module): + compiled_f = aot_module( + f, + fw_compiler=make_boxed_compiler( + partial(extract_graph, graph_cell=fw_graph_cell) + ), + bw_compiler=nop, + decompositions=decompositions, + keep_inference_input_mutations=keep_input_mutations, + dynamic=dynamic, + ) + else: + compiled_f = aot_function( + f, + fw_compiler=make_boxed_compiler( + partial(extract_graph, graph_cell=fw_graph_cell) + ), + bw_compiler=nop, + decompositions=decompositions, + keep_inference_input_mutations=keep_input_mutations, + dynamic=dynamic, + ) + return compiled_f + # test_mutation will: # - Ensure that inputs are non-leaves, so our graphs can mutate them # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) @@ -349,28 +402,9 @@ def verify_aot_autograd( graph_inps = inp graph_inps_copy = inp_copy fw_graph_cell = [None] - if isinstance(f, nn.Module): - compiled_f = aot_module( - f, - fw_compiler=make_boxed_compiler( - partial(extract_graph, graph_cell=fw_graph_cell) - ), - bw_compiler=nop, - decompositions=decompositions, - keep_inference_input_mutations=keep_input_mutations, - dynamic=dynamic, - ) - else: - compiled_f = aot_function( - f, - fw_compiler=make_boxed_compiler( - partial(extract_graph, graph_cell=fw_graph_cell) - ), - bw_compiler=nop, - decompositions=decompositions, - keep_inference_input_mutations=keep_input_mutations, - dynamic=dynamic, - ) + compiled_f = self.run_autograd( + f, fw_graph_cell, decompositions, keep_input_mutations, dynamic + ) ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) self.assertEqual(ref_grad, test_grad) @@ -537,6 +571,9 @@ def f(a, b): ] self.verify_aot_autograd(f, inp, keep_inp_mutations=True) + @skipIfDynamoInput( + "Test doesn't make sense with dynamo, which changes order of mutations" + ) def test_set__and_data_mutation_good(self): def f(a, b): # The data mutation happens *after* the set_(). This is ok (see the graph below) @@ -601,6 +638,9 @@ def f(a): f, inp, test_mutation=True, keep_inp_mutations=True ) + @skipIfDynamoInput( + "Test doesn't make sense with dynamo, which changes order of mutations" + ) def test_set__not_allowed(self): def f(a, b): with torch.no_grad(): @@ -678,8 +718,6 @@ def f(a): out_ref = f(ref_view) out_test = f_compiled(test_view) - print(ref) - print(test) self.assertEqual(ref, test) def test_input_mutation_modifies_autograd_meta_of_aliases(self): @@ -1919,6 +1957,7 @@ def forward(self, primals_1, primals_2): # One gets a data mutation, the other gets a metadata mutation. # We need to make sure that the metadata mutation gets propagated # back to the original input. + @skipIfDynamoInput("Dynamo removes runtime error") def test_input_data_and_metadata_mutation_aliases_other_input(self): # a and b are aliased def f(a, b): @@ -2524,6 +2563,7 @@ def forward(self, primals_1, primals_2): return [as_strided_scatter, add, add_1]""", ) # noqa: B950 + @skipIfDynamoInput("Fails with dynamo") def test_input_mutation_aliases_bases_out_of_order(self): # This tests our calling convention: if b and d are aliased, then the outer calling convention # that we send to the compiled forward becomes: @@ -5291,6 +5331,34 @@ def f(a, b): self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) + def test_aot_dispatch_output_requires_grad_in_no_grad(self): + def fn(x): + out1 = x.sin() + with torch.enable_grad(): + out2 = x.cos() + return out1, out2 + + inp_fns = [ + lambda: torch.ones(10, requires_grad=True), + lambda: torch.ones(10, requires_grad=False), + ] + + compiled_f = aot_function(fn, nop) + for inp_fn in inp_fns: + with torch.no_grad(): + ref_x = inp_fn() + ref_out = fn(ref_x) + x = inp_fn() + out = compiled_f(x) + for r, o in zip(ref_out, out): + self.assertEqual(r.requires_grad, o.requires_grad) + if ref_x.requires_grad: + with torch.enable_grad(): + (ref_out[0] + ref_out[1]).sum().backward() + (out[0] + out[1]).sum().backward() + self.assertEqual(ref_x.grad, x.grad) + assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3) + class TestAOTModuleSimplified(AOTTestCase): def test_aot_module_simplified(self): @@ -5821,5 +5889,64 @@ def test_aot_autograd_symbolic_module_exhaustive( instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for) +@xfail_inherited_tests( + [ + "test_set__and_data_mutation_bad", + "test_subclass_metadata_mutation_req_grad_True", + "test_subclass_metadata_mutation_req_grad_False", + ] +) +@skipIfTorchDynamo("This test suite already uses dynamo") +class TestAOTAutogradWithDynamo(TestAOTAutograd): + """ + These are the same as TestAOTAutograd tests, but we run dynamo first to get a graph module. + """ + + def assertExpectedInline(self, *args, **kwargs): + # These will have different outputs because dynamo returns a different graph module + # But we don't really care about that assertion when testing with dynamo, + # only that the outputs match, etc. + pass + + # Compiler to passes to dynamo + def run_autograd( + self, + f: Callable, + fw_graph_cell: List[Optional[Callable]], + decompositions: Optional[Dict], + keep_input_mutations: bool, + dynamic: bool, + ): + """ + Runs dynamo and aot_autograd with the specified settings + """ + + def dynamo_compiler(gm, inputs, **kwargs): + result = aot_module_simplified( + gm, + inputs, + fw_compiler=make_boxed_compiler( + partial(extract_graph, graph_cell=fw_graph_cell) + ), + bw_compiler=nop, + decompositions=decompositions, + keep_inference_input_mutations=keep_input_mutations, + # Dynamic is calculated from whether the inputs have fake tensors + ) + return result + + def torch_compile_wrapper(*args, **kwargs): + torch._dynamo.reset() + fn = torch.compile(f, backend=dynamo_compiler) + try: + result = fn(*args, **kwargs) + except torch._dynamo.exc.BackendCompilerFailed as e: + # So that assertRaises works properly + raise e.inner_exception from e + return result + + return torch_compile_wrapper + + if __name__ == "__main__": run_tests() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e8664cb1e98d..f538c5af78ce 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -681,9 +681,43 @@ def test_while_loop_simple_with_linear_compile_check_graph(self): torch.compile(fn, backend=backend)(*inp) self.assertEqual(len(backend.graphs), 1) gm = backend.graphs[0] - self.assertExpectedInline( - gm.code.strip(), - """\ + if torch._dynamo.config.inline_inbuilt_nn_modules: + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter): + l_iter_ = L_iter_ + l_x_ = L_x_ + l_self_buffers_dec_ = L_self_buffers_dec_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + cond_fn_0 = self.cond_fn_0 + body_fn_0 = self.body_fn_0 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None + getitem = while_loop[0] + getitem_1 = while_loop[1]; while_loop = None + return (getitem, getitem_1)""", # noqa: B950 + ) + self.assertExpectedInline( + gm.cond_fn_0.code.strip(), + """\ +def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None + gt = sub > 0; sub = None + return gt""", # noqa: B950 + ) + self.assertExpectedInline( + gm.body_fn_0.code.strip(), + """\ +def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = l_iter_ - 1; l_iter_ = None + linear = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None + return (sub, linear)""", # noqa: B950 + ) + else: + self.assertExpectedInline( + gm.code.strip(), + """\ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): l_iter_ = L_iter_ l_x_ = L_x_ @@ -696,23 +730,23 @@ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): getitem = while_loop[0] getitem_1 = while_loop[1]; while_loop = None return (getitem, getitem_1)""", # noqa: B950 - ) - self.assertExpectedInline( - gm.cond_fn_0.code.strip(), - """\ + ) + self.assertExpectedInline( + gm.cond_fn_0.code.strip(), + """\ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None gt = sub > 0; sub = None return gt""", # noqa: B950 - ) - self.assertExpectedInline( - gm.body_fn_0.code.strip(), - """\ + ) + self.assertExpectedInline( + gm.body_fn_0.code.strip(), + """\ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): sub = l_iter_ - 1; l_iter_ = None linear = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None return (sub, linear)""", # noqa: B950 - ) + ) def test_while_loop_nested2_traced(self): fn, inp = WHILE_LOOP_TESTS["nested2"] diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index c767810beb85..8107f865f7bc 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -77,7 +77,6 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, - xfailIfTorchDynamo, ) from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -2342,7 +2341,8 @@ def f(x): self.assertEqual(actual, expected) # https://github.com/pytorch/pytorch/issues/127036 - @xfailIfTorchDynamo + # it won't fail as jacrev/jacfwd were not inlined (see #128255) + # @xfailIfTorchDynamo @parametrize("_preallocate_and_copy", (True, False)) def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): # With chunk_size=1, we shouldn't `vmap` and hence not be limited diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index c8df820c7c9c..4766b4cddabb 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -608,6 +608,7 @@ def abs_if_complex(t): "nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)} ), tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}), + tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), tol1( "nn.functional.multi_head_attention_forward", @@ -2031,6 +2032,10 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): tol2( "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)} ), + tol1( + "nn.functional.conv_transpose2d", + {torch.float32: tol(atol=5e-04, rtol=5e-04)}, + ), tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}), ), @@ -2366,6 +2371,8 @@ def fn(input, weight, bias): "linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)} ), tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}), + tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), + tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), ), ) def test_vmap_autograd_grad(self, device, dtype, op): diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 967152945af5..737927a60f80 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -25,6 +25,7 @@ } xfail_functorch_batched_decomposition = { + "aten::alias_copy", "aten::diagonal_copy", "aten::is_same_size", "aten::unfold_copy", diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py new file mode 100644 index 000000000000..8e1a7c5ae2cd --- /dev/null +++ b/test/fx/test_fx_xform_observer.py @@ -0,0 +1,61 @@ +# Owner(s): ["module: fx"] + +import os +import tempfile + +import torch +from torch.fx import subgraph_rewriter, symbolic_trace +from torch.fx.passes.graph_transform_observer import GraphTransformObserver + +from torch.testing._internal.common_utils import TestCase + + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_fx.py TESTNAME\n\n" + "instead." + ) + + +class TestGraphTransformObserver(TestCase): + def test_graph_transform_observer(self): + class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + return torch.add(val, val) + + def pattern(x): + return torch.neg(x) + + def replacement(x): + return torch.relu(x) + + traced = symbolic_trace(M()) + + log_url = tempfile.mkdtemp() + + with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob: + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + self.assertTrue("relu" in ob.created_nodes) + self.assertTrue("neg" in ob.erased_nodes) + + current_pass_count = GraphTransformObserver.get_current_pass_count() + + self.assertTrue( + os.path.isfile( + os.path.join( + log_url, + f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.svg", + ) + ) + ) + self.assertTrue( + os.path.isfile( + os.path.join( + log_url, + f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.svg", + ) + ) + ) diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py new file mode 100644 index 000000000000..ff6418238f8e --- /dev/null +++ b/test/fx/test_partitioner_order.py @@ -0,0 +1,53 @@ +# Owner(s): ["module: fx"] + +import unittest + +from typing import Mapping + +import torch +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.testing._internal.common_utils import TestCase + + +class DummyDevOperatorSupport(OperatorSupport): + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return True + + +class DummyPartitioner(CapabilityBasedPartitioner): + def __init__(self, graph_module: torch.fx.GraphModule): + super().__init__( + graph_module, + DummyDevOperatorSupport(), + allows_single_node_partition=True, + ) + + +class AddModule(torch.nn.Module): + def forward(self, x): + y = torch.add(x, x) + z = torch.add(y, x) + return z + + +class TestPartitionerOrder(TestCase): + # partitoner test to check graph node order + def test_partitioner_order(self): + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + partions = DummyPartitioner(traced_m).propose_partitions() + partion_nodes = [list(partition.nodes) for partition in partions] + node_order = [n.name for n in partion_nodes[0]] + for _ in range(10): + traced_m = torch.fx.symbolic_trace(m) + new_partion = DummyPartitioner(traced_m).propose_partitions() + new_partion_nodes = [list(partition.nodes) for partition in new_partion] + new_node_order = [n.name for n in new_partion_nodes[0]] + self.assertTrue(node_order == new_node_order) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index ea53b57e8209..cd8b80e8e886 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -25,7 +25,7 @@ ) from torch.testing._internal.torchbind_impls import init_torchbind_implementations -from torch.utils.hooks import RemovableHandle +from torch.utils.hooks import RemovableHandle # noqa: TCH001 @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support") diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index dc4fe6fcbf7d..fb15fa01d318 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -12,6 +12,7 @@ import torch import torch._export import torch._inductor +import torch._inductor.config import torch.nn as nn from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters @@ -421,9 +422,10 @@ class LinearModel(torch.nn.Module): def __init__(self, device): super().__init__() self.weight = torch.randn(10, 10, device=device).to(dtype) + self.bias = torch.randn(10, device=device).to(dtype) def forward(self, y): - return torch.nn.functional.linear(y, self.weight) + return torch.nn.functional.linear(y, self.weight, self.bias) example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),) @@ -968,29 +970,19 @@ class Model(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, primals_1, primals_2, primals_5): - view = torch.ops.aten.reshape.default(primals_5, [-1, 4, 128]) + def forward(self, primals_5): + view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) primals_5 = None permute = torch.ops.aten.permute.default(view, [0, 2, 1]) clone = torch.ops.aten.clone.default( permute, memory_format=torch.contiguous_format ) - permute = None - view_1 = torch.ops.aten.reshape.default(clone, [-1, 4]) - clone = None - permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0]) - primals_1 = None - addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1) - primals_2 = None - return addmm - - s0 = 727828 - s1 = 512 - example_inputs = ( - torch.rand(2, 4, device=self.device), - torch.rand(2, device=self.device), - torch.rand(s0, s1, device=self.device), - ) + return clone + + # let y_grid = 65537 + s0 = 16777472 + s1 = 8 + example_inputs = (torch.rand(s0, s1, device=self.device),) self.check_model(Model(), example_inputs) def test_cond_simple(self): @@ -1311,7 +1303,6 @@ def forward(self, x): return self.foo + x example_inputs = (torch.rand(4, 4, device=self.device),) - torch._export.aot_compile(Model(self.device), example_inputs) self.check_model(Model(self.device), example_inputs) def test_non_tensor_input(self): @@ -1323,14 +1314,19 @@ def fn(a, b, alpha=1.0): with self.assertRaises(RuntimeError): torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0}) - so_path = torch._export.aot_compile( - torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False - ) - kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) - res = kernel_runner.run([a, b]) - self.assertTrue(isinstance(res, list)) - self.assertTrue(len(res) == 1) - self.assertEqual(fn(a, b, alpha=2.0), res[0]) + for simdlen in [0, None]: + with torch._inductor.config.patch({"cpp.simdlen": simdlen}): + so_path = torch._export.aot_compile( + torch.ops.aten.add, + args=(a, b), + kwargs={"alpha": 2.0}, + same_signature=False, + ) + kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) + res = kernel_runner.run([a, b]) + self.assertTrue(isinstance(res, list)) + self.assertTrue(len(res) == 1) + self.assertEqual(fn(a, b, alpha=2.0), res[0]) def test_buffer_mutation_2(self): class Model(torch.nn.Module): @@ -1389,6 +1385,26 @@ def forward(self, inp_pos, k, v): self.check_model(model, example_inputs) self.code_check_count(model, example_inputs, "empty_strided", 2) + def test_buffer_mutation_4(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "_tensor_constant0", + torch.randint(1, size=[38], dtype=torch.int64, device="cpu"), + ) + + def forward(self, x): + return x + self._tensor_constant0.to(torch.device(type="cuda", index=0)) + + example_inputs = ( + torch.randint(1, size=[38], dtype=torch.int64, device="cuda"), + ) + torch._export.aot_compile(Model(), example_inputs) + @requires_multigpu() def test_replicate_on_devices(self): if self.device != "cuda": @@ -3065,7 +3081,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): CUDA_TEST_FAILURES = { # test_failures, xfail by default, set is_skip=True to skip - "test_large_grid": fail_cuda(), "test_normal_functional": fail_abi_compatible_cuda(is_skip=True), # no runtime checks for non_abi_compatible mode "test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True), diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index a8e6392892f7..1a25e81ebf27 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -56,7 +56,7 @@ def __init__(self, in_channels, out_channels, device, **kwargs): self.use_scalar = scalar tensor_size = [1 for _ in range(self.conv.weight.ndim)] tensor_size[1] = self.conv.weight.size(0) - self.tensor = ( + self.tensor = torch.nn.Parameter( add_tensor if add_tensor is not None else torch.rand(tensor_size).to(device) @@ -136,7 +136,11 @@ def my_inner_compile(gm, example_inputs, *args, **kwargs): nn.Conv2d, pytorch_op, False, - add_tensor=torch.rand(32, 1, 32).to(self.device), + add_tensor=torch.rand( + 32, + 1, + 32, + ).to(self.device), expect_success=False, ) @@ -156,7 +160,7 @@ def my_inner_compile(gm, example_inputs, *args, **kwargs): nn.Conv2d, pytorch_op, False, - add_tensor=torch.tensor([2]).to(torch.int).to(self.device), + add_tensor=torch.tensor([2]).to(torch.float64).to(self.device), expect_success=False, ) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 994786740a65..3ef39adeed3d 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -12,8 +12,8 @@ from torch._dynamo import reset from torch._dynamo.utils import counters from torch._inductor import config, metrics +from torch._inductor.async_compile import AsyncCompile from torch._inductor.codecache import ( - AsyncCompile, cuda_compile_command, CUDACodeCache, FxGraphCachePickler, @@ -195,9 +195,9 @@ def put(self, filename, data): num_put += 1 cache_module = ( - "triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" + "triton.fb.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend" if config.is_fbcode() - else "triton.runtime.cache.RedisRemoteCacheBackend" + else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) with config.patch( @@ -465,6 +465,81 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + @config.patch({"fx_graph_cache": True}) + def test_cache_with_nt(self): + def gen_nt(r): + values = torch.randn(r, 16) + offsets = torch.tensor([0, 2, 3, 6, 13, r]) + return torch.nested.nested_tensor_from_jagged(values, offsets) + + def fn(nt): + if nt.values().size(0) % 16 == 0: + return nt.sin() + return nt.cos() + + inp1 = gen_nt(19) + inp2 = gen_nt(20) + + counters.clear() + torch.compile(fn)(inp1) + torch.compile(fn)(inp2) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.reset() + counters.clear() + torch.compile(fn)(inp1) + torch.compile(fn)(inp2) + # TODO(oulgen): This doesnt actually produce a cache hit. + # Despite pickling the exact same object, pickle produces different + # results. + # self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + # self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + + @config.patch({"fx_graph_cache": True}) + def test_cache_with_symint_non_arg_guard(self): + def fn(x, ref_id): + self_id = 22 + if self_id == ref_id: + x = torch.mul(x, 1.0) + else: + x = torch.mul(x, 0) + return x + + x = torch.ones(2) + + counters.clear() + torch.compile(fn, fullgraph=True, dynamic=True)(x, 2) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.reset() + counters.clear() + torch.compile(fn, fullgraph=True, dynamic=True)(x, 2) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + + @config.patch({"fx_graph_cache": True}) + def test_cache_guard(self): + def f(x, val): + if val > 5: + return x.sin() + else: + return x.cos() + + x = torch.ones(2) + a = torch.compile(f, dynamic=True)(x, 6) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.reset() + counters.clear() + b = torch.compile(f, dynamic=True)(x, 4) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + + self.assertNotEqual(a, b) + class TestFxGraphCacheHashing(TestCase): def test_tensor_constants(self): @@ -527,7 +602,7 @@ def test_hash_fake_tensors(self): FxGraphCachePickler.dumps(torch.randn(3)[1:]), FxGraphCachePickler.dumps(torch.randn(3)[1:]), ) - self.assertNotEqual( + self.assertEqual( FxGraphCachePickler.dumps(torch.randn(3)[1:]), FxGraphCachePickler.dumps(torch.randn(2)), ) @@ -586,16 +661,16 @@ def test_hash_kwargs(self): ordering of the kwargs dict and any set arguments. """ # Dict order of the kwargs should not affect hashes. - details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}) - details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}) + details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, []) + details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0}, []) self.assertEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), ) # Different kwarg values should affect hashes. - details1 = FxGraphHashDetails(None, [], {"a": 0}) - details2 = FxGraphHashDetails(None, [], {"a": 1}) + details1 = FxGraphHashDetails(None, [], {"a": 0}, []) + details2 = FxGraphHashDetails(None, [], {"a": 1}, []) self.assertNotEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), @@ -605,16 +680,16 @@ def test_hash_kwargs(self): # sorting and creating a new set seems to change the order. set1 = {"a", "b", "c", "d", "e", "f", "g"} set2 = set(sorted(set1)) # noqa: C414 - details1 = FxGraphHashDetails(None, [], {"a": set1}) - details2 = FxGraphHashDetails(None, [], {"a": set2}) + details1 = FxGraphHashDetails(None, [], {"a": set1}, []) + details2 = FxGraphHashDetails(None, [], {"a": set2}, []) self.assertEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), ) # But different set contents should affect hashes. - details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}) - details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}) + details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}}, []) + details2 = FxGraphHashDetails(None, [], {"a": {1, 2}}, []) self.assertNotEqual( FxGraphCachePickler.dumps(details1), FxGraphCachePickler.dumps(details2), @@ -625,11 +700,11 @@ def test_hash_config_changes(self): Test that different config settings affect hashes. """ with config.patch({"max_autotune": False}): - details1 = FxGraphHashDetails(None, [], {}) - details2 = FxGraphHashDetails(None, [], {}) + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) with config.patch({"max_autotune": True}): - details3 = FxGraphHashDetails(None, [], {}) + details3 = FxGraphHashDetails(None, [], {}, []) self.assertEqual( FxGraphCachePickler.dumps(details1), diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 87299d796f6c..e09928cf5576 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import functools +import io import logging import re import sys @@ -11,8 +12,9 @@ import torch import torch.nn as nn from torch import _inductor as inductor -from torch._dynamo import compiled_autograd +from torch._dynamo import compiled_autograd, config from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.testing._internal.logging_utils import logs_to_string @@ -54,10 +56,14 @@ def hook3(gI, gO): class TestCompiledAutograd(TestCase): def setUp(self) -> None: super().setUp() + torch._logging.set_logs(compiled_autograd_verbose=False) + config.compiled_autograd = False compiled_autograd.reset() def tearDown(self) -> None: super().tearDown() + torch._logging.set_logs(compiled_autograd_verbose=False) + config.compiled_autograd = False compiled_autograd.reset() def check_output_and_recompiles( @@ -230,6 +236,170 @@ def fn(): self.check_output_and_recompiles(fn) + def test_torch_compile_api_inductor(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn) + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_api_aot_eager(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn, backend="aot_eager") + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_api_eager(self): + def fn(): + torch.manual_seed(123) + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + + res = [] + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + result.backward() + res.append(model[0].weight.grad) + res.append(model[0].bias.grad) + model.zero_grad() + return res + + expected = fn() + with config.patch(compiled_autograd=True): + compiled_fn = torch.compile(fn, backend="eager") + actual = compiled_fn() + self.assertEqual(expected, actual) + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_multiple_torch_compile(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + def fn(): + result = model(x).sum() + result.backward() + + model2 = torch.nn.Linear(4, 4) + x2 = torch.randn([1, 4]) + + def fn2(): + result = model2(x2).sum() + result.backward() + + no_ca1 = torch.compile(fn) + no_ca1() + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + counters.clear() + + with config.patch(compiled_autograd=True): + with_ca = torch.compile(fn2) + with_ca() + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + counters.clear() + + no_ca2 = torch.compile(fn) + no_ca2() + self.assertEqual(counters["compiled_autograd"]["captures"], 0) + + def test_torch_compile_graph_break(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + @torch._dynamo.disable() + def fn(): + result = model(x).sum() + result.backward() + + with config.patch(compiled_autograd=True): + opt_fn = torch.compile(fn) + opt_fn() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_graph_break2(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + @torch._dynamo.disable() + def inner_fn(loss): + loss.backward() + + def fn(): + result = model(x).sum() + inner_fn(result) + + with config.patch(compiled_autograd=True): + opt_fn = torch.compile(fn) + opt_fn() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + + def test_torch_compile_only_backward_call(self): + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.Sigmoid(), + ) + x = torch.randn([1, 4]) + + result = model(x).sum() + with config.patch(compiled_autograd=True): + opt_bwd = torch.compile(lambda: result.backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + def test_dynamo_boxed(self): def get_placeholders(gm_): placeholders = [] @@ -289,7 +459,7 @@ def test_inputs_aliasing_bytecode_attr_mutations(self): param_proxy, activ_proxy = proxies buf = activ_proxy * 2 torch.ops.inductor.accumulate_grad_.default(param_proxy, buf) - compiled_fn = compiler.end_capture(buf) + runtime_wrapper, compiled_fn = compiler.end_capture(buf) def bytecode_hook(code, out_code): import dis @@ -326,7 +496,9 @@ def bytecode_hook(code, out_code): torch._dynamo.reset() handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) try: - compiled_fn(inputs=[param, activ], sizes=(), hooks=()) + runtime_wrapper( + compiled_fn=compiled_fn, inputs=[param, activ], sizes=(), hooks=() + ) finally: handle.remove() @@ -409,7 +581,7 @@ def model(x): self.check_output_and_recompiles(fn) - def test_output_nodes(self): + def test_output_nodes_all_leaves(self): def fn(): y = torch.randn(1, 4, requires_grad=True) z = torch.randn(1, 4, requires_grad=True) @@ -421,7 +593,7 @@ def model(x): x = torch.randn([1, 4]) result = model(x).sum() - gy, gz = torch.autograd.grad(result, [y, z]) + gy, gz = torch.autograd.grad(result, inputs=[y, z]) assert y.grad is None assert z.grad is None yield gy @@ -429,6 +601,111 @@ def model(x): self.check_output_and_recompiles(fn) + def test_output_nodes_some_leaves(self): + def fn(): + class UnreachableBwd(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + raise RuntimeError + + y = torch.randn(1, 4, requires_grad=True) + z = torch.randn(1, 4, requires_grad=True) + + def model(x): + return torch.sigmoid(UnreachableBwd.apply(y) * z) + + for _ in range(3): + x = torch.randn([1, 4]) + + result = model(x).sum() + gz = torch.autograd.grad(result, inputs=[z]) + assert y.grad is None + assert z.grad is None + yield gz + + self.check_output_and_recompiles(fn) + + def test_no_output_nodes_all_leaves(self): + def fn(): + y = torch.randn(1, 4, requires_grad=True) + z = torch.randn(1, 4, requires_grad=True) + + def model(x): + return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y)) + + for _ in range(3): + x = torch.randn([1, 4]) + result = model(x).sum() + out = result.backward() + assert out is None + assert y.grad is not None + assert z.grad is not None + yield y.grad + yield z.grad + y.grad = None + z.grad = None + + self.check_output_and_recompiles(fn) + + def test_no_output_nodes_some_leaves(self): + def fn(): + class UnreachableBwd(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, gO): + raise RuntimeError + + y = torch.randn(1, 4, requires_grad=True) + z = torch.randn(1, 4, requires_grad=True) + a = torch.randn(1, 4, requires_grad=True) + + def model(x): + return torch.sigmoid(x * y * z * UnreachableBwd.apply(a)) + + for _ in range(3): + x = torch.randn([1, 4]) + result = model(x).sum() + out = result.backward(inputs=[y, z]) + assert out is None + assert y.grad is not None + assert z.grad is not None + assert a.grad is None + yield y.grad + yield z.grad + y.grad = None + z.grad = None + + self.check_output_and_recompiles(fn) + + def test_no_output_nodes_different_leaves_will_recompile(self): + def fn(): + def fwd(x, y, z): + out = x * y # MulBackward0 + out2 = out * z # MulBackward0 + return out2.sum() # SumBackward0 + + x = torch.randn(5, requires_grad=True) + y = torch.randn(5, requires_grad=True) + z = torch.randn(5, requires_grad=True) + loss = fwd(x, y, z) + torch.compile(lambda: torch.autograd.backward(loss, inputs=[x]))() + yield x.grad + x.grad = None + + loss = fwd(x, y, z) + torch.compile(lambda: torch.autograd.backward(loss, inputs=[y]))() + yield y.grad + + # Guarded by TensorArg id, mismatch on last MulBackward0 + self.check_output_and_recompiles(fn, 2) + def test_dynamic_shapes(self): def fn(): model = torch.nn.Sequential( @@ -1490,6 +1767,147 @@ def fn(inputs): out = compiled_fn(activations) self.assertTrue(len(activations) == 0) + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_cpu_division(self): + from torch._dynamo.testing import reduce_to_scalar_loss + + model = torch.nn.Linear(10, 10, dtype=torch.float16).cuda() + inputs = torch.randn(10, 10, dtype=torch.float16).cuda() + out = model(inputs) + loss = reduce_to_scalar_loss(out) + + stderr_msgs = io.StringIO() + with mock.patch("sys.stderr", stderr_msgs), compiled_autograd.enable( + compiler_fn + ): + torch._inductor.config.triton.cudagraphs = True + loss.backward() + torch._inductor.config.triton.cudagraphs = False + + self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue()) + + def test_cudagraphs_cpu_graph(self): + from torch._dynamo.testing import reduce_to_scalar_loss + + model = torch.nn.Linear(10, 10, dtype=torch.float16) + inputs = torch.randn(10, 10, dtype=torch.float16) + out = model(inputs) + loss = reduce_to_scalar_loss(out) + + with compiled_autograd.enable(compiler_fn): + torch._inductor.config.triton.cudagraphs = True + loss.backward() + torch._inductor.config.triton.cudagraphs = False + + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_sdpa(self): + query = torch.rand( + 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True + ) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + out = torch.nn.functional.scaled_dot_product_attention(query, key, value) + + with config.patch(compiled_autograd=True), inductor_config.patch( + "triton.cudagraphs", True + ): + opt_bwd = torch.compile(lambda: out.sum().backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): + class MyFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + cpu_tensor = torch.tensor(5) + ctx.save_for_backward(x, cpu_tensor) # visible to c++/autograd + ctx.cpu_scalar = 5 # opaque to c++/autograd + return x.sum() + + @staticmethod + def backward(ctx, gO): + x, cpu_tensor = ctx.saved_tensors + expand = gO * torch.ones_like(x) + return expand * cpu_tensor * ctx.cpu_scalar + + x = torch.randn(10, requires_grad=True, device="cuda") + out = MyFn.apply(x) + with config.patch(compiled_autograd=True), inductor_config.patch( + "triton.cudagraphs", True + ): + opt_bwd = torch.compile(lambda: out.backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. + # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + + @unittest.skipIf(not HAS_CUDA, "requires cuda") + def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self): + cpp_source = """ +struct CustomOpAutogradFunction : public torch::autograd::Function { + static constexpr bool is_traceable = true; + + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& x) { + const auto& cpu_tensor = torch::tensor(1); + ctx->save_for_backward({x, cpu_tensor}); + ctx->saved_data["cpu_scalar"] = 1; + return x; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::variable_list grad_output) { + const auto& saved_variables = ctx->get_saved_variables(); + assert(saved_variables.size() == 2); + torch::Tensor x = saved_variables[0]; + torch::Tensor cpu_tensor = saved_variables[1]; + int cpu_scalar = ctx->saved_data["cpu_scalar"].toInt(); + auto expand = grad_output[0] * torch::ones_like(x); + torch::autograd::variable_list grad_inputs(1); + grad_inputs[0] = expand * cpu_tensor * cpu_scalar; // autograd engine asserts that tensors are on same device + return grad_inputs; + } +}; + +torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { + return CustomOpAutogradFunction::apply(x); +} + +TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { + m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); +} + """ + + module = torch.utils.cpp_extension.load_inline( + name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op", + cpp_sources=cpp_source, + functions="custom_op_backed_by_autograd_fn", + verbose=True, + ) + + x = torch.randn(2, 2, requires_grad=True, device="cuda") + with config.patch(compiled_autograd=True), inductor_config.patch( + "triton.cudagraphs", True + ): + out = torch.ops.test_cudagraphs_cpu_scalar_used_in_cpp_custom_op.custom_op_backed_by_autograd_fn( + x + ) + opt_bwd = torch.compile(lambda: out.sum().backward()) + opt_bwd() + + self.assertEqual(counters["compiled_autograd"]["captures"], 1) + # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + def test_verbose_logs_graph(self): torch._logging.set_logs(compiled_autograd_verbose=True) @@ -1673,7 +2091,18 @@ def wrap_test_class(orig_cls): return cls -known_graph_breaks_tests = {} +known_graph_breaks_tests = { + "test_hook_none", # uses assert in hook + "test_post_accumulate_grad_hook_e2e", # optim.Adam manually graph breaks + "test_tensor_hooks_inplace", # uses assert in hook + "test_tensor_hooks_inplace_over_view", # uses assert in hook + "test_grad_fn_prehooks", # uses assert in hook + "test_grad_fn_prehooks_multiple_outputs", # uses assert in hook + "test_grad_fn_prehooks_remove_hooks", # uses handle.remove() in hook + "test_tensor_hooks_inplace_multiple_outputs", # uses assert in hook + "test_hooks", # uses assert in hook + "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose +} # These groups of tests aren't supported yet known_failures_re = re.compile( @@ -1691,23 +2120,14 @@ def wrap_test_class(orig_cls): "test_saved_variable_saved_original_inplace_detach", # AssertionError: RuntimeError not raised "test_saving_variable_to_disk", # Cannot call numel() on tensor with symbolic sizes/strides "test_setitem_mask", # torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: It appears that you're - "test_tensor_hooks_inplace_over_view", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} - "test_tensor_hooks_inplace", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} "test_wrapped_number_saved_variable_hooks", # RuntimeError: this hook should not be called - "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # data dependent operator: aten.allclose.default "test_accumulate_grad_tensor_reference", # backend='inner_compiler' raised: "test_anomaly_grad_warnings", # "one of the variables needed for gradient computation has been modified by an... "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd - "test_backward_with_inputs", # specifying inputs= with .backward() not yet implemented for compiled autograd "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd "test_custom_function_exception", # "Simulate error on backward pass" does not match "type object 'SimulateBackwa... "test_grad_batched_grad", # Cannot access storage of BatchedTensorImpl - "test_grad_unreachable_discovery", # specifying inputs= with .backward() not yet implemented for compiled autograd "test_index_backward_does_not_save_tensor", # dynamic shape operator: aten.nonzero.default - "test_post_accumulate_grad_hook_e2e", # tensor_post_acc_grad_hooks not implemented for compiled autograd - "test_post_accumulate_grad_hook_gets_cleaned_up", # tensor_post_acc_grad_hooks not implemented for compiled autograd - "test_post_accumulate_grad_hook_multiple_hooks", # tensor_post_acc_grad_hooks not implemented for compiled autograd - "test_post_accumulate_grad_hook_multiple_tensors", # tensor_post_acc_grad_hooks not implemented for compiled autograd "test_post_accumulate_grad_hook_ordering", # tensor_post_acc_grad_hooks not implemented for compiled autograd "test_post_accumulate_grad_hook_returns_not_None", # "hooks should return None." does not match "test_reentrant_child_error", # "Simulate error" does not match "type object 'ReentrantFunc' has no attribute... @@ -1739,21 +2159,20 @@ def wrap_test_class(orig_cls): "test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object - "test_hook_edge_case_when_called_with_grad", # RuntimeError: specifying inputs= with .backward() not yet - "test_hooks", # torch._dynamo.exc.Unsupported: inline in skipfiles + "test_hook_edge_case_when_called_with_grad", # retains_grad_hooks NYI "test_inplace_on_view_backward", # RuntimeError: compiled_autograd does not support create_graph - "test_multi_grad_any_hooks", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd - "test_multi_grad_all_hooks", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_multi_grad_any_hooks", # register_multi_grad_hook NYI + "test_multi_grad_all_hooks", # retains_grad_hooks NYI "test_nested_anomaly_detect_nan", # RuntimeError: compiled_autograd does not support create_graph "test_nested_anomaly_printstack_cleanup", # RuntimeError: compiled_autograd does not support create_graph "test_once_differentiable", # RuntimeError: compiled_autograd does not support create_graph - "test_prehook_ordering", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_prehook_ordering", # retains_grad_hooks NYI "test_retain_grad", # RuntimeError: retains_grad_hooks not implemented for compiled autograd "test_saved_variable_packing_unpacking_saved_original_with_hooks", # RuntimeError: compiled_autograd "test_select_sum", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients "test_unrelated_inputs", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients - "test_will_engine_execute_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd - "test_backward_to_node", # RuntimeError: specifying inputs= with .backward() not yet implemented for compiled autograd + "test_will_engine_execute_node", # retains_grad_hooks NYI + "test_backward_to_node", # retains_grad_hooks NYI "test_anomaly_detect_nan", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function aten.add.Tensor( "test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable( "test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance @@ -1770,11 +2189,7 @@ def wrap_test_class(orig_cls): "test_deep_reentrant", # torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of "test_dont_materialize_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone "test_function_returns_undefined_tensor", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function - "test_grad_fn_prehooks", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} - "test_grad_fn_prehooks_multiple_outputs", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: - "test_grad_fn_prehooks_remove_hooks", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: RemovableHandle.remove "test_grad_mode_restored_reentrant", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue - "test_hook_none", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNotNone "test_invalid_gradients", # AssertionError: "expected shape" does not match "The size of tensor a (5) must match "test_mark_non_differentiable_mixed", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertTrue "test_materialize_grads", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} @@ -1794,7 +2209,6 @@ def wrap_test_class(orig_cls): "test_set_materialize_non_diff_grads", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsNone "test_setup_context_when_forward_has_default_args", # torch._dynamo.exc.Unsupported: call_function args "test_simple_reentrant", # torch._dynamo.exc.Unsupported: call_method SkipFunctionVariable() sum [] {} - "test_tensor_hooks_inplace_multiple_outputs", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} "test_lobpcg", # torch._dynamo.exc.Unsupported: 'call_function LOBPCGAutogradFunction.backward in skip_files "test_backward_dict_grad_for_nontensor", # AssertionError: "non-Tensor-like types" does not match "'skip function "test_backward_dict_invalid_keys", # AssertionError: "to have keys {'x'}" does not match "'skip function @@ -1807,9 +2221,9 @@ def wrap_test_class(orig_cls): "test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # AssertionError: "None or Tensor" "test_backward_tensorlist_input_requires_list_grads_with_same_numel", # AssertionError: "3 gradients "test_save_for_backward_inputs_are_namedtuple", # torch._dynamo.exc.Unsupported: 'skip function - "test_autograd_function_backed_op", # RuntimeError: compiled_args not implemented "test_setitem", # AssertionError: Tensor-likes are not close! "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) + "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads } if not HAS_CUDA: diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 0888f3ad47a1..8bf9b1e6a61f 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) -from torch.testing._internal.common_utils import IS_MACOS, slowTest +from torch.testing._internal.common_utils import IS_MACOS, slowTest, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import HAS_CPU @@ -68,7 +68,17 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): ("cpp_wrapper",), is_skip=True ), } - +if TEST_WITH_ROCM: + test_failures_cpp_wrapper.update( + { + "test_linear_packed": test_torchinductor.TestFailure( + ("cpp_wrapper"), is_skip=True + ), + "test_linear_packed_dynamic_shapes": test_torchinductor.TestFailure( + ("cpp_wrapper"), is_skip=True + ), + } + ) if config.abi_compatible: xfail_list = [ "test_conv2d_binary_inplace_fusion_failed_cpu", @@ -84,6 +94,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_maxpool2d_linear_dynamic_cpu", "test_qconv2d_relu_cpu", "test_qlinear_cpu", + "test_qlinear_add_cpu", "test_qlinear_dequant_promotion_cpu", "test_qlinear_relu_cpu", ] @@ -114,6 +125,7 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, + skip=None, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -122,6 +134,8 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func + if skip: + func = unittest.skip(skip)(func) @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): @@ -169,6 +183,7 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} + skip: str = None for item in [ BaseTest("test_add_complex"), @@ -227,7 +242,9 @@ class BaseTest(NamedTuple): torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), - BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()), + BaseTest( + "test_linear_packed", "", test_cpu_repro.CPUReproTests(), skip="Failing" + ), BaseTest( "test_lstm_packed_change_input_sizes", "cpu", @@ -301,18 +318,21 @@ class BaseTest(NamedTuple): "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_add", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_add_relu", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), + skip="Failing", ), BaseTest( "test_qlinear_dequant_promotion", @@ -368,6 +388,7 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, + skip=item.skip, ) test_torchinductor.copy_tests( diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 5a552aa15c96..79ace98c1b96 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1884,6 +1884,14 @@ def test_ops_masked_with_bool_input(self): self.assertEqual(res_aten_eager, res) check_metrics_vec_kernel_count(1) + def test_bitwise_right_shift(self): + x = torch.randint(-1, 0, (1, 1, 1), device="cpu", dtype=torch.int64) + bit_num = 31 + res_aten_eager = torch.bitwise_right_shift(x, bit_num) + cfn = torch.compile(torch.bitwise_right_shift) + res = cfn(x, bit_num) + self.assertEqual(res_aten_eager, res) + @patch("torch.cuda.is_available", lambda: False) def test_scatter_using_atomic_add(self): def fn(a, dim, index, b): @@ -2224,7 +2232,6 @@ def get_index(): graph_lowering = GraphLowering( torch.fx.GraphModule(submodules, _graph), shape_env=None, - num_static_inputs=0, ) def set_opt_dtype(graph): @@ -2335,7 +2342,6 @@ def get_index(): graph_lowering = GraphLowering( torch.fx.GraphModule(submodules, _graph), shape_env=None, - num_static_inputs=0, ) with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( graph_lowering @@ -3755,6 +3761,20 @@ def fn(arg0_1): exactly=True, ).run(code) + def test_repeated_exp(self): + def fn(x): + y = x.sigmoid() + return y + 1, y.sum(-1) + + x = torch.randn(1000, 1000) + opt_fn = torch.compile(fn) + _, code = run_and_get_cpp_code(opt_fn, x) + FileCheck().check_count( + ".exp()", + 1, + exactly=True, + ).run(code) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index bab01927fac6..495a6362497d 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -211,6 +211,7 @@ class BaseTest(NamedTuple): BaseTest("test_reduction1"), # Reduction BaseTest("test_relu"), # multiple inputs BaseTest("test_repeat_interleave_2"), + BaseTest("test_roi_align"), BaseTest("test_scalar_input"), BaseTest("test_scaled_dot_product_attention"), BaseTest("test_scaled_dot_product_efficient_attention"), diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 386fb36a635e..8365d216f82c 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -417,8 +417,8 @@ def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr): block_start = pid * XBLOCK offsets = block_start + tl.arange(0, XBLOCK) mask = offsets < xnumel - x = tl.load(in_out_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr0 + offsets, mask=mask) + x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0) + y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0) output = x + y tl.store(in_out_ptr0 + offsets, output, mask=mask) @@ -1181,6 +1181,63 @@ def outer_reduce(x): self.assertEqual(outer_reduce(a), out) self.assertTrue("for roffset" not in code) + def test_epilogue_fusion_with_view(self): + class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.linear = torch.nn.Linear(262144, 100) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = x.view(x.size(0), -1) + return self.relu(self.linear(x)) + + m = ToyModel().to(device="cuda:0") + input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0") + from torch._inductor.utils import fresh_inductor_cache + + with fresh_inductor_cache(): + cm = torch.compile(m, mode="max-autotune") + out = cm(input_tensor) + out2 = m(input_tensor) + self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) + + def test_reflection_pad_loop_order(self): + def fn(x, y): + a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect") + b = torch.nn.functional.pad(y, (5, 5, 5, 5), mode="reflect") + return a + b + + cfn = torch.compile(fn) + a = torch.rand((10, 10, 10), device="cuda") + b = torch.rand((10, 10, 10), device="cuda") + expect = fn(a, b) + actual, code = run_and_get_code(cfn, a, b) + self.assertEqual(expect, actual) + + # Expect the code iterates in contiguous order, and is not tiled + kernel_code = "\n".join(code[0].split("\n")[50:64]) + self.assertExpectedInline( + kernel_code, + """\ +@triton.jit +def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 4000 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex % 20 + x1 = (xindex // 20) % 20 + x2 = (xindex // 400) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') + tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 33a179a9abc7..ac26f6a6656c 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -6,7 +6,8 @@ import torch from torch._inductor import config -from torch._inductor.codecache import AsyncCompile, CUDACodeCache +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.codecache import CUDACodeCache from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 1ac9af7bc6e7..7e8b9fce2b3b 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -648,7 +648,9 @@ def get_aligned_inputs(): with mode: inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)] - compiled_f = compile_fx_inner(mod, inps, num_fixed=1, cudagraphs=True) + compiled_f = compile_fx_inner( + mod, inps, static_input_idxs=[0], cudagraphs=True + ) def get_unaligned_inputs(): return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)] @@ -1729,6 +1731,189 @@ def test_storage_access_error(self): with self.assertRaisesRegex(Exception, "custom error msg"): device = x.untyped_storage() + def test_static_inputs_address_mutation_log(self): + class Goo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2, device="cuda") + + def forward(self, x) -> torch.Tensor: + return self.linear(x) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.static_tensor = torch.zeros((2, 2), device="cuda") + self.goo = Goo() + + def forward(self, x) -> torch.Tensor: + self.static_tensor.add_(torch.ones((2, 2), device="cuda")) + return self.static_tensor + x + self.goo(x) + + foo = Foo() + foo = torch.compile(foo, mode="reduce-overhead") + inp = torch.rand((2, 2), device="cuda") + + for _ in range(3): + foo(inp) + + # mutates static input tensors' addresses + foo.static_tensor = torch.ones((2, 2), device="cuda") + foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda")) + + with self.assertRaisesRegex( + Exception, + r"static input data pointer changed.\n" + r"input name: primals_2. data pointer changed from .* to .*. input stack trace: None\n" + r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," + r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n\n", + ): + self.curr_node().run( + [foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp] + ) + + def run_static_input_param_test(self, fn_eager, num_graphs): + with torch.device("cuda"): + fn_compiled = torch.compile(fn_eager, mode="reduce-overhead") + + def run_iter(param, fn): + fwd_output = fn(torch.ones(2, 2), param) + fwd_output.sum().backward() + grad_output = param.grad.clone().detach() + param.grad = None + return fwd_output, grad_output + + def loop(param): + exp_output, exp_grad = run_iter(param, fn_eager) + for _ in range(5): + compiled_output, compiled_grad = run_iter(param, fn_compiled) + self.assertEqual(exp_output, compiled_output) + self.assertEqual(exp_grad, compiled_grad) + + p1 = torch.nn.Parameter(torch.rand([2, 2])) + loop(p1) + + p2 = torch.nn.Parameter(torch.rand([2, 2])) + loop(p2) + + # Run p1 again to ensure we reuse the previous recording + loop(p1) + + self.assertEqual(self.get_manager().new_graph_id().id, num_graphs) + + def _module_test(self, mod): + with torch.device("cuda"): + + def fn(x, mod): + return mod(x) + + fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True) + + def run_test_iter(mod, fn): + fwd_output = fn(torch.ones(2, 2), mod) + fwd_output.sum().backward() + grad_output = mod.weight.grad.clone().detach() + mod.zero_grad() + return fwd_output, grad_output + + def run_test(): + exp_output, exp_grad = run_test_iter(mod, fn) + for _ in range(5): + compiled_output, compiled_grad = run_test_iter(mod, fn_compiled) + self.assertEqual(exp_output, compiled_output) + self.assertEqual(exp_grad, compiled_grad) + + run_test() + old = mod.weight.data + mod.weight.data = torch.rand_like(mod.weight.data) + run_test() + # Run original version to verify we reuse the other recording + mod.weight.data = old + run_test() + + # Fwd + bwd graphs for each version of the function => 4 graphs + self.assertEqual(self.get_manager().new_graph_id().id, 4) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_single_compile_param_inputs(self): + # Verify that we can record multiple cudagraphs for a single + # compiled function with param inputs + def fn(x, y): + return x * y + + # Fwd + bwd graphs for each version of the function => 4 graphs + self.run_static_input_param_test(fn, 4) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_single_compile_builtin_module(self): + # Verify that we don't recompile when changing the param of a builtin module + # and that we record another cudagraph + # Note: Linear is a builtin module so we enable that config setting above + self._module_test(torch.nn.Linear(2, 3, device="cuda")) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_custom_module(self): + # Test that we can correctly dispatch multiple graphs + # if params of a custom module change + class TestModule(torch.nn.Module): + def __init__(self, param) -> None: + super().__init__() + self.weight = param + + def forward(self, x): + return self.weight * x + + self._module_test( + TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda"))) + ) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_child_node(self): + # Test that we can correctly dispatch multiple graphs if a child node + # in the tree has stable input pointers change + def fn(x, p): + # Graph 1 + y = x * x + torch._dynamo.graph_break() + # Graph 2 + return y * p + + # We have 5 graphs here + # Graph 1 + # / \ + # Graph 2 w/ p1 Graph 2 w/ p2 + # and then two backward graphs + self.run_static_input_param_test(fn, 5) + + @torch._inductor.config.patch("triton.cudagraphs", True) + @torch._dynamo.config.patch("error_on_recompile", True) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_multi_dispatch_parent_node(self): + def fn(x, p): + # Graph 1 + y = x * p + torch._dynamo.graph_break() + # Graph 2 + return y + x + + # We have 6 graphs here + # Graph 1 w/ p1 Graph 1 w/ p2 + # | | + # Graph 2 (v1) Graph 2 (v2) + # There are two versions of graph 2 because + # we re-record due to different memory state after running the + # two versions of Graph 1 + # and then two backward graphs + self.run_static_input_param_test(fn, 6) + instantiate_parametrized_tests(CudaGraphTreeTests) if __name__ == "__main__": diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index c65b7585f9f3..9def5230bdba 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -158,7 +158,7 @@ def test_conv_bn_eval( out_eager = mod_eager(inp) out_optimized = mod_optimized(inp) - self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5) + self.assertEqual(out_optimized, out_eager, atol=3e-04, rtol=1e-5) out_eager.mean().backward() out_optimized.mean().backward() @@ -170,7 +170,7 @@ def test_conv_bn_eval( out_eager_bw = mod_eager(inp_bw) out_optimized_bw = mod_optimized(inp_bw) - self.assertEqual(out_eager_bw, out_optimized_bw, atol=2e-04, rtol=1e-5) + self.assertEqual(out_eager_bw, out_optimized_bw, atol=3e-04, rtol=1e-5) current_value = counters["inductor"]["efficient_conv_bn_eval"] self.assertEqual( current_value - original_value, test_class.expected_optimization_count diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index bc688ab834cb..c6f03052f37a 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -31,7 +31,10 @@ # Skip tests if Triton is not available supported_platform = skipUnless( - torch.cuda.is_available() and has_triton() and torch.version.hip is None, + torch.cuda.is_available() + and has_triton() + and torch.version.hip is None + and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) @@ -144,37 +147,29 @@ def _check_equal( ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() + if torch.isnan(compiled_error).any() and not torch.isnan(ref_error).any(): + self.assertTrue(False, "Output/Grad with NaN") if compiled_error > ref_error * fudge_factor: name = tensor_name if tensor_name is not None else "" msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) - def run_test( + def _check_out_and_grad( self, - score_mod: Callable, - dtype: torch.dtype = torch.float16, - B: int = B, - H: int = H, - S: int = S, - D: int = D, + golden_out: torch.Tensor, + ref_out: torch.Tensor, + compiled_out: torch.Tensor, + q_gold: torch.Tensor, + q_ref: torch.Tensor, + q: torch.Tensor, + k_gold: torch.Tensor, + k_ref: torch.Tensor, + k: torch.Tensor, + v_gold: torch.Tensor, + v_ref: torch.Tensor, + v: torch.Tensor, ): - sdpa_partial = create_attention(score_mod) - compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) - q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) - golden_out = sdpa_partial(q_gold, k_gold, v_gold) - ref_out = sdpa_partial(q_ref, k_ref, v_ref) - compiled_out = compiled_sdpa(q, k, v) - - backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - - golden_out.backward(backward_grad.to(torch.float64)) - ref_out.backward(backward_grad) - compiled_out.backward(backward_grad) - + dtype = ref_out.dtype with torch.no_grad(): # Note, it seems like we really are less accurate than the float32 # computation, likely due to the online softmax @@ -195,11 +190,62 @@ def run_test( self._check_equal( k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" ) - v_fudge_factor = 8 * fudge_factor + v_fudge_factor = 4 * fudge_factor self._check_equal( v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" ) + def run_test( + self, + score_mod: Callable, + dtype: torch.dtype = torch.float16, + Q_B: int = B, + Q_H: int = H, + Q_S: int = S, + Q_D: int = D, + KV_B: int = B, + KV_H: int = H, + KV_S: int = S, + KV_D: int = D, + ): + sdpa_partial = create_attention(score_mod) + compiled_sdpa = torch.compile(sdpa_partial) + q = torch.randn( + (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True + ) + k = torch.randn( + (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True + ) + v = torch.randn( + (KV_B, KV_H, KV_S, KV_D), dtype=dtype, device="cuda", requires_grad=True + ) + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) + q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) + golden_out = sdpa_partial(q_gold, k_gold, v_gold) + ref_out = sdpa_partial(q_ref, k_ref, v_ref) + compiled_out = compiled_sdpa(q, k, v) + + backward_grad = torch.randn((Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda") + + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + compiled_out.backward(backward_grad) + + self._check_out_and_grad( + golden_out, + ref_out, + compiled_out, + q_gold, + q_ref, + q, + k_gold, + k_ref, + k, + v_gold, + v_ref, + v, + ) + def run_dynamic_test( self, score_mod: Callable, @@ -211,24 +257,34 @@ def run_dynamic_test( ): sdpa_partial = create_attention(score_mod) # The first eager batch, shape (B, H, S, D) - q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out1 = sdpa_partial( - q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) - ) - ref_out1 = sdpa_partial(q1, k1, v1) + q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) + q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) + ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref) + golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold) + + backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out1.backward(backward_grad1.to(torch.float64)) + ref_out1.backward(backward_grad1) # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) - q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out2 = sdpa_partial( - q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) - ) - ref_out2 = sdpa_partial(q2, k2, v2) + q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) + q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) + ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref) + golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold) + + backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out2.backward(backward_grad2.to(torch.float64)) + ref_out2.backward(backward_grad2) # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. @@ -236,20 +292,41 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 - - self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) + compiled_out1.backward(backward_grad1) + + self._check_out_and_grad( + golden_out1, + ref_out1, + compiled_out1, + q1_gold, + q1_ref, + q1, + k1_gold, + k1_ref, + k1, + v1_gold, + v1_ref, + v1, + ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) + compiled_out2.backward(backward_grad2) + self._check_out_and_grad( + golden_out2, + ref_out2, + compiled_out2, + q2_gold, + q2_ref, + q2, + k2_gold, + k2_ref, + k2, + v2_gold, + v2_ref, + v2, + ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( @@ -343,6 +420,25 @@ def test_builtin_score_mods_automatic_dynamic( ): self.run_automatic_dynamic_test(score_mod, dtype) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("score_mod", test_score_mods) + def test_builtin_score_mods_different_seqlen( + self, dtype: torch.dtype, score_mod: Callable + ): + self.run_test( + score_mod, + dtype, + B, + H, + S // 2, # Seqlen of Q is different from seqlen of K/V + D, + B, + H, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, dtype: torch.dtype): @@ -680,11 +776,13 @@ def f(q, k, v): metrics.reset() f(q, k, v) accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize - logsumexp_bytes = 1 * 8 * 1024 * torch.float32.itemsize num_accesses = 4 # q, k, v reads, one output. - self.assertEqual( - metrics.num_bytes_accessed, accessed_bytes * num_accesses + logsumexp_bytes - ) + # TODO: Get rid of this fudge factor + # We need this fudge factor for now, since + # 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow) + # 2. We also write the extraneous logsumexp + num_accesses += 2 + self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) @supported_platform @skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571 @@ -719,14 +817,6 @@ def test_mixed_dtypes_fails(self): ): _flex_attention(query, key, value, _identity) - @supported_platform - def test_different_sequence_length_fails(self): - query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda") - key = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") - value = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") - with self.assertRaisesRegex(ValueError, "NYI: The target sequence length"): - _flex_attention(query, key, value, _identity) - @supported_platform @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune(self): @@ -960,10 +1050,10 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): fw_graph = self.fw_graph joint_graph = self.joint_graph - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = None getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0] getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1] getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index e53ab76036d6..c17d78f628a3 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -280,6 +280,7 @@ def dot_prod_attention( self._check_common(dot_prod_attention) self._check_common(checkpoint_wrapper(dot_prod_attention)) + @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 def _test_sdpa_rewriter_3(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool @@ -296,6 +297,7 @@ def dot_prod_attention( checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True ) + @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 def _test_sdpa_rewriter_4(self): def dot_prod_attention( query: torch.Tensor, diff --git a/test/inductor/test_graph_transform_observer.py b/test/inductor/test_graph_transform_observer.py new file mode 100644 index 000000000000..678458284c4f --- /dev/null +++ b/test/inductor/test_graph_transform_observer.py @@ -0,0 +1,72 @@ +# Owner(s): ["module: inductor"] +import glob +import math +import os +import shutil +import tempfile + +import torch +import torch._dynamo +import torch._inductor.config as inductor_config +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CUDA + +try: + import pydot # noqa: F401 + + HAS_PYDOT = True +except ImportError: + HAS_PYDOT = False + + +HAS_DOT = True if shutil.which("dot") is not None else False + + +class TestGraphTransformObserver(TestCase): + @skipIfRocm + def test_sdpa_rewriter(self): + if not ( + HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT + ): + return + + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .matmul(value) + ) + + log_url = tempfile.mkdtemp() + inductor_config.trace.log_url_for_graph_xform = log_url + inductor_config.force_disable_caches = True + compiled_fn = torch.compile(dot_prod_attention, fullgraph=True) + + tensor_shape = (4, 2, 16, 32) + q = torch.randn(tensor_shape, device="cuda") + k = torch.randn(tensor_shape, device="cuda") + v = torch.randn(tensor_shape, device="cuda") + compiled_fn(q, k, v) + + found_input_svg = False + found_output_svg = False + for filepath_object in glob.glob(log_url + "/*"): + if os.path.isfile(filepath_object): + if filepath_object.endswith("input_graph.svg"): + found_input_svg = True + elif filepath_object.endswith("output_graph.svg"): + found_output_svg = True + + self.assertTrue(found_input_svg) + self.assertTrue(found_output_svg) + + +if __name__ == "__main__": + if IS_LINUX: + run_tests() diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index b203a0f63e8b..96255c54147e 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -2,6 +2,7 @@ import collections import unittest +from typing import List import torch import torch._inductor @@ -22,6 +23,37 @@ requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +class TestHighwaySelfGating(torch.nn.Module): + def __init__( + self, + d_model: int, + size: int, + device="cuda", + ) -> None: + super().__init__() + self.size = size + self.device = device + self.gating_proj = torch.nn.Linear(d_model, d_model).to(self.device) + self.transform_proj = torch.nn.Linear(d_model, d_model).to(self.device) + self.gating_func = torch.nn.Sigmoid().to(self.device) + + self.d_model = d_model + + def forward( + self, + inputs: List[torch.Tensor], + ) -> torch.Tensor: + results = [] + for i in range(self.size): + x = inputs[i] + gating_proj = self.gating_proj(x) + transform_proj = self.transform_proj(x) + x = gating_proj * self.gating_func(transform_proj) + results.append(x) + + return torch.cat(results, dim=-1) + + class MyModule(torch.nn.Module): def __init__(self, z: int, has_bias: bool, device="cuda") -> None: super().__init__() @@ -221,6 +253,25 @@ def forward(self, x): return torch.cat(div, dim=1) +class TestPoitwiseOpsPostGrad(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + def forward(self, x): + inputs = torch.ops.aten.split(x.to(self.device), 500, dim=1) + x_split = torch.ops.aten.split(inputs[0].to(self.device), 50, dim=1) + y_split = torch.ops.aten.split(inputs[1].to(self.device), 50, dim=1) + tanh_1 = [torch.ops.aten.tanh(x_split[i]) for i in range(len(x_split))] + tanh_2 = [torch.ops.aten.tanh(y_split[i]) for i in range(len(y_split))] + sigmoid_1 = [torch.ops.aten.sigmoid(tanh_1[i]) for i in range(len(tanh_1))] + sigmoid_2 = [torch.ops.aten.sigmoid(tanh_2[i]) for i in range(len(tanh_2))] + relu_1 = [torch.ops.aten.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))] + relu_2 = [torch.ops.aten.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))] + add = [torch.ops.aten.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))] + return torch.cat(add, dim=1) + + @requires_cuda @torch._inductor.config.patch( pre_grad_fusion_options={ @@ -400,6 +451,75 @@ def test_pointwise_op_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "batch_aten_relu": {}, + "batch_aten_sigmoid": {}, + "batch_aten_tanh": {}, + "unbind_stack_aten_pass": {}, + }, + ) + def test_pointwise_op_fusion_post_grad(self): + counters.clear() + module = TestPoitwiseOpsPostGrad("cuda") + input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + self.compare_pred(module, traced, input) + self.assertEqual(counters["inductor"]["batch_aten_tanh"], 1) + self.assertEqual(counters["inductor"]["batch_aten_relu"], 1) + self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) + self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 2) + ref.sum().backward() + res.sum().backward() + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "batch_linear_post_grad": { + "shape_broadcast_batch_linear": True, + "fuse_nodes_with_same_users": True, + }, + "batch_aten_mul": {"fuse_nodes_with_same_parent": False}, + "batch_aten_sigmoid": {"fuse_nodes_with_same_parent": True}, + "batch_aten_add": {"fuse_nodes_with_same_parent": True}, + "normalization_aten_pass": {}, + "unbind_stack_aten_pass": {}, + }, + ) + def test_gate_fusion_post_grad(self): + counters.clear() + size = 20 + module = TestHighwaySelfGating(d_model=10, size=size) + input = [ + [ + torch.randn(10, 10, requires_grad=True, device="cuda") + for i in range(size) + ] + ] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + self.compare_pred(module, traced, input) + self.assertEqual(counters["inductor"]["batch_linear_post_grad"], 2) + self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) + self.assertEqual(counters["inductor"]["batch_aten_mul"], 1) + self.assertEqual(counters["inductor"]["batch_aten_add"], 2) + self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1) + self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 5) + ref.sum().backward() + res.sum().backward() + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + class TestBMMFusionModule(torch.nn.Module): def __init__(self): diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 52227c20d1ff..9b923bd1981d 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -3,6 +3,7 @@ import unittest import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import HalideCodeCache from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta from torch._inductor.test_case import run_tests, TestCase diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 299a619f9cd6..19a736160908 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -1,17 +1,32 @@ # Owner(s): ["module: inductor"] +import os +import unittest + import sympy +import torch + from torch._inductor.codegen.cpp import cexpr from torch._inductor.codegen.triton import texpr from torch._inductor.codegen.wrapper import pexpr +from torch._inductor.runtime.runtime_utils import do_bench_gpu from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_triton_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) -from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.utils._sympy.functions import ( + FloorDiv, + ModularIndexing, + RoundDecimal, + RoundToInt, +) + +DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" class TestIndexingSimplification(InductorTestCase): @@ -159,6 +174,73 @@ def test_indexing_join(self): self.assertEqual(simplified, FloorDiv(i0, 3)) self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) + def test_modular_indexing_pairs_merged(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + a = 1024 + b = 32 + expr1 = ModularIndexing(x, 1, a) + expr2 = ModularIndexing(expr1, 1, b) + expected = ModularIndexing(x, 1, b) + + actual = sizevars.combine_modular_indexing_pairs(expr2) + self.assertEqual(expected, actual) + self.assertNotEqual(expr2, actual) + + def test_modular_indexing_pairs_not_merged(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + a = 1024 + b = 3 # pick a 'b' that we can not merge + expr1 = ModularIndexing(x, 1, a) + expr2 = ModularIndexing(expr1, 1, b) + + actual = sizevars.combine_modular_indexing_pairs(expr2) + self.assertEqual(expr2, actual) + self.assertNotEqual(ModularIndexing(x, 1, b), actual) + + def test_expand_floor_div_skipped(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + y = sympy.Symbol("y", integer=True, positive=True) + + expr = FloorDiv(x, 2) + FloorDiv(y, 3) + # The expression can not be simplified since there are multiple + # FloorDiv. We return False in that case + self.assertFalse(sizevars.expand_floor_div(expr)) + + def test_expand_floor_div_applied(self): + sizevars = SizeVarAllocator() + x = sympy.Symbol("x", integer=True, positive=True) + y = sympy.Symbol("y", integer=True, positive=True) + + expr = x * 5 + FloorDiv(y, 3) + actual, denominator = sizevars.expand_floor_div(expr) + self.assertNotEqual(expr, actual) + expected = FloorDiv(x * 15 + y, 3) + self.assertEqual(expected, FloorDiv(actual, denominator)) + + @unittest.skipUnless(HAS_CUDA, "Need GPU for this test") + def test_int8_unpack(self): + @torch.compile + def f(x): + first_elements = x >> 4 + second_elements = x & 15 + unpacked = torch.stack([first_elements, second_elements], dim=-1).view( + *x.size()[:-1], -1 + ) + return unpacked * 2 + + x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device="cuda") + + triton_code = run_and_get_triton_code(f, x) + # Make sure the 2 load uses simpified indexing rather than something like + # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), + self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) + if DO_PERF_TEST: + ms = do_bench_gpu(lambda: f(x)) + print(f"{ms=:.03f}") + class ExprPrinterTests(InductorTestCase): def test_print_pow(self): @@ -168,21 +250,11 @@ def test_print_pow(self): common_cases = [ # expr, result - # Test exprs. - ( - s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), - lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", - ), - (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow - ( - sympy.Pow(s1 + s2, -3), - lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", - ), ] gpu_cases = common_cases + [ @@ -231,12 +303,10 @@ def test_print_ceil(self): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = Round(sympy.Symbol("x", integer=True) / 2) + expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline( - texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" - ) + self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -251,45 +321,18 @@ def test_print_round_decimal(self, ndigits): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) - expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) - if ndigits >= 0: - for do_print in [pexpr, cexpr, texpr]: - self.assertEqual(do_print(expr), "x") - else: - self.assertEqual(pexpr(expr), f"round(x, {ndigits})") - for do_print in [cexpr, texpr]: - with self.assertRaisesRegex( - ValueError, "only non-negative ndigits are currently supported" - ): - do_print(expr) - def test_print_floor_div(self): - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.Symbol("s2", integer=integer) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - if integer: - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - else: - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast(s2))", - ) - - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - if integer: - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") - else: - self.assertEqual(pexpr(expr), "(s1 // (-1))") - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", - ) + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.Symbol("s2", integer=True) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") + + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") def test_print_Min_Max(self): cases = ( @@ -315,7 +358,6 @@ def test_print_Min_Max(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA if HAS_CPU or HAS_CUDA: run_tests("sympy") diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 7d1688b366c4..4b6c04403002 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -338,6 +338,9 @@ def foo(mod, inp): ).run(code[0]) self.assertEqual(out_eager, out) + # With inlining of inbuilt nn modules, Dynamo traces the innards of inbuilt + # module and does not modify the eager module. + @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) def test_error_on_eager(self): mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 23804e08f23f..ffe0300d8aad 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -6,6 +6,7 @@ from unittest.mock import patch import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py new file mode 100644 index 000000000000..5261c2325834 --- /dev/null +++ b/test/inductor/test_loop_ordering.py @@ -0,0 +1,59 @@ +# Owner(s): ["module: inductor"] + +import torch +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import same +from torch._inductor import config as inductor_config, metrics +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.inductor_utils import HAS_CUDA + +if HAS_CUDA: + torch.set_default_device("cuda") + + +@inductor_config.patch( + { + "benchmark_kernel": True, + "triton.unique_kernel_names": True, + } +) +class LoopOrderingTest(TestCase): + def do_acc_test(self, f, *args): + expect = f(*args) + actual = torch.compile(f)(*args) + self.assertTrue(same(expect, actual, tol=1e-3)) + + def test_for_reordering_reindex(self): + """ + ComputedBuffer.iter_reoredering_reindex can cause some fusion + opportunitiies being skipped. + + In this test case, Inductor generates 2 triton kernels before. + By removing ComputedBuffer.iter_reoredering_reindex, we can fuse those + two kernels into a single one. + """ + + def f(x, y): + """ + Add a matmul since inductor may force layout for output. + """ + return (x.sum(dim=-1) + 1) @ y + + A, B = 20, 30 + # Make the first 2 dimension not able to merge on purpose so that + # ComputedBuffer.iter_reoredering_reindex will be updated. + x = rand_strided([A, A, B], [B, B * A + 300, 1], device="cuda") + y = torch.randn(A, A) + + self.do_acc_test(f, x, y) + self.assertEqual(1, metrics.generated_kernel_count) + expected_num_bytes = 0 + expected_num_bytes += A * A * B + A * A # for the fused reduction + expected_num_bytes += A * A * 3 # for matmul + expected_num_bytes *= x.itemsize + self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) + + +if __name__ == "__main__": + if HAS_CUDA: + run_tests() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index cef7d610ee4d..176f0dda606d 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -9,7 +9,7 @@ from torch import multiprocessing as mp, nn from torch._dynamo import reset from torch._dynamo.exc import BackendCompilerFailed -from torch._dynamo.testing import reset_rng_state +from torch._dynamo.testing import rand_strided, reset_rng_state from torch._inductor import config from torch._inductor.autotune_process import ( BenchmarkRequest, @@ -267,9 +267,9 @@ def put(self, filename, data): num_put += 1 cache_module = ( - "triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" + "triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" if config.is_fbcode() - else "triton.runtime.cache.RedisRemoteCacheBackend" + else "torch._inductor.remote_cache.RedisRemoteCacheBackend" ) with config.patch( @@ -674,12 +674,10 @@ def test_non_contiguous_input_mm(self): Make sure the triton template can work with non-contiguous inputs without crash. Check https://github.com/pytorch/pytorch/issues/125437 for more details. """ - x = torch.empty_strided( + x = rand_strided( (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" ) - y = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") @torch.compile(mode="max-autotune") def f(x, y): @@ -687,16 +685,14 @@ def f(x, y): ref = x @ y act = f(x, y) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) def test_non_contiguous_input_addmm(self): - b = torch.empty((768), dtype=torch.bfloat16, device="cuda") - x = torch.empty_strided( + b = torch.randn((768), dtype=torch.bfloat16, device="cuda") + x = rand_strided( (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" ) - y = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda") @torch.compile(mode="max-autotune") def f(x, y): @@ -704,13 +700,13 @@ def f(x, y): ref = torch.addmm(b, x, y) act = f(x, y) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) def test_non_contiguous_input_bmm(self): - x = torch.empty_strided( + x = rand_strided( (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device="cuda" ) - y = torch.empty_strided( + y = rand_strided( (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device="cuda" ) @@ -720,22 +716,14 @@ def f(x, y): ref = torch.bmm(x, y) act = f(x, y) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) def test_non_contiguous_input_mm_plus_mm(self): - x1 = torch.empty_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" - ) - y1 = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + x1 = rand_strided((50257, 32768), (1, 50304), device="cuda") + y1 = rand_strided((32768, 768), (768, 1), device="cuda") - x2 = torch.empty_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device="cuda" - ) - y2 = torch.empty_strided( - (32768, 768), (768, 1), dtype=torch.bfloat16, device="cuda" - ) + x2 = rand_strided((50257, 32768), (1, 50304), device="cuda") + y2 = rand_strided((32768, 768), (768, 1), device="cuda") @torch.compile(mode="max-autotune") def f(x1, y1, x2, y2): @@ -743,7 +731,7 @@ def f(x1, y1, x2, y2): ref = x1 @ y1 + x2 @ y2 act = f(x1, y1, x2, y2) - self.assertTrue(torch.allclose(ref, act, atol=4 * 1e-3, rtol=4 * 1e-3)) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) @config.patch( max_autotune=True, diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 1ec1dd9f89e9..78c7086972eb 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -56,7 +56,7 @@ def test_python_wrapper(self): ).check_next( "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" ).check( - "buf1 = alloc_from_pool(pool1, align((4*s0) + (4*s0*((-1) + s0)))," + "buf1 = alloc_from_pool(pool1, align(4*(s0*s0))," ).run( code ) @@ -74,7 +74,7 @@ def test_cpp_wrapper(self): ).check_next( "auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});" ).check( - "auto buf1 = alloc_from_pool(pool1, align((4L*s0) + (4L*s0*((-1L) + s0)))," + "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast(s0*s0)))," ).run( code ) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 756de35df84c..0490c3bcb9f3 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -10,7 +10,7 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._export import capture_pre_autograd_graph -from torch._inductor import config +from torch._inductor import config, metrics from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantize_pt2e import ( @@ -83,6 +83,36 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer +def cal_conv_generated_kernel_number(mod, input, dtype): + # this function is to decide how many kernels are generated + # while testing conv2d/3d/deconv2d + # the assumption is: + # (1) There will be a to_dtype kernel for input for lp + # (2) inductor always use channe_last format, there will + # be a to_channel_last format for input + # (3) to_dtype and to_channel_last for input can be fused + # (4) inductor always get channel last format from mkldnn_conv_pointwise(binary), + # and force the output to have same stride with eager. + # So there will be a to_contiguous for output if eager output is contiguouse + mod = copy.deepcopy(mod) + input = input.clone() + if dtype == torch.float32: + maybe_autocast = contextlib.nullcontext() + else: + maybe_autocast = torch.cpu.amp.autocast(dtype=dtype) + with torch.no_grad(), maybe_autocast: + output = mod(input) + input_kernel, output_kernel = 0, 0 + if ( + input.is_contiguous(memory_format=torch.contiguous_format) + or dtype != torch.float32 + ): + input_kernel = 1 + if output.is_contiguous(memory_format=torch.contiguous_format): + output_kernel = 1 + return input_kernel + output_kernel + + @config.patch({"freezing": True}) class TestPatternMatcherBase(TestCase): def _check_unary_is_decomposed(self, unary_fn): @@ -203,6 +233,7 @@ def _test_code_common( rtol=1.3e-6, check_quantization=False, check_dynamic=None, + num_include_ops=None, ): with torch.no_grad(): clone_inputs = self._clone_inputs(inputs) @@ -215,6 +246,12 @@ def _test_code_common( ) for op in include_ops: self.assertIn(op, source_code) + if num_include_ops is not None: + assert len(include_ops) == len(num_include_ops) + for i in range(len(include_ops)): + self.assertEqual( + source_code.count(include_ops[i]), num_include_ops[i] + ) for op in exclude_ops: self.assertNotIn(op, source_code) if check_dynamic is not None: @@ -264,6 +301,7 @@ def forward(self, x): memory_format, dtype, ) in options: + metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -284,10 +322,18 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) + generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv2d_unary_cpu(self): self._test_conv_unary_cpu_base(dim=4) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv3d_unary_cpu(self): self._test_conv_unary_cpu_base(dim=5) @@ -321,6 +367,7 @@ def forward(self, x): dtypes.append(torch.float16) options = itertools.product(unary_list, [True, False], dtypes) for unary_fn, bias, dtype in options: + metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 mod = mod @@ -335,6 +382,8 @@ def forward(self, x): self._test_common( mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype ) + # only generated 1 kernel for "to" + self.assertEqual(metrics.generated_kernel_count, 1) @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self): @@ -354,6 +403,42 @@ def forward(self, x): matcher_nodes = 1 self._test_common(mod, (v,), matcher_count, matcher_nodes) + def test_linear_add_bias(self): + class M(torch.nn.Module): + def __init__(self, dtype, unary_fn): + super().__init__() + self.linear = torch.nn.Linear(10, 64, bias=False) + self.bias = torch.randn(64).to(dtype=dtype) + self.unary_fn = unary_fn + + def forward(self, x): + x = self.linear(x) + self.bias + return self.unary_fn(x) + + dtypes = [] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) + options = itertools.product(unary_list, dtypes) + for unary_fn, dtype in options: + metrics.reset() + mod = M(dtype, unary_fn).eval() + v = torch.randn(2, 10) + matcher_count = 3 + # Add 1 for weight packing pass, add 2 for bias folding pass. + matcher_nodes = unary_list[unary_fn] + 3 + if self._check_unary_is_decomposed(unary_fn): + # Has extra dtype conversion nodes for autocast. + matcher_nodes += 2 + self._test_common( + mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype + ) + self.assertEqual(metrics.generated_kernel_count, 1) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv_transpose2d_unary(self): class M(torch.nn.Module): def __init__( @@ -386,6 +471,7 @@ def forward(self, x): ) for unary_fn, memory_format, dtype in options: + metrics.reset() x_shape = (1, 3, 28, 28) mod = M(unary_fn).eval() @@ -401,6 +487,8 @@ def forward(self, x): # Has extra dtype conversion nodes for autocast. match_nodes += 2 self._test_common(mod, (v,), 2, match_nodes, check_autocast=dtype) + generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) def _test_conv_binary_base(self, dim=4): assert dim == 4 or dim == 5 @@ -430,19 +518,29 @@ def forward(self, x): else: return self.binary_fn(x1, x2) + dtypes = [ + torch.float, + ] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d test_memory_format = [torch.contiguous_format, cl_format] options = itertools.product( binary_list, [True, False], test_memory_format, + dtypes, ) for ( binary_fn, has_relu, memory_format, + dtype, ) in options: + metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) else: @@ -457,11 +555,21 @@ def forward(self, x): match_nodes = binary_list[binary_fn][1] if has_relu: match_nodes += 1 - self._test_common(mod, (v,), match_count, match_nodes + 2) + self._test_common( + mod, (v,), match_count, match_nodes + 2, check_autocast=dtype + ) + generated_kernel_count = cal_conv_generated_kernel_number(mod, v, dtype) + self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv2d_binary(self): self._test_conv_binary_base(dim=4) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm def test_conv3d_binary(self): self._test_conv_binary_base(dim=5) @@ -489,7 +597,7 @@ def forward(self, x, y): ) out_feature = 30 for binary_fn, input_shape, bias, dtype in options: - torch._dynamo.reset() + metrics.reset() # addmm(mm) + (linear+add) match_count = 2 match_nodes = 3 @@ -498,13 +606,20 @@ def forward(self, x, y): # view + linear + view(joint_graph+freeze pass) match_count = match_count + 5 if is_inplace else match_count + 3 match_nodes = match_nodes + 7 if is_inplace else match_nodes + 5 - mod = M(binary_fn, input_shape[-1], out_feature, bias).to(dtype).eval() - v = torch.randn(input_shape).to(dtype) + mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() + v = torch.randn(input_shape) other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) - mod_c = torch.compile(mod) - out, code = run_and_get_code(mod_c, v, other) - self.assertEqual(out, mod(v, other), rtol=1e-2, atol=1e-2) - # TODO - assert fusions work code + self._test_common( + mod, + ( + v, + other, + ), + match_count, + match_nodes, + check_autocast=dtype, + ) + self.assertEqual(metrics.generated_kernel_count, 1) def test_multi_linear_share_same_input(self): # llama pattern. @@ -1700,6 +1815,32 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, is_qat=is_qat, ) + if torch._inductor.config.cpp_wrapper: + # For CPP wrapper + self._test_code_common( + mod, + (v,), + [ + "op_qlinear_pointwise.call", + "op_qlinear_pointwise_binary.call", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) + else: + # For python wrapper + self._test_code_common( + mod, + (v,), + [ + "torch.ops.onednn.qlinear_pointwise.default", + "torch.ops.onednn.qlinear_pointwise.binary", + ], + [], + check_quantization=True, + num_include_ops=[2, 2], + ) @skipIfNoDynamoSupport @skipIfNoONEDNN diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index e1bb0ad36d0b..3350e8e895f3 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,6 +1,4 @@ # Owner(s): ["module: functorch"] -import unittest - import torch import torch._dynamo import torch._functorch @@ -8,30 +6,14 @@ import torch._inductor.decomposition from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_utils import ( - find_library_location, - IS_FBCODE, - IS_MACOS, - IS_SANDCASTLE, - IS_WINDOWS, -) + +from torch.testing._internal.torchbind_impls import init_torchbind_implementations class TestTorchbind(TestCase): def setUp(self): super().setUp() - if IS_MACOS: - raise unittest.SkipTest("non-portable load_library call used in test") - elif IS_SANDCASTLE or IS_FBCODE: - torch.ops.load_library( - "//caffe2/test/cpp/jit:test_custom_class_registrations" - ) - elif IS_WINDOWS: - lib_file_path = find_library_location("torchbind_test.dll") - torch.ops.load_library(str(lib_file_path)) - else: - lib_file_path = find_library_location("libtorchbind_test.so") - torch.ops.load_library(str(lib_file_path)) + init_torchbind_implementations() def get_exported_model(self): """ diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4f65504d5696..f9d736bcd413 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -80,6 +80,7 @@ IS_X86, parametrize, serialTest, + skipIfNNModuleInlined, skipIfRocm, skipIfXpu, subtest, @@ -1306,6 +1307,15 @@ def reflection_pad_left(x, n): expect = reflection_pad_left(x, 3) self.assertEqual(expect, actual) + def test_index_propagation_device_assert_masked(self): + def fn(a): + idx = torch.arange(a.size(0), device=a.device) + padded_idx = torch.constant_pad_nd(idx, (1050, 0)) + padded_idx = torch.where(padded_idx >= 0, padded_idx, padded_idx) + return a[padded_idx] + + self.common(fn, (torch.randn(1024),)) + @skipIfRocm @config.patch(debug_index_asserts=False) def test_neg_index(self): @@ -2891,6 +2901,24 @@ def fn(a, b, scale, bias): check_lowp=True, ) + @skipIfPy312 # segfaults + @config.patch(force_mixed_mm=True) + def test_mixed_mm3(self): + def fn(a, b): + return torch.mm(a, b.to(a.dtype)) + + # (256, 256) @ (256, 256) so different block sizes are tried out during autotuning + self.common( + fn, + ( + torch.randn(256, 256), + torch.randint(-128, 127, (256, 256), dtype=torch.int8), + ), + check_lowp=True, + rtol=0.01, + atol=0.1, + ) + @with_tf32_off @config.patch(use_mixed_mm=True) def test_uint4x2_mixed_mm(self): @@ -3945,6 +3973,7 @@ def forward(self, x): self.assertEqual(eager_delta, compile_delta) + @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/128198") def test_buffer_batch_norm(self): class MyModel(torch.nn.Module): def __init__(self): @@ -4283,6 +4312,19 @@ def fn(x): (torch.randn([1, 2, 4, 8]),), ) + def test_repeat_as_strided(self): + # Reproducer for #127474 + + def fn(x): + view_size = (3, 2) + full = x.repeat((3, 2)) + view = torch.as_strided(full, view_size, full.stride()) + result = view + view + + return result + + self.common(fn, (torch.randn(1, 1),)) + def test_repeat_interleave(self): def fn(x): return ( @@ -5460,6 +5502,14 @@ def fn(a): for dtype in all_types(): self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),)) + def test_full_boolean(self): + def fn(n): + x = torch.full((1,), n >= 1024, device=self.device) + return x, x + 1 + + self.common(fn, (1024,)) + self.common(fn, (1023,)) + def test_index1(self): def fn(a, b, c): return aten.index(a, [b, c]) @@ -6621,6 +6671,11 @@ def fn(x): self.common(fn, [torch.randn(64, 64)]) + def test_new_cpp_build_logical(self): + from torch._inductor.codecache import validate_new_cpp_commands + + validate_new_cpp_commands() + def test_as_strided(self): def fn(x): return ( @@ -7769,6 +7824,95 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 0) + def test_avg_pool3d_backward(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [2, 2, 2], + [2, 2, 2], + [0, 0, 0], + True, + False, + None, + ) + + self.common( + fn, + [ + torch.randn([2, 4, 7, 7, 7]), + torch.randn([2, 4, 14, 14, 14]), + ], + ) + + def test_avg_pool3d_backward2(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [3, 3, 3], + [1, 1, 1], + [1, 1, 1], + True, + False, + None, + ) + + self.common( + fn, + [ + torch.randn([1, 1, 20, 20, 15]), + torch.randn([1, 1, 20, 20, 15]), + ], + ) + + def test_avg_pool3d_backward3(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [1, 1, 1], + [2, 2, 2], + [0, 0, 0], + False, + False, + None, + ) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + [ + torch.randn([1, 2016, 11, 11, 11]), + torch.randn([1, 2016, 21, 21, 21]), + ], + ) + assertGeneratedKernelCountEqual(self, 1) + + def test_avg_pool3d_backward4(self): + def fn(a, b): + return aten.avg_pool3d_backward( + a, + b, + [13, 13, 13], + [1, 1, 1], + [0, 0, 0], + True, + False, + None, + ) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + [ + torch.randn([1, 16, 12, 12, 12]), + torch.randn([1, 16, 24, 24, 24]), + ], + check_lowp=False, + ) + assertGeneratedKernelCountEqual(self, 0) + @config.patch(search_autotune_cache=False) def test_mm_views(self): def fn(a, b): @@ -9168,7 +9312,6 @@ def func(arg0_1): graph = GraphLowering( gm, shape_env=shape_env, - num_static_inputs=0, ) with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): graph.run(*example_inputs) @@ -9788,6 +9931,7 @@ def bar_meta(x): bar_cuda, bar_xpu, bar_meta, + tags=[torch._C.Tag.needs_fixed_stride_order], ) def fn(x): @@ -9850,69 +9994,13 @@ def baz_meta(x): baz_cuda, baz_xpu, baz_meta, + tags=[torch._C.Tag.needs_fixed_stride_order], ) with torch.no_grad(): net = torch.compile(model) out = net(input_t) - @requires_gpu() - @config.patch(implicit_fallbacks=True) - def test_needs_fixed_stride_order(self): - with torch.library._scoped_library("prims", "FRAGMENT") as prims_lib: - with torch.library._scoped_library("custom", "FRAGMENT") as custom_lib: - strides = [] - - def foo_impl(x): - strides.append(x.stride()) - return x.clone() - - def foo_meta(x): - return x.clone() - - all_ops = [] - for ( - needs_fixed_stride_order, - does_not_need_fixed_stride_order, - ) in itertools.product([True, False], [True, False]): - tags = [] - if needs_fixed_stride_order: - tags.append(torch.Tag.needs_fixed_stride_order) - if does_not_need_fixed_stride_order: - tags.append(torch.Tag.does_not_need_fixed_stride_order) - name = f"foo_{int(needs_fixed_stride_order)}{int(does_not_need_fixed_stride_order)}" - for ns, lib in {"custom": custom_lib, "prims": prims_lib}.items(): - all_ops.append(ns + "::" + name) - lib.define(f"{name}(Tensor x) -> Tensor", tags=tags) - lib.impl(name, foo_impl, "CompositeExplicitAutograd") - lib.impl(name, foo_meta, "Meta") - - assert len(all_ops) == 8 - expect_contig_strides = { - "custom::foo_01", - "prims::foo_00", - "prims::foo_01", - } - print(all_ops) - - for qualname in all_ops: - ns, name = qualname.split("::") - op = getattr(getattr(torch.ops, ns), name) - - @torch.compile(fullgraph=True) - def f(x): - y = x.t().contiguous().t() - y = y.sin() - return op(y) - - x = torch.randn(24, 24, device=self.device) - f(x) - stride = strides[-1] - if qualname in expect_contig_strides: - self.assertEqual(stride, (24, 1)) - else: - self.assertEqual(stride, (1, 24)) - def test_buffer_use_after_remove(self): # https://github.com/pytorch/pytorch/issues/102857 @@ -10319,6 +10407,23 @@ def test_generate_rand_fp8(self): t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn) self.assertTrue(t.dtype is torch.float8_e4m3fn) + def test_large_grid(self): + # https://github.com/pytorch/pytorch/issues/123210 + def fn(primals_5): + view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) + primals_5 = None + permute = torch.ops.aten.permute.default(view, [0, 2, 1]) + clone = torch.ops.aten.clone.default( + permute, memory_format=torch.contiguous_format + ) + return clone + + s0 = 16777472 + s1 = 8 + compiled_fn = torch._dynamo.optimize()(fn) + actual = compiled_fn(torch.ones(s0, s1)) + self.assertTrue((actual == 1).all()) + @dataclasses.dataclass class TestFailure: @@ -10433,7 +10538,6 @@ def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: cxt = TritonCodeGenTests.NoOpCompilerBackend() torch._dynamo.optimize(backend=cxt.noop_backend)(fn)(*args) graph = GraphLowering(cxt.model) - graph.num_static_inputs = 0 kernels = [] with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): graph.run(*(cxt.example_args)) @@ -10481,7 +10585,6 @@ def fn(a: torch.Tensor) -> torch.Tensor: self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1)) torch._dynamo.reset() - @expectedFailureXPU @config.patch(assume_aligned_inputs=False) def test_codegen_config_option_dont_assume_alignment(self): def fn(x: torch.Tensor) -> torch.Tensor: @@ -10948,7 +11051,9 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_LayerNorm_ReLU", + "triton_poi_fused_layer_norm_relu" + if torch._dynamo.config.inline_inbuilt_nn_modules + else "triton_poi_fused_LayerNorm_ReLU", (torch.randn(4, 4, device=GPU_TYPE),), ), ] diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index b1ccc49df499..bd036810d4c1 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -121,6 +121,7 @@ def run(*ex, **kwargs): "test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_expand_dynamic_shapes": TestFailure(("cpu",)), + "test_full_boolean_dynamic_shapes": TestFailure(("cpu",)), "test_glu_dynamic_shapes": TestFailure(("cpu",)), "test_isinf2_dynamic_shapes": TestFailure(("cpu",)), "test_linspace1_dynamic_shapes": TestFailure(("cpu",)), @@ -135,6 +136,7 @@ def run(*ex, **kwargs): "test_zeros_dynamic_shapes": TestFailure(("cpu",)), "test_uint_dynamic_shapes": TestFailure(("cpu",)), "test_issue102546_dynamic_shapes": TestFailure(("cpu",)), + "test_repeat_as_strided_dynamic_shapes": TestFailure(("cpu",)), # # Failed to find for loop/triton kernel: # @@ -145,6 +147,7 @@ def run(*ex, **kwargs): "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")), "test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")), diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 8513e928c412..5608adc94e2f 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -3,10 +3,12 @@ import importlib import math +import operator import os import sys import unittest from functools import partial +from typing import List import torch import torch.library @@ -368,6 +370,47 @@ def f(x): arg = torch.tensor(5, device=device) self.assertEqual(f(arg), cf(arg)) + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + @torch._inductor.config.patch(implicit_fallbacks=True) + def test_unbacked_save_for_backwards(self, device) -> None: + @torch.library.custom_op("_test::_cat", mutates_args=()) + def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor: + return t * t.new_ones([sum(ds)]) + + @torch.library.register_fake("_test::_cat") + def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor: + [torch._check_is_size(d) for d in ds] + return t.new_empty([sum(ds)]) + + def _cat_setup_context(ctx, inputs, output): + pass + + def _cat_backward(ctx, grad): + return grad.sum(), None + + torch.library.register_autograd( + "_test::_cat", + _cat_backward, + setup_context=_cat_setup_context, + ) + + def fn(t, sizes): + r = torch.ops._test._cat(t, sizes.tolist()) + return r * t + + t = torch.randn((), requires_grad=True, device=device) + sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu") + out = fn(t, sizes) + out.sum().backward() + expect = t.grad + t.grad = None + torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)( + t, sizes + ).sum().backward() + self.assertEqual(t.grad, expect) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_unbacked_reduction(self, device): expect_fail = device == "cpu" and not IS_ARM64 @@ -649,6 +692,33 @@ def fn(a): actual = cfn(5) self.assertEqual(expect, actual) + def test_interpolate_ceil_eq(self, device): + ceiling = math.ceil + IntTrueDiv = operator.truediv + + def fn(t): + s0, s2, s3 = t.size() + x = torch.zeros( + ( + s0, + 2048, + ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)), + ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)), + ), + dtype=torch.bfloat16, + ) + return torch.nn.functional.interpolate( + x, + scale_factor=2, + mode="nearest", + ) + + cfn = self.compile_fn(fn) + arg = torch.randn(4, 16, 18) + expect = fn(arg) + actual = cfn(arg) + self.assertEqual(expect, actual) + def test_full_recompiles(self, device): def fn(x): _, L = x.shape diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 2a7995de4e0e..1d9c733a7302 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -440,6 +440,8 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "triu", "cummax", "cummin", + "nextafter", + "gather", "_chunk_cat", "constant_pad_nd", } diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index d8c74c0a3841..549903f47ce4 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -4,9 +4,8 @@ import unittest import torch -from torch.testing._internal.common_device_type import expectedFailureXPU -from torch.testing._internal.common_utils import IS_LINUX +from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU try: @@ -38,7 +37,7 @@ def test_triton_config(self): def _test_artificial_zgrid(self): def forward(primals_1, primals_2, primals_5): - view = torch.ops.aten.reshape.default(primals_5, [-1, 4, 128]) + view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) primals_5 = None permute = torch.ops.aten.permute.default(view, [0, 2, 1]) clone = torch.ops.aten.clone.default( @@ -53,8 +52,8 @@ def forward(primals_1, primals_2, primals_5): primals_2 = None return addmm - s0 = 727828 - s1 = 512 + s0 = 16777472 + s1 = 8 args = [ torch.rand([2, 4], device=GPU_TYPE), @@ -73,12 +72,11 @@ def forward(primals_1, primals_2, primals_5): ] self.assertEqual(forward(*args), foo_c(*args)) - @unittest.skip("https://github.com/pytorch/pytorch/issues/123210") - @expectedFailureXPU + @skipIfXpu def test_artificial_zgrid(self): self._test_artificial_zgrid() - @expectedFailureXPU + @skipIfXpu @config.patch("cpp_wrapper", True) def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index accab8beae6b..af788de0ab0c 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -31,11 +31,10 @@ fast_dividef as my_fast_dividef, ) - -# Define shared triton constants here. -CONSTANT_C = 4 -STRING_CONSTANT_C = "CONSTANT_C" -BOOL_CONSTANT_C = True + # Define shared triton constants here. + CONSTANT_C: tl.constexpr = 4 + STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C" + BOOL_CONSTANT_C: tl.constexpr = True class KernelTests(torch._inductor.test_case.TestCase): @@ -586,6 +585,7 @@ def call_triton( self.assertEqual(int_result, resulti) @requires_cuda + @skipIfRocm def test_triton_kernel_constants(self): @triton.jit def mulC_kernel( @@ -600,7 +600,7 @@ def mulC_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) - if CONSTANT_NAME.value == STRING_CONSTANT_C: + if CONSTANT_NAME == STRING_CONSTANT_C: output = CONSTANT_C * x if BOOL_CONSTANT_C: output *= CONSTANT_C @@ -1550,6 +1550,23 @@ def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): expected, ) + @requires_cuda + @skipIfRocm + def test_triton_kernel_inference_mode(self): + def f(x, y, out): + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4) + + with torch.inference_mode(): + x = torch.ones(32, device="cuda") + y = torch.ones(32, device="cuda") + out_ref = torch.zeros_like(x) + out_test = torch.zeros_like(x) + f(x, y, out_ref) + torch.compile(f)(x, y, out_test) + self.assertEqual(out_ref, out_test) + @make_mutation_test def test_cumsum(): @triton.jit diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index 24ba84ebf86a..f0d3ad829d45 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -4,6 +4,7 @@ import sys import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index cd9540382cc1..1bb9f7e82572 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -3,14 +3,12 @@ import unittest from copy import deepcopy from itertools import product -from tempfile import NamedTemporaryFile import torch import torch.nn as nn from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - IS_WINDOWS, parametrize, run_tests, skipIfCrossRef, @@ -206,33 +204,6 @@ def hook_fn( model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) model.load_state_dict(model.state_dict(), strict=True) - @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") - @swap([True, False]) - def test_register_state_dict_pre_hook_backward_compat(self): - called = False - - def my_state_dict_pre_hook(*args, **kwargs): - nonlocal called - called = True - - m = nn.Linear(1, 1) - self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) - delattr(m, "_state_dict_pre_hooks") - # Save and load, ensure we can still call state_dict - # without running into issues. - with NamedTemporaryFile() as f: - # Note that torch.save / torch.load is not recommended - # to save / load modules. - torch.save(m, f.name) - m = torch.load(f.name) - - # Ensure we can run state_dict without issues - _ = m.state_dict() - self.assertFalse(called) - m.register_state_dict_pre_hook(my_state_dict_pre_hook) - _ = m.state_dict() - self.assertTrue(called) - # fails swapping as LSTM installs weak references on the parameters @swap([False]) @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index f76837660302..dc4bead78242 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -21,6 +21,7 @@ parametrize as parametrize_test, run_tests, skipIfTorchDynamo, + swap, TestCase, ) @@ -549,6 +550,7 @@ def _hook_to_pickle(*args, **kwargs): class TestStateDictHooks(TestCase): + @swap([True, False]) def test_load_state_dict_pre_hook(self): m = nn.Linear(10, 10) m_state_dict = m.state_dict() @@ -613,6 +615,7 @@ def test_pickled_hook(self): m._register_load_state_dict_pre_hook(_hook_to_pickle, True) pickle.loads(pickle.dumps(m)) + @swap([True, False]) def test_load_state_dict_module_pre_hook(self): hook_called = 0 @@ -686,6 +689,7 @@ def __init__(self, mod): m.load_state_dict(state_dict) self.assertEqual(2, hook_called) + @swap([True, False]) def test_load_state_dict_post_hook(self): hook_called = 0 @@ -743,6 +747,7 @@ def load_hook_clear_incompatible(module, incompatible_keys): self.assertEqual([], ret.unexpected_keys) @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") + @swap([True, False]) def test_load_state_dict_post_hook_backward_compatibility(self): def my_post_load_hook(mod, _): nonlocal called @@ -771,6 +776,89 @@ def my_post_load_hook(mod, _): m.load_state_dict(sd) self.assertTrue(called) + def _test_register_state_dict_pre_hook(self, model, submodule): + _state_dict_prefix = "foo." + state_dict_pre_hook_count = 0 + keep_var_setting = False + + def my_state_dict_pre_hook(module, prefix, keep_vars): + self.assertEqual(keep_vars, keep_var_setting) + nonlocal state_dict_pre_hook_count + state_dict_pre_hook_count += 1 + self.assertTrue(prefix.startswith(_state_dict_prefix)) + + model.register_state_dict_pre_hook(my_state_dict_pre_hook) + # Test to ensure submodules run the hook as well. + submodule.register_state_dict_pre_hook(my_state_dict_pre_hook) + + def check_results(model): + nonlocal state_dict_pre_hook_count, keep_var_setting + for keep_var_setting in [True, False]: + _ = model.state_dict( + prefix=_state_dict_prefix, keep_vars=keep_var_setting + ) + self.assertEqual(2, state_dict_pre_hook_count) + state_dict_pre_hook_count = 0 + + # Test state dict works as expected after model construction + check_results(model) + # Test state dict works as expected after forward + model(torch.ones(10, 3)) + check_results(model) + + def test_register_state_dict_pre_hook(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Sequential( + nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3) + ) + + def forward(self, x): + return self.a(x) + + mod = MyModule() + self._test_register_state_dict_pre_hook(mod, mod.a) + + def test_register_state_dict_pre_hook_lazy_module(self): + class MyLazyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.LazyLinear(8) + self.layer2 = nn.LazyLinear(5) + + def forward(self, x): + return self.layer2(self.layer1(x)) + + mod = MyLazyModule() + self._test_register_state_dict_pre_hook(mod, mod.layer1) + + @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") + def test_register_state_dict_pre_hook_backward_compat(self): + called = False + + def my_state_dict_pre_hook(*args, **kwargs): + nonlocal called + called = True + + m = nn.Linear(1, 1) + self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) + delattr(m, "_state_dict_pre_hooks") + # Save and load, ensure we can still call state_dict + # without running into issues. + with NamedTemporaryFile() as f: + # Note that torch.save / torch.load is not recommended + # to save / load modules. + torch.save(m, f.name) + m = torch.load(f.name) + + # Ensure we can run state_dict without issues + _ = m.state_dict() + self.assertFalse(called) + m.register_state_dict_pre_hook(my_state_dict_pre_hook) + _ = m.state_dict() + self.assertTrue(called) + class TestModuleGlobalHooks(TestCase): def tearDown(self): @@ -1553,6 +1641,7 @@ def parameter_registration_hook(module, name, parameter): instantiate_parametrized_tests(TestModuleHooks) +instantiate_parametrized_tests(TestStateDictHooks) if __name__ == "__main__": run_tests() diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 1d5127b18603..0c7a141d6a7a 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -24,7 +24,6 @@ from torch.testing._internal import common_utils sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - import onnx_test_common @@ -472,7 +471,7 @@ def generate_example_inputs(batch: int, seq: int, hidden_size: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 3 if test_backward else 2 + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, @@ -565,7 +564,7 @@ def generate_example_inputs(batch: int, seq: int, hidden_size: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 3 if test_backward else 2 + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, @@ -650,7 +649,7 @@ def generate_example_inputs(batch: int, seq: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 3 if test_backward else 2 + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, @@ -782,8 +781,9 @@ def record_onnx_model_transform(onnx_model): result = compiled_model() self.assertEqual(len(recorded_models), 1) + # NOTE: Constant folded by optimizer self.assertTrue( - "aten_add" in [node.op_type for node in recorded_models[0].graph.node] + "Constant" in [node.op_type for node in recorded_models[0].graph.node] ) self.assertEqual(result, torch.ones(4, 8)) @@ -822,11 +822,11 @@ def example_model(x: torch.Tensor): # Part 2: Change the ONNX model seen by the transform so that # ORT receives a different model. + # NOTE: the function is optimized away by optimizer def replace_relu_with_sigmoid(onnx_model): - for function in onnx_model.functions: - for node in function.node: - if node.op_type == "Relu": - node.op_type = "Sigmoid" + for node in onnx_model.graph.node: + if node.op_type == "Relu": + node.op_type = "Sigmoid" def another_example_model(x: torch.Tensor): y = torch.relu(x) diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index cc60b975a5eb..30bfd27483b9 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -26,6 +26,13 @@ def forward(self, x): return (y, z) +class SampleModelTwoInputs(torch.nn.Module): + def forward(self, x, b): + y = x + b + z = y.relu() + return (y, z) + + class _LargeModel(torch.nn.Module): def __init__(self): super().__init__() @@ -221,5 +228,101 @@ def test_serialize_succeeds_when_model_greater_than_2gb(self): serializer.serialize(onnx_program, io.BytesIO()) +class TestONNXExportWithDynamo(common_utils.TestCase): + def test_args_normalization_with_no_kwargs(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), torch.randn(1, 1, 2), torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_normalization_with_kwargs(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_args_normalization_with_empty_dict_at_the_tail(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}), + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_enable_dynamic_shape(self): + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + SampleModelTwoInputs(), + torch.randn(1, 1, 2), + b=torch.randn(1, 1, 2), + export_options=ExportOptions(dynamic_shapes=True), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelTwoInputs(), + (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}), + dynamic_axes={"b": [0, 1, 2]}, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_raises_unrelated_parameters_warning(self): + message = ( + "f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, " + "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " + "autograd_inlining are not supported for dynamo export at the moment." + ) + + with self.assertWarnsOnceRegex(UserWarning, message): + _ = torch.onnx.export( + SampleModel(), + (torch.randn(1, 1, 2),), + dynamo=True, + ) + + def test_raises_unsupported_specific_dynamic_axes_warning(self): + message = ( + "Specified dynamic axes is not supported for dynamo export at the moment." + ) + + with self.assertWarnsOnceRegex(UserWarning, message): + _ = torch.onnx.export( + SampleModel(), + (torch.randn(1, 1, 2),), + dynamic_axes={"input": [0, 1, 2]}, + dynamo=True, + ) + + def test_saved_f_exists_after_export(self): + with common_utils.TemporaryFileName(suffix=".onnx") as path: + _ = torch.onnx.export( + SampleModel(), torch.randn(1, 1, 2), path, dynamo=True + ) + self.assertTrue(os.path.exists(path)) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index 223ff04606db..3c5526de53f9 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -6,7 +6,6 @@ import io import logging import typing -import unittest from typing import AbstractSet, Protocol, Tuple import torch @@ -17,6 +16,9 @@ from torch.onnx._internal.fx import diagnostics as fx_diagnostics from torch.testing._internal import common_utils, logging_utils +if typing.TYPE_CHECKING: + import unittest + class _SarifLogBuilder(Protocol): def sarif_log(self) -> sarif.SarifLog: diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index b9b5e9859bab..6fdbf4e92839 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -205,10 +205,10 @@ def xfail_dynamic_fx_test( Args: reason: The reason for xfailing dynamic exporting test. model_type (TorchModelType): The model type to xfail dynamic exporting test for. - When None, model type is not used to skip dynamic tests. + When None, model type is not used to xfail dynamic tests. Returns: - A decorator for skipping dynamic exporting test. + A decorator for xfailing dynamic exporting test. """ def skip_dec(func): @@ -225,6 +225,36 @@ def wrapper(self, *args, **kwargs): return skip_dec +def xfail_op_level_debug_test( + error_message: str, + model_type: Optional[TorchModelType] = None, + reason: Optional[str] = None, +): + """Xfail op level debug test. + + Args: + reason: The reason for xfailing op level debug test. + model_type (TorchModelType): The model type to xfail dynamic exporting test for. + When None, model type is not used to xfail op level debug tests. + + Returns: + A decorator for xfailing op level debug test. + """ + + def skip_dec(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if self.op_level_debug and ( + not model_type or self.model_type == model_type + ): + return xfail(error_message, reason)(func)(self, *args, **kwargs) + return func(self, *args, **kwargs) + + return wrapper + + return skip_dec + + def skip_dynamic_fx_test(reason: str, model_type: TorchModelType = None): """Skip dynamic exporting test. diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 4c71aafa473e..6d675d446030 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -65,7 +65,7 @@ common_methods_invocations, common_utils, ) -from torch.testing._internal.opinfo import core as opinfo_core +from torch.testing._internal.opinfo import core as opinfo_core # noqa: TCH001 # NOTE: For ATen signature modifications that will break ONNX export, @@ -147,6 +147,7 @@ def skip_torchlib_forward_compatibility( ), xfail( "__rmatmul__", + dtypes=(torch.float16,), reason="fixme: Assertion error: result mismatch", ), xfail( @@ -218,9 +219,8 @@ def skip_torchlib_forward_compatibility( reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") ), xfail( - "all", - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") + "alias_copy", + reason="OnnxExporterError: Failed to export model", ), xfail( "allclose", @@ -240,11 +240,6 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), ), - xfail( - "any", - reason=onnx_test_common.reason_onnx_does_not_support( - "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") - ), xfail( "arange", dtypes=(torch.uint8,), @@ -319,6 +314,11 @@ def skip_torchlib_forward_compatibility( "bincount", reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), ), + xfail( + "block_diag", + dtypes=onnx_test_common.COMPLEX_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"), + ), xfail( "bmm", dtypes=( @@ -346,10 +346,6 @@ def skip_torchlib_forward_compatibility( "chalf", reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." ), - xfail( - "chunk", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool") - ), xfail( "chunk", dtypes=(torch.uint8, torch.int8, torch.int16,), @@ -421,8 +417,14 @@ def skip_torchlib_forward_compatibility( reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), ), xfail( - "cross", - reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"), + "diag", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), + ), + xfail( + "diagonal_copy", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), ), xfail( "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), @@ -523,6 +525,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") ), + xfail( + "gather", + reason="HandleNegativeAxis(int64_t, int64_t) IsAxisInRange(axis, tensor_rank) was \ + false. axis 0 is not in valid range [-0,-1]" + ), xfail( "geometric", reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), @@ -532,14 +539,24 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), ), + xfail( + "index_add", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "int64, int32, bool"), + ), xfail( "index_fill", dtypes=onnx_test_common.COMPLEX_TYPES, reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") ), + xfail( + "index_fill", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES, + reason="fixme: Constant input list has None. ONNXScript does not support None in constant list." + ), xfail( "index_put", - dtypes=onnx_test_common.BOOL_TYPES, + dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), ), xfail( @@ -547,6 +564,11 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.uint8, torch.int8, torch.int16,), reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), ), + xfail( + "index_put", + dtypes=(torch.float16,), + reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "float16"), + ), xfail( "isnan", dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, @@ -574,6 +596,10 @@ def skip_torchlib_forward_compatibility( variant_name="grad_oriented", reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), ), + xfail( + "linalg.matrix_power", + reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])." + ), xfail( "linalg.norm", reason="fixme: Assertion error: result mismatch", @@ -624,11 +650,6 @@ def skip_torchlib_forward_compatibility( dtypes=(torch.float16,), reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", ), - xfail( - "logcumsumexp", - reason=onnx_test_common.reason_onnx_does_not_support( - "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0") - ), xfail( "logical_and", dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, @@ -649,12 +670,7 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), ), - xfail( - "logsumexp", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceLogSumExp", "bool, int"), - ), - xfail( + skip( "masked.logsumexp", reason="fixme: https://github.com/onnx/onnx/issues/4986", ), @@ -724,12 +740,9 @@ def skip_torchlib_forward_compatibility( xfail( "max", variant_name="reduction_with_dim", + dtypes=(torch.int64,), reason="https://github.com/onnx/onnx/issues/4986", ), - xfail( - "mean", - reason="(ReduceMean) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), xfail( "min", variant_name="reduction_no_dim", @@ -864,6 +877,11 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.COMPLEX_TYPES, reason="fixme: Assertion error: result mismatch", ), + xfail( + "nn.functional.cosine_embedding_loss", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_runtime_does_not_support("CosineEmbeddingLoss", "bool"), + ), xfail( "nn.functional.ctc_loss", reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), @@ -954,6 +972,20 @@ def skip_torchlib_forward_compatibility( variant_name="reflect", reason="fixme: Assertion error: result mismatch", ), + xfail( + "nn.functional.pixel_shuffle", + dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, + reason="fixme: ONNX Runtime does not support int32/64 inputs", + ), + xfail( + "nn.functional.pixel_unshuffle", + reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"), + ), + xfail( + "nn.functional.poisson_nll_loss", + dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, + reason="fixme: result mismatch with NaN.", + ), xfail( "nn.functional.rrelu", reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), @@ -1093,6 +1125,11 @@ def skip_torchlib_forward_compatibility( variant_name="mean", reason="ONNX doesn't support reduce='mean' option", ), + xfail( + "sgn", + dtypes=onnx_test_common.BOOL_TYPES, + reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), + ), xfail( "sign", dtypes=onnx_test_common.BOOL_TYPES, @@ -1127,35 +1164,20 @@ def skip_torchlib_forward_compatibility( reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), ), xfail( - "special.ndtr", - dtypes=(torch.float16,), + "special.log_ndtr", + dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES, reason="fixme: Assertion error: result mismatch", ), xfail( - "split", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), - xfail( - "split", - variant_name="list_args", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), - xfail( - "split_with_sizes", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), + "special.ndtr", + dtypes=(torch.float16,), + reason="fixme: Assertion error: result mismatch", ), xfail( "square", dtypes=(torch.int8, torch.uint8, torch.int16), reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), ), - xfail( - "squeeze", - reason="fixme: Assertion error: result mismatch", - ), xfail( "squeeze", variant_name="multiple", @@ -1165,15 +1187,6 @@ def skip_torchlib_forward_compatibility( "svd_lowrank", reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), ), - xfail( - "std_mean", - reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." - ), - xfail( - "std_mean", - variant_name="unbiased", - reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple." - ), xfail( "stft", reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), @@ -1213,11 +1226,6 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), ), - xfail( - "unbind", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), xfail( "unflatten", dtypes=onnx_test_common.BOOL_TYPES, @@ -1240,16 +1248,6 @@ def skip_torchlib_forward_compatibility( dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), ), - xfail( - "unsafe_split", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), - xfail( - "unsafe_chunk", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"), - ), xfail( "where", dtypes=onnx_test_common.BOOL_TYPES, @@ -1415,8 +1413,10 @@ def skip_torchlib_forward_compatibility( ), xfail( "index_add", - matcher=lambda sample: len(sample.input.shape) < 2, - reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", + matcher=lambda sample: len(sample.input.shape) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ScatterND", "0-D tensor" + ), ), xfail( "index_add", @@ -1425,8 +1425,10 @@ def skip_torchlib_forward_compatibility( ), xfail( "index_copy", - matcher=lambda sample: len(sample.input.shape) < 2, - reason="fixme: https://github.com/microsoft/onnxscript/issues/1212", + matcher=lambda sample: len(sample.input.shape) == 0, + reason=onnx_test_common.reason_onnx_runtime_does_not_support( + "ScatterND", "0-D tensor" + ), ), xfail( "index_copy", @@ -1457,12 +1459,6 @@ def skip_torchlib_forward_compatibility( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: LogSoftMax does not support empty tensor as input", ), - xfail( - "logsumexp", - matcher=lambda sample: isinstance(sample.input, torch.Tensor) - and len(sample.input.shape) == 0, - reason="fixme: IsScalar", - ), skip( "masked.log_softmax", matcher=lambda sample: len(sample.input.shape) == 0, @@ -1473,12 +1469,6 @@ def skip_torchlib_forward_compatibility( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), - xfail( - "min", - variant_name="reduction_with_dim", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: https://github.com/onnx/onnx/issues/4986", - ), skip( "mm", matcher=lambda sample: torch.numel(sample.input) == 0, @@ -1570,8 +1560,7 @@ def skip_torchlib_forward_compatibility( xfail( "nn.functional.instance_norm", model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - matcher=lambda sample: sample.kwargs.get("running_mean") is not None - or sample.input.dtype in (torch.float16,), + matcher=lambda sample: sample.kwargs.get("running_mean") is not None, reason="fixme: KeyError: 'self___kwargs__running_mean'", ), xfail( @@ -1580,6 +1569,11 @@ def skip_torchlib_forward_compatibility( and sample.kwargs.get("padding") == 1, reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", ), + xfail( + "nn.functional.pixel_shuffle", + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: ORT does not support empty tensor as input", + ), xfail( "nonzero", matcher=lambda sample: len(sample.input.shape) == 0 @@ -1625,12 +1619,6 @@ def skip_torchlib_forward_compatibility( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: LogSoftMax does not support empty tensor as input", ), - xfail( - "t", - matcher=lambda sample: isinstance(sample.input, torch.Tensor) - and len(sample.input.shape) < 2, - reason="fixme: IsScalar", - ), xfail( "unflatten", reason="Logic not implemented for size 0 inputs in op.Reshape", @@ -1992,8 +1980,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "addr": [3e-3, 4e-3], "baddbmm": [3e-2, 1e-3], "cumulative_trapezoid": [3e-2, 1e-3], + "cross": [3e-2, 2e-2], "diff": [1e-2, 5e-2], "gradient": [3e-3, 4e-3], + "linalg.cross": [1e-3, 2e-2], "linalg.multi_dot": [3e-2, 1e-3], "linalg.vecdot": [1e-2, 2e-2], "linspace": [2e-2, 2e-3], @@ -2008,6 +1998,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "nn.functional.hardsigmoid": [1e-3, 5e-3], "nn.functional.hardswish": [1e-3, 5e-3], "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], + "nn.functional.huber_loss": [1e-3, 1e-2], "nn.functional.instance_norm": [1e-2, 1e-3], "nn.functional.interpolate": [1e-2, 1e-3], "nn.functional.kl_div": [2e-3, 2e-4], @@ -2015,7 +2006,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "nn.functional.local_response_norm": [1e-2, 5e-3], "nn.functional.poisson_nll_loss": [3e-2, 1e-3], "nn.functional.nll_loss": [3e-2, 1e-3], + "nn.functional.triplet_margin_loss": [2e-2, 1e-2], + "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2], "native_batch_norm": [3e-2, 1e-3], + "norm": [1e-2, 1e-2], "dot": [3e-2, 1e-3], "logit": [3e-2, 1e-3], "rsub": [3e-2, 1e-3], @@ -2023,6 +2017,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "sub": [3e-2, 1e-3], "trapezoid": [1e-3, 7e-3], "trapz": [1e-3, 7e-3], + "vdot": [1e-3, 1e-2], } fp16_low_precision_variant_dict = { diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py index 9ebbf11646dc..8389f3912075 100644 --- a/test/onnx/test_fx_passes.py +++ b/test/onnx/test_fx_passes.py @@ -1,4 +1,6 @@ # Owner(s): ["module: onnx"] +import pytorch_test_common + import torch import torch._dynamo import torch.fx @@ -96,6 +98,10 @@ def func(x, y, z): @common_utils.instantiate_parametrized_tests class TestModularizePass(common_utils.TestCase): + @pytorch_test_common.xfail( + error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found", + reason="optimizer", + ) @common_utils.parametrize( "is_exported_program", [ @@ -146,6 +152,10 @@ def forward(self, x, y): ) self.assertFalse(any("ReLU" in name for name in function_proto_names)) + @pytorch_test_common.xfail( + error_message="'torch_nn_modules_activation_ReLU_relu_1' not found", + reason="optimizer", + ) @common_utils.parametrize( "is_exported_program", [ @@ -187,6 +197,10 @@ def forward(self, x, y): self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names) self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names) + @pytorch_test_common.xfail( + error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found", + reason="optimizer", + ) @common_utils.parametrize( "is_exported_program", [ diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index b660b0525dba..61cb9e807f70 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -17,7 +17,7 @@ from torch._subclasses import fake_tensor from torch.nn import functional as F from torch.onnx import dynamo_export, ExportOptions -from torch.onnx._internal.diagnostics import infra +from torch.onnx._internal.diagnostics import infra # noqa: TCH001 from torch.onnx._internal.fx import diagnostics, registration from torch.testing._internal import common_utils @@ -171,9 +171,13 @@ def forward(self, input): torch.argmax(input, dim=1, keepdim=True), ) - _ = dynamo_export( - ArgminArgmaxModel(), model_input, export_options=self.export_options - ) + # NOTE: KeyError: dim raised in optimizer + with self.assertWarnsOnceRegex( + UserWarning, "ONNXScript optimizer failed. Skipping optimization." + ): + _ = dynamo_export( + ArgminArgmaxModel(), model_input, export_options=self.export_options + ) def test_multiple_outputs_op_with_evaluator(self): class TopKModel(torch.nn.Module): @@ -182,7 +186,8 @@ def forward(self, x): return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) - onnx_program = dynamo_export(TopKModel(), x, export_options=self.export_options) + + _ = dynamo_export(TopKModel(), x, export_options=self.export_options) def test_unsupported_indices_fake_tensor_generated_with_op_level_debug(self): class EmbedModelWithoutPaddingIdx(torch.nn.Module): @@ -364,11 +369,13 @@ def _assert_node_outputs_has_value_info( node: onnx.NodeProto, value_infos: Mapping[str, onnx.ValueInfoProto], local_functions: Mapping[Tuple[str, str], onnx.FunctionProto], + exclude_names_in_value_info, function_id: str = "", ): for output in node.output: name = f"{function_id}/{output}" if function_id else output - self.assertIn(name, value_infos) + if name not in exclude_names_in_value_info: + self.assertIn(name, value_infos) if node.domain.startswith("pkg.onnxscript.torch_lib"): # No shape info available for values inside torchlib functions. return @@ -378,13 +385,25 @@ def _assert_node_outputs_has_value_info( for node in function.node: function_id = f"{function.domain}::{function.name}" _assert_node_outputs_has_value_info( - node, value_infos, local_functions, function_id + node, + value_infos, + local_functions, + exclude_names_in_value_info, + function_id, ) type_infos = {vi.name: vi for vi in model_proto.graph.value_info} functions = {(f.domain, f.name): f for f in model_proto.functions} + # NOTE: inputs, outputs, and initializers are not included in value_info spec + exclude_names_in_value_info = ( + [input.name for input in model_proto.graph.input] + + [output.name for output in model_proto.graph.output] + + [init.name for init in model_proto.graph.initializer] + ) for node in model_proto.graph.node: - _assert_node_outputs_has_value_info(node, type_infos, functions) + _assert_node_outputs_has_value_info( + node, type_infos, functions, exclude_names_in_value_info + ) def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): class SubModule(torch.nn.Module): @@ -424,10 +443,11 @@ def forward(self, tensor_x: torch.Tensor): model = MNISTModel() onnx_program = torch.onnx.dynamo_export(model, tensor_x) model_proto = onnx_program.model_proto - self.assertEqual( - {initializer.name for initializer in model_proto.graph.initializer}, - {*model.state_dict().keys()}, - ) + + # NOTE: initializers could be optimized away by onnx optimizer + onnx_initilizers = {init.name for init in model_proto.graph.initializer} + torch_weights = {*model.state_dict().keys()} + self.assertTrue(onnx_initilizers.issubset(torch_weights)) @common_utils.parametrize( "checkpoint_type", @@ -708,7 +728,11 @@ def forward(self, input: torch.Tensor): input = input.to(float8_type) return input + torch.tensor(1.0, dtype=float8_type) - _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) + # NOTE: shape inference error raised in optimizer due to unsupported dtype + with self.assertWarnsOnceRegex( + UserWarning, "ONNXScript optimizer failed. Skipping optimization." + ): + _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) def test_export_with_logging_logger(self): logger = logging.getLogger(__name__) diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 149b9dc987bb..0f0e01bc0dc2 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,8 +158,12 @@ def forward(self, x, y): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - torch.tensor([operator.floordiv(x.item(), y.item())]), - torch.tensor([operator.pow(x.item(), y.item())]), + # This requires torch.sym_float, probably easy to lower to + # ONNX but I don't know where to put it + # torch.tensor([operator.floordiv(x.item(), y.item())]), + # NB: abs so that the base and exponent are provably + # non-negative, so we don't generate runtime asserts + torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), @@ -577,9 +581,6 @@ def forward(self, x): x = torch.randn(1, 1, 1, 32, device=torch.device("cuda")) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (x,)) - # NOTE:The test was meant to test the empty bounding box case, but it is not - # supported. When we have vision model examples, we will have a better test case - # to demonstrate in FX and FX exporter. def test_view_dynamic_zero_dim(self): class ViewModel(torch.nn.Module): def forward(self, input): @@ -587,12 +588,11 @@ def forward(self, input): return input.view(1, -1) x = torch.ones(2) - # y = torch.empty(0) + y = torch.empty(0) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( ViewModel(), (x,), - # additional_test_inputs=[((y,),)], # TODO: Without `additional_test_inputs` arg, dynamic shape cannot be verified - skip_dynamic_shapes_check=True, # Has static shape for dynamic_shapes=True due to 0/1 specialization + additional_test_inputs=[((y,),)], ) def test_flatten_dynamic_axes(self): @@ -666,6 +666,11 @@ def forward(self, x): @pytorch_test_common.xfail_if_model_type_is_exportedprogram( error_message="Trying to flatten user inputs with exported input tree spec" ) + @pytorch_test_common.xfail_dynamic_fx_test( + error_message="!(it.GetName().empty())", + reason="With after onnx==1.16, constant folding in optimizer causes this error.", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + ) def test_gpt2_tiny_from_config(self): # Model config = transformers.GPT2Config( @@ -1145,6 +1150,11 @@ def create_kwargs(): reason="Dynamic shape check is not expected for exported program in this test suite.", model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, ) + @pytorch_test_common.xfail_dynamic_fx_test( + error_message="!(it.GetName().empty())", + reason="With after onnx==1.16, constant folding in optimizer causes this error.", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, + ) @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( error_message="Expected 4 inputs, got 2", reason="https://github.com/pytorch/pytorch/issues/115745", @@ -1259,16 +1269,17 @@ def create_model(): model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, ) @pytorch_test_common.xfail_dynamic_fx_test( - error_message="NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node", - reason="Need to check Trilu node in the ONNX graph", + error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", + reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, ) - @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( - error_message="NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node", - reason="Need to check Trilu node in the ONNX graph", + @pytorch_test_common.xfail_op_level_debug_test( + error_message="Could not find an implementation for Trilu(14) node", + reason="ORT error during op level dubug", + model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, ) @pytorch_test_common.xfail_if_model_type_is_exportedprogram( - error_message="aot_autograd expected to have an entirely functional graph", + error_message="n=copy_, n.args[0]=zeros_like, placeholders={", reason="aot_autograd doesn't support it.", ) def test_fake_tensor_mode_huggingface_openai_whisper(self): @@ -1477,41 +1488,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) - @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( - error_message="Expected 4 inputs, got 2", - reason="https://github.com/pytorch/pytorch/issues/115745", - ) - def test_fake_tensor_mode_huggingface_tiny_gpt2_torch_load(self): - model_name = "sshleifer/tiny-gpt2" - device = "cpu" - - def create_model(): - return transformers.AutoModel.from_pretrained(model_name).to(device).eval() - - def create_args(): - tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) - kwargs = tokenizer("Hello world!", return_tensors="pt") - input_ids = kwargs["input_ids"] - attention_mask = kwargs["attention_mask"] - return input_ids, None, attention_mask - - def create_pytorch_only_extra_kwargs(): - return {"return_dict": False} - - self._test_fake_tensor_mode_exporter( - "huggingface_sshleifer_tiny-gpt2", - create_model, - create_args, - create_pytorch_only_extra_kwargs, - load_checkpoint_during_init=self.load_checkpoint_during_init, - export_within_fake_mode=self.export_within_fake_mode, - model_type=self.model_type, - ) - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 816bcfc3b8df..e49d5d3bceeb 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -2585,6 +2585,9 @@ def forward(self, x, update): update = torch.randn(4, 1, 3, 2) self.run_test(IndexPutModel2(), (x, update)) + @unittest.skip( + "regression in 1.18: https://github.com/microsoft/onnxruntime/issues/20855" + ) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_loop(self): @torch.jit.script diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index bca8b3b6b69c..11c6a24bd161 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -1,4 +1,5 @@ # Owner(s): ["module: optimizer", "module: LrScheduler" ] +import copy import math import pickle import tempfile @@ -2403,6 +2404,119 @@ def test_lr_scheduler_state_dict_load(self, LRClass, weights_only): scheduler2.load_state_dict(state_dict_loaded) self.assertEqual(scheduler2.state_dict(), state_dict) + @parametrize( + "LRClass", + [ + partial(LambdaLR, lr_lambda=lambda e: e // 10), + partial(MultiplicativeLR, lr_lambda=lambda e: 0.95), + partial(StepLR, step_size=30), + partial(MultiStepLR, milestones=[30, 80]), + ConstantLR, + LinearLR, + partial(ExponentialLR, gamma=0.9), + PolynomialLR, + partial(CosineAnnealingLR, T_max=10), + partial(CosineAnnealingWarmRestarts, T_0=20), + ], + ) + def test_constant_initial_lr(self, LRClass): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = LRClass(opt) + + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + opt.step() + sch.step(i) + lr.multiply_(0.1) + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(sch.base_lrs, [0.1]) + + def test_constant_initial_params_cyclelr(self): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + max_lr = torch.as_tensor(0.2) + base_momentum = torch.as_tensor(0.8) + max_momentum = torch.as_tensor(0.9) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = CyclicLR( + opt, + base_lr=lr, + max_lr=max_lr, + base_momentum=base_momentum, + max_momentum=max_momentum, + ) + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + lr.multiply_(0.5) + max_lr.multiply_(0.5) + base_momentum.multiply_(0.5) + max_momentum.multiply_(0.5) + opt.step() + sch.step(i) + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) + self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) + self.assertEqual(sch.base_lrs, [0.1]) + self.assertEqual(sch.max_lrs, [0.2]) + self.assertEqual(group["max_momentum"], 0.9) + self.assertEqual(group["base_momentum"], 0.8) + + def test_constant_initial_params_onecyclelr(self): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + base_momentum = torch.as_tensor(0.85) + max_momentum = torch.as_tensor(0.95) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = OneCycleLR( + opt, + max_lr=lr, + total_steps=10, + base_momentum=base_momentum, + max_momentum=max_momentum, + ) + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + lr.multiply_(0.5) + base_momentum.multiply_(0.5) + max_momentum.multiply_(0.5) + opt.step() + sch.step(i) + + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(group["max_lr"], ori_group["max_lr"]) + self.assertEqual(group["min_lr"], ori_group["min_lr"]) + self.assertEqual(group["max_momentum"], ori_group["max_momentum"]) + self.assertEqual(group["base_momentum"], ori_group["base_momentum"]) + self.assertEqual(group["max_momentum"], 0.95) + self.assertEqual(group["base_momentum"], 0.85) + + def test_constant_initial_params_swalr(self): + # Test that the initial learning rate is constant + lr = torch.as_tensor(0.1) + swa_lr = torch.as_tensor(0.05) + opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr) + sch = SWALR(opt, swa_lr=swa_lr) + ori_param_groups = copy.deepcopy(opt.param_groups) + + for i in range(2): + lr.multiply_(0.5) + swa_lr.multiply_(0.5) + opt.step() + sch.step() + for group, ori_group in zip(opt.param_groups, ori_param_groups): + self.assertEqual(group["initial_lr"], ori_group["initial_lr"]) + self.assertEqual(group["swa_lr"], ori_group["swa_lr"]) + self.assertEqual(group["swa_lr"], 0.05) + self.assertEqual(sch.base_lrs, [0.1]) + instantiate_parametrized_tests(TestLRScheduler) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 38e83d448fdd..81d158635c0e 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1723,9 +1723,12 @@ def _validate_basic_json(self, traceEvents, cuda_available=False): gpu_value = traceEvent.get("args", {}).get("labels", None) if gpu_value and "GPU" in gpu_value: gpu_dict[gpu_value] += 1 + # Max PID offset is 5M, based from pytorch/kineto include header: + # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35 + kExceedMaxPid = 5000000 self.assertTrue( traceEvents[i + 1]["args"]["sort_index"] - == 0x1000000 + int(gpu_value.split()[1]) + == kExceedMaxPid + int(gpu_value.split()[1]) ) # TODO add checking gpu count if cpuOnly_ is true or not @@ -2408,6 +2411,7 @@ def test_profiler_matmul_dim_fp16_pattern(self): num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( diff --git a/test/quantization/core/experimental/test_adaround_eager.py b/test/quantization/core/experimental/test_adaround_eager.py index 33a16f21bd0f..a0a2f8f8aa03 100644 --- a/test/quantization/core/experimental/test_adaround_eager.py +++ b/test/quantization/core/experimental/test_adaround_eager.py @@ -29,14 +29,20 @@ def feedforawrd_callback( ) -> None: model(data) - def run_adaround(self, model, img_data): + def feedforawrd_callback_with_wrapper(self, model, data, wrapper) -> None: + wrapper(model, data) + + def run_adaround(self, model, img_data, wrapper=None): adaround_optimizer = AdaptiveRoundingOptimizer( model, - self.feedforawrd_callback, + self.feedforawrd_callback + if wrapper is None + else self.feedforawrd_callback_with_wrapper, forward_wrapper, img_data, max_iter=100, batch_size=10, + feed_forward_wrapper=wrapper, ) adarounded_model = adaround_optimizer.run_adaround() return adarounded_model @@ -63,6 +69,17 @@ def get_fake_quant(self, model): module.weight.data.copy_(fake_quant_module) return hard_fake_quant_model + def get_feed_forward_wrapper(self): + class FeedForwardWrapper(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, model, sample): + return model(sample) + + wrapper_module = FeedForwardWrapper() + return wrapper_module + def test_linear_chain(self): class LinearChain(nn.Module): def __init__(self): @@ -79,7 +96,9 @@ def forward(self, x): float_model = LinearChain() img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)] - adarounded_model = self.run_adaround(float_model, img_data) + adarounded_model = self.run_adaround( + float_model, img_data, self.get_feed_forward_wrapper() + ) fq_model = self.get_fake_quant(float_model) rand_input = torch.rand(10, 3) with torch.no_grad(): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d59f1fffd926..5b86693e11c1 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -21,6 +21,7 @@ import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() +from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN @@ -31,10 +32,12 @@ qengine_is_onednn, ) from torch.ao.quantization import PerChannelMinMaxObserver -from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA +from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA from torch.testing._internal.optests import opcheck import torch.backends.xnnpack +from torch.utils.cpp_extension import ROCM_HOME + from typing import Optional np_dtype = { @@ -43,6 +46,8 @@ torch.qint32 : np.int32 } +TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None + class PointwisePostOp(NamedTuple): binary_attr : str = "none" alpha : float = 1.0 @@ -905,9 +910,8 @@ def test_qadd_relu_same_qparams(self): """Tests the correctness of the cudnn add and add_relu op (Similar to test_qadd_relu_different_qparams, will probably merge in the future)""" @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the test_qadd_relu_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qadd_relu_cudnn(self): dtype = torch.qint8 add_relu = torch.ops.quantized.add_relu @@ -940,9 +944,8 @@ def test_qadd_relu_cudnn(self): """Tests the correctness of the cudnn add and add_relu op for nhwc format""" @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the test_qadd_relu_cudnn_nhwc op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qadd_relu_cudnn_nhwc(self): dtype = torch.qint8 add_relu = torch.ops.quantized.add_relu @@ -1379,7 +1382,7 @@ def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode): self.assertEqual(a_ref, a_hat.dequantize(), msg="ops.quantized.max_pool1d results are off") - # TODO: merge this test with test_max_pool2d when USE_EXPERIMENTAL_CUDNN_V8_API flag is enabled in CI + # TODO: merge this test with test_max_pool2d """Tests 2D cudnn max pool operation on quantized tensors.""" @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, min_side=1, max_side=10), @@ -1394,9 +1397,8 @@ def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode): padding=st.integers(0, 2), ceil_mode=st.booleans()) @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv2d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(TEST_CUDNN_VERSION <= 90100, "cuDNN maxpool2d mishandles -128 before v90100") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_max_pool2d_cudnn(self, X, kernel, stride, dilation, padding, ceil_mode): X, (scale, zero_point, torch_type) = X assume(kernel // 2 >= padding) # Kernel cannot be overhanging! @@ -4050,9 +4052,9 @@ def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qlinear_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") # TODO: check with yang regarding CUDNN flags def test_qlinear_cudnn(self, batch_size, input_channels, output_channels, use_bias, use_relu, use_multi_dim_input, use_channelwise): @@ -5427,9 +5429,8 @@ def test_qconv2d_add_relu(self): use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv2d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv2d_cudnn( self, batch_size, @@ -5510,9 +5511,8 @@ def test_qconv2d_cudnn( use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv2d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv2d_relu_cudnn( self, batch_size, @@ -6245,9 +6245,8 @@ def test_qconv1d_relu( use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv1d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv1d_cudnn( self, batch_size, @@ -6319,9 +6318,8 @@ def test_qconv1d_cudnn( use_channelwise=st.sampled_from([False])) @skipIfNoFBGEMM @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") - @unittest.skip("Local only - currently the qconv1d_cudnn op is bulid " - "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test " - "after it is built by default") + @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") + @unittest.skipIf(TEST_ROCM, "not supported on rocm.") def test_qconv1d_relu_cudnn( self, batch_size, diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index 52f169b1d5b6..31ffa3104b65 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -2,62 +2,63 @@ import copy import math + import torch -import torch.nn as nn -import torch.backends.mkldnn -from torch.nn import Conv2d, BatchNorm2d, ReLU, init -from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d -from torch.nn.modules.utils import _pair -import torch.ao.nn.quantized as nnq -import torch.ao.nn.quantized.dynamic as nnqd -import torch.ao.nn.qat as nnqat import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.backends.mkldnn +import torch.nn as nn +import torch.testing._internal.hypothesis_utils as hu + +from hypothesis import given, strategies as st +from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.ao.quantization import ( - prepare, convert, - prepare_qat, - quantize_qat, - QuantStub, - DeQuantStub, - default_qconfig, - default_qat_qconfig, default_embedding_qat_qconfig, + default_qat_qconfig, + default_qconfig, default_symmetric_qnnpack_qat_qconfig, - get_default_qat_qconfig, + DeQuantStub, FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, + get_default_qat_qconfig, get_embedding_qat_module_mappings, get_embedding_static_quant_module_mappings, NoopObserver, + prepare, + prepare_qat, + quantize_qat, + QuantStub, ) from torch.ao.quantization.qconfig import qconfig_equals +from torch.nn import BatchNorm2d, Conv2d, init, ReLU +from torch.nn.modules.utils import _pair from torch.testing._internal.common_quantization import ( DeFusedEmbeddingBagLinear, - QuantizationTestCase, - QuantStubModel, - ManualLinearQATModel, - ManualDropoutQATModel, - ManualLinearDynamicQATModel, ManualConvLinearQATModel, ManualConvLinearSymmQATModel, + ManualDropoutQATModel, ManualEmbeddingBagLinear, - TwoLayerLinearModel, + ManualLinearDynamicQATModel, + ManualLinearQATModel, + QuantizationTestCase, + QuantStubModel, test_only_eval_fn, test_only_train_fn, + TwoLayerLinearModel, ) from torch.testing._internal.common_quantized import ( + override_qengines, override_quantized_engine, supported_qengines, - override_qengines, ) from torch.testing._internal.common_utils import skipIfNoXNNPACK -from hypothesis import given -from hypothesis import strategies as st -import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() from functools import reduce @@ -1099,6 +1100,33 @@ def test_linear_bn_workflow(self): self.assertTrue(type(mq[1]) == nnq.Linear) self.assertTrue(type(mq[2]) == nn.Identity) + + @skipIfNoXNNPACK + @override_qengines + def test_linear_precomputed_fake_quant(self): + qengine = torch.backends.quantized.engine + if qengine != "qnnpack": + return # Only qnnpack support symmetric quantization + m_ref = nn.Linear(4, 4) + + m_ref_copy = copy.deepcopy(m_ref) + qconfig = default_qconfig + m_ref_copy.qconfig = qconfig + weight_post_process = copy.deepcopy(qconfig.weight()) + activation = copy.deepcopy(qconfig.activation()) + activation(torch.randn(4, 4)) + m_ref_copy.activation_post_process = activation + m_ref_copy = nnq.Linear.from_float(m_ref_copy) + weight_post_process = qconfig.weight() + weight_post_process.min_val = torch.tensor(-1) + weight_post_process.max_val = torch.tensor(1) + m_ref.weight_post_process = weight_post_process + m_ref.activation_post_process = activation + m_ref.qconfig = qconfig + m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True) + self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale) + + if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 218b30bd9e33..fb8182b21dda 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -538,6 +538,7 @@ def _test_quantizer( expected_node_occurrence, expected_node_list=None, is_qat=False, + debug=False, ): m_eager = model.train() if is_qat else model.eval() @@ -556,6 +557,8 @@ def _test_quantizer( prepare_model = copy.deepcopy(m) m = convert_pt2e(m) convert_model = copy.deepcopy(m) + if debug: + convert_model.print_readable(True) pt2_quant_output = m(*example_inputs) node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() @@ -751,9 +754,11 @@ def test_conv2d_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] self._test_quantizer( m, @@ -1346,9 +1351,11 @@ def test_linear_binary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1401,9 +1408,11 @@ def test_linear_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1472,9 +1481,11 @@ def test_linear_binary_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] fq_m = self._test_quantizer( m, @@ -1694,9 +1705,11 @@ def test_qat_conv2d_binary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] @@ -1741,9 +1754,11 @@ def test_qat_conv2d_binary2(self): torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.aten.add_.Tensor - if inplace_add - else torch.ops.aten.add.Tensor, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), ] self._test_quantizer( m, @@ -1865,6 +1880,410 @@ def test_qat_dynamic_quant_linear(self): is_qat=True, ) + @skipIfNoX86 + def test_set_module_name_qconfig(self): + """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. + + Expect that all linear layers within the submodule `sub` are quantized. + """ + + class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.relu1 = torch.nn.ReLU(inplace=False) + self.linear2 = torch.nn.Linear(10, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.relu1(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to `None` and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of two linear layers from `sub` + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # two Q/DQ pairs for two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it.""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + # This module name has underscores, which can be part of a mangled name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "foo_bar", xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = capture_pre_autograd_graph(m, example_inputs) + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # for foo_bar. + self.assertEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + count += 1 + + @skipIfNoX86 + def test_set_module_name_and_module_type_case1(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + + Expect that all linear layers are not quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with default config and then `None` for all `Linear`. + # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ).set_module_type_qconfig(torch.nn.Linear, None) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # last linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_case2(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + + Expect that all linear layers are quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with None and then default config for a all `Linear`. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( + torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input and output of the first and second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the first and second linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # Q/DQ for first lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # Q/DQ for second lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # last linear is not quantized + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig_for_dynamic_quant(self): + """Test that quantize a specific submodule for dynamic quantization.""" + + with override_quantized_engine("x86"), torch.no_grad(): + for is_qat in [False, True]: + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + # only quantize `q_proj` `v_proj` + dynamic_config = xiq.get_default_x86_inductor_quantization_config( + is_dynamic=True, is_qat=is_qat + ) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig("q_proj", dynamic_config) + .set_module_name_qconfig("v_proj", dynamic_config) + ) + node_occurrence = { + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # dequantize the weight of q_proj and v_proj + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + # q_proj + torch.ops.aten.linear.default, + # k_proj + torch.ops.aten.linear.default, + # v_proj + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + ) + + @skipIfNoX86 + def test_set_module_name_with_mixed_configs(self): + """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. + + The config for 'v_proj' will always be ignored and raise a warning. + """ + with override_quantized_engine("x86"), torch.no_grad(): + with self.assertWarns(UserWarning) as context: + for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( + [False, True], repeat=4 + ): + if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: + continue + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig( + "q_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=q_is_qat, is_dynamic=q_is_dynamic + ), + ) + .set_module_name_qconfig( + "v_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=v_is_qat, is_dynamic=v_is_dynamic + ), + ) + ) + quant_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequant_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + # quantize and dequantize the input + quant_op: 1, + dequant_op: 1, + # only `q_proj` was quantized, dequantize its weight + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # quantize and dequantize the input + quant_op, + dequant_op, + # q_proj + torch.ops.aten.linear.default, + # k_proj/v_proj + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=q_is_qat, + ) + warning_msg = ( + "Mixed QAT and Non-QAT" + if q_is_qat != v_is_qat + else "Mixed dynamic and static" + ) + self.assertTrue( + any( + warning_msg in msg + for msg in [str(w.message) for w in context.warnings] + ) + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_with_mixed_configs(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. + + Expect that only the last linear(`sub`) is quantized using static quantization. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False) + ).set_module_type_qconfig( + torch.nn.Linear, + xiq.get_default_x86_inductor_quantization_config(is_dynamic=True), + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # Q/DQ pairs for the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + @skipIfNoX86 def test_filter_conv2d_recipe(self): """ @@ -1994,12 +2413,12 @@ def test_attention_block(self): ) node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_tensor.default: 5 - if annotate_matmul - else 1, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7 - if annotate_matmul - else 3, + torch.ops.quantized_decomposed.quantize_per_tensor.default: ( + 5 if annotate_matmul else 1 + ), + torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( + 7 if annotate_matmul else 3 + ), # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, diff --git a/test/run_test.py b/test/run_test.py index 23160d01281c..57e69c0d979c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -24,7 +24,6 @@ import torch.distributed as dist from torch.multiprocessing import current_process, get_context from torch.testing._internal.common_utils import ( - FILE_SCHEMA, get_report_path, IS_CI, IS_MACOS, @@ -745,14 +744,7 @@ def test_distributed(test_module, test_directory, options): old_environ = dict(os.environ) os.environ["TEMP_DIR"] = tmp_dir os.environ["BACKEND"] = backend - os.environ["INIT_METHOD"] = "env://" os.environ.update(env_vars) - if with_init_file: - if test_module.name == "test_distributed_spawn": - init_method = f"{FILE_SCHEMA}{tmp_dir}/" - else: - init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file" - os.environ["INIT_METHOD"] = init_method try: os.mkdir(os.path.join(tmp_dir, "barrier")) os.mkdir(os.path.join(tmp_dir, "test_dir")) @@ -1188,21 +1180,15 @@ def parse_args(): or (IS_WINDOWS and not TEST_CUDA) or TEST_CONFIG == "nogpu_AVX512" or TEST_CONFIG == "nogpu_NO_AVX2" - or ( - "sm86" not in BUILD_ENVIRONMENT - and TEST_CONFIG == "default" - and TEST_CUDA - ) - or (not TEST_CUDA and TEST_CONFIG == "default") + or TEST_CONFIG == "default" ) and get_pr_number() is not None and not strtobool(os.environ.get("NO_TD", "False")) - and not IS_SLOW and not TEST_WITH_ROCM and not IS_MACOS + and "xpu" not in BUILD_ENVIRONMENT and "onnx" not in BUILD_ENVIRONMENT - and "debug" not in BUILD_ENVIRONMENT - and "parallelnative" not in BUILD_ENVIRONMENT, + and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"), ) parser.add_argument( "--shard", diff --git a/test/test_autograd.py b/test/test_autograd.py index 911762024930..ce5b4234b829 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -81,7 +81,7 @@ from torch.utils._python_dispatch import TorchDispatchMode from torch.utils.checkpoint import checkpoint, checkpoint_sequential from torch.utils.cpp_extension import load_inline -from torch.utils.hooks import RemovableHandle +from torch.utils.hooks import RemovableHandle # noqa: TCH001 def graph_desc(fn): @@ -1342,6 +1342,23 @@ def prehook(gI): b.backward() + def test_accumulate_grad_posthooks_should_not_execute(self): + def tensor_prehook(g): + raise RuntimeError + + def posthook(gO, gI): + raise RuntimeError + + a = torch.tensor(1.0, requires_grad=True) + a.register_hook(tensor_prehook) + b = torch.tensor(1.0, requires_grad=True) + c = a.clone() + acc = c.grad_fn.next_functions[0][0] + acc.register_hook(posthook) + + out = a + b + c + out.sum().backward(inputs=[b]) + def test_hook_edge_case_when_called_with_grad(self): # grad executes the tensor hooks of the next node but not # grad_fn pre hooks or the post hooks @@ -9508,6 +9525,13 @@ def f(x): memory_with_hooks = torch.cuda.memory_allocated() self.assertEqual(memory_with_hooks, memory_without_grad) + @unittest.skipIf(not TEST_CUDA, "test requires CUDA") + def test_scalar_grad_mixed_device(self): + x = torch.tensor(1.0, requires_grad=True) + y = torch.randn(2, 2, device="cuda") + out = x * y + out.sum().backward() + def test_multi_grad_all_hooks(self): t1 = torch.rand(2, requires_grad=True) t2 = torch.rand(2, requires_grad=True) diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index 73f51d008bb1..c6ed9efe9f64 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] +# mypy: allow-untyped-defs import io diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 2ba1ee847e8b..007fbd32dde4 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] +# mypy: allow-untyped-defs import io import textwrap diff --git a/test/test_complex.py b/test/test_complex.py index 04fa566bf94f..67e8732dcbe1 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: complex"] import torch diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 3e5ce5cfcef4..eb6d43e4cf00 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -55,6 +55,10 @@ def test_extension_function(self): y = torch.randn(4, 4) z = cpp_extension.sigmoid_add(x, y) self.assertEqual(z, x.sigmoid() + y.sigmoid()) + # test pybind support torch.dtype cast. + self.assertEqual( + str(torch.float32), str(cpp_extension.get_math_type(torch.half)) + ) def test_extension_module(self): mm = cpp_extension.MatrixMultiplier(4, 8) diff --git a/test/test_cuda.py b/test/test_cuda.py index 6ce7555519d7..7ec86bd6f47b 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -29,6 +29,7 @@ ) from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch.testing._internal.common_cuda import ( + _create_scaling_case, _get_torch_cuda_version, TEST_CUDNN, TEST_MULTIGPU, @@ -36,8 +37,9 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, + onlyNativeDeviceTypes, ) -from torch.testing._internal.common_optimizers import optim_db, optims +from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker from torch.testing._internal.common_utils import ( freeze_rng_state, gcIfJetson, @@ -378,10 +380,10 @@ def test_cublas_workspace_explicit_allocation(self): def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() - start = torch.torch.cuda.memory_stats()["active_bytes.all.allocated"] + start = torch.cuda.memory_stats()["active_bytes.all.allocated"] with torch.no_grad(): torch.matmul(inp, inp) - finish = torch.torch.cuda.memory_stats()["active_bytes.all.allocated"] + finish = torch.cuda.memory_stats()["active_bytes.all.allocated"] return finish - start # check default @@ -4741,6 +4743,85 @@ class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. + @onlyNativeDeviceTypes + @optims( + [optim for optim in optim_db if "fused" in optim.supported_impls], + dtypes=[torch.float32], + ) + def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): + device = device.split(":")[0] + if device not in optim_info.supports_fused_on: + self.skipTest( + f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" + ) + optim_inputs = optim_info.optim_inputs_func(device=device) + optim_cls = optim_info.optim_cls + for optim_input in optim_inputs: + for _separate_unscale in (True, False): + kwargs = optim_input.kwargs + kwargs["fused"] = True + torch.manual_seed(20) + ( + mod_control, + mod_scaling, + opt_control, + opt_scaling, + data, + loss_fn, + _, + ) = _create_scaling_case( + optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device + ) + optimizer_kwargs = deepcopy(kwargs) + optimizer_kwargs["fused"] = False + if "lr" not in kwargs: + # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr + optimizer_kwargs["lr"] = 1.0 + opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs) + scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) + scaler_control = torch.amp.GradScaler(device, init_scale=128.0) + tracker = TensorTracker() + for input, target in data: + opt_control.zero_grad() + with torch.autocast(device_type=device, dtype=torch.half): + output_control = mod_control(input) + loss_control = loss_fn(output_control, target) + scaler_control.scale(loss_control).backward() + scaler_control.step(opt_control) + scaler_control.update() + + opt_scaling.zero_grad() + with torch.autocast(device_type=device, dtype=torch.half): + output_scaling = mod_scaling(input) + loss_scaling = loss_fn(output_scaling, target) + scaler_scaling.scale(loss_scaling).backward() + if _separate_unscale: + scaler_scaling.unscale_(opt_scaling) + scaler_scaling.step(opt_scaling) + scaler_scaling.update() + + tracker.add(loss_control) + tracker.pop_check_set(loss_scaling, self) + for param_control, param_scaling in zip( + mod_control.parameters(), mod_scaling.parameters() + ): + tracker.add(param_control.grad) + tracker.pop_check_set(param_scaling.grad, self) + tracker.add(param_control) + tracker.pop_check_set(param_scaling, self) + + state_control, state_scaling = ( + opt_control.state[param_control], + opt_scaling.state[param_scaling], + ) + + for k in state_control: + actual = state_scaling[k] + if k == "step": + actual = actual.squeeze() + tracker.add(state_control[k]) + tracker.pop_check_set(actual, self) + @onlyCUDA @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index db03098a0fec..e2af3efaa98a 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3154,6 +3154,21 @@ def test_opcheck_bad_op(self): }, ) + def test_opcheck_does_not_require_extra_deps(self): + # torch.testing._internal.common_utils comes with a lot of additional + # test-time dependencies. Since opcheck is public API, it should be + # usable only with pytorch install-time dependencies. + cmd = [ + sys.executable, + "-c", + "import torch; import sys; \ + x = torch.randn(3, requires_grad=True); \ + torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \ + assert 'expecttest' not in sys.modules; \ + assert 'torch.testing._internal.common_utils' not in sys.modules", + ] + subprocess.check_output(cmd, shell=False) + only_for = ("cpu", "cuda") instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index b6be7eb76b97..37cf896eda24 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -2423,7 +2423,7 @@ def test_batch_mapdatapipe(self): _generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9) if _generic_namedtuple_allowed: - class InvalidData(Generic[T_co], NamedTuple): + class InvalidData(NamedTuple, Generic[T_co]): name: str data: T_co diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d548e9df0707..60ce1fb764ec 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -41,7 +41,11 @@ ) from torch.utils import _pytree as pytree from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils._sympy.functions import FloorDiv, Mod +from torch.utils._sympy.functions import ( + FloorDiv, + IsNonOverlappingAndDenseIndicator, + Mod, +) aten = torch.ops.aten @@ -205,15 +209,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True): +def create_symint(shape_env, i: int, duck=True) -> SymInt: return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool): +def create_symbool(shape_env, b: bool) -> SymBool: return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float): +def create_symfloat(shape_env, f: float) -> SymFloat: return create_symtype(SymFloat, float, shape_env, f) @@ -457,14 +461,16 @@ def test_sym_int(self): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" + ) a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" + str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" ) def test_sym_sqrt(self): @@ -474,7 +480,7 @@ def test_sym_sqrt(self): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" ) def test_sym_floor(self): @@ -483,11 +489,17 @@ def test_sym_floor(self): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), + """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", + ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_trunc(self): shape_env = ShapeEnv() @@ -495,12 +507,14 @@ def test_sym_trunc(self): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" + ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -510,12 +524,17 @@ def test_sym_ceil(self): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" + str(shape_env.guards[0][0]), + """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", ) - r = math.floor(3.0 * a0) + r1 = 3.0 * a0 + r = math.floor(r1) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_ite(self): shape_env = ShapeEnv() @@ -751,6 +770,70 @@ def test_non_overlapping_and_dense(self): r = torch.empty_strided((a0, 7), (1, a0), device="meta") self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) + def test_non_overlapping_and_dense_unbacked(self): + shape_env = ShapeEnv() + u0 = shape_env.create_unbacked_symint() + torch._check_is_size(u0) + cf = torch.ops.aten.is_non_overlapping_and_dense.default + + self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1) + self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1) + self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta"))) + self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta"))) + + self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1) + self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1) + self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta"))) + self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta"))) + + Max = torch.sym_max + # NB: This only works because we're able to determine this tensor is + # contiguous. transpose(0, 1) makes it stop working + self.assertTrue( + cf( + torch.empty_strided( + ( + 2, + 3, + 1, + u0, + ), + (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), + device="meta", + ) + ) + ) + + def test_debug_has_internal_overlap_unbacked(self): + shape_env = ShapeEnv() + u0 = shape_env.create_unbacked_symint() + torch._check_is_size(u0) + cf = torch._debug_has_internal_overlap + self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0) + self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0) + Max = torch.sym_max + self.assertEqual( + cf( + torch.empty_strided( + ( + 2, + 3, + 1, + u0, + ), + (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), + device="meta", + ) + ), + 0, + ) + + # Wobbling these to zero is OK too + self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2) + self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2) + def test_specialize_zero_one(self): shape_env = ShapeEnv(specialize_zero_one=True) a0 = create_symint(shape_env, 5) @@ -962,8 +1045,14 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): + with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): + return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) + + def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized + # TODO: We don't have to circuitously create the float, can just + # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -976,27 +1065,42 @@ def get_sym_inp(inp): else: return torch.SymFloat(to_node(seed_node, inp)) + if fn == "float_pow": + if inp1 < 0: + return + + if fn == "pow_by_natural": + if isinstance(inp1, float) or isinstance(inp2, float): + return + if inp2 < 0: + return + def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: + elif ( + fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") + and inp2 == 0 + ): # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn == "pow" and inp1 == 0 and inp2 < 0: + elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - fn == "pow" + # TODO: dear catastrophe waitress, + # this doesn't work + fn in ["float_pow", "pow_by_natural"] and inp1 < 0 - and inp2 in (2.5, -2.5) and ( - type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) + type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) ) + and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((TypeError,)) + return self.assertRaises((RuntimeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1080,6 +1184,9 @@ def test_method(self, fn, first_type, second_type): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") + if second_type == "float" and fn in ["mod"]: + self.skipTest(f"{fn} only handles int") + is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1251,112 +1358,15 @@ def yield_test_cases(values, negate=True): yield (-x, -y) def test_floordiv_float_int(self): - values = ( - (2.5, 2.1), - (2.1, 2.5), - (2.0, 2.1), - (7, 2.5), - (2.1, 7), - (7, 2), - ) + values = ((7, 2),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) - def test_floordiv_bool(self): - values = ( - (False, True), - (True, 2.5), - (2.5, True), - (False, 7), - (7, True), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # Compares to int since our FloorDiv has no bool support - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(int(x), int(y)), - ) - # Tests that our impl throws - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_complex(self): - values = ( - (1.5 + 2.5j, 1.3 + 3.5j), - (1.5 + 2.5j, 2.5), - (2.5, 1.5 + 2.5j), - (1.5 + 2.5j, 7), - (7, 1.5 + 2.5j), - ) - - for x, y in TestFloorDiv.yield_test_cases(values): - # We don't test error messages to avoid depending on Python - # interpreter version - self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_div_by_zero(self): - values = ( - (2.5, 0), - (2.1, 0.0), - (2.3, sympy.Symbol("s", zero=True)), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # We don't test error messages to avoid depending on Python - # interpreter version - if type(y) is not sympy.Symbol: - self.assertRaises( - ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) - ) - self.assertRaisesRegex( - ZeroDivisionError, - "division by zero", - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_zero_base(self): - values = ( - (0, 2.5), - (0.0, 2.1), - (sympy.Symbol("s", zero=True), 2.3), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - if type(x) is not sympy.Symbol: - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(x, y), - ) - else: - self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) - def test_floordiv_div_by_one(self): - values = ( - (2.5, 1), - (2.1, 1.0), - (2, 1.0), - (2, 1), - ) + values = ((2, 1),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1367,12 +1377,7 @@ def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = ( - 7 * FloorDiv(6, 2), - 7 * FloorDiv(6.28, 2), - 7 * FloorDiv(6.28, 2.0), - 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), - ) + exprs = (7 * FloorDiv(6, 2),) for expr in exprs: self.assertEqual(expr, result) @@ -1382,33 +1387,10 @@ def test_floordiv_simplify(self): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) - def test_floordiv_simplify_rational(self): - result = 21 - - a = sympy.Symbol("a", integer=True) - b = sympy.Symbol("b") - - cases = [ - (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), - (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), - ] - - for expr, expected in cases: - self.assertEqual(expr, expected) - def test_floordiv_assumptions(self): - # We define two Symbols (with different names) for each type to make - # sure the behavior is consistent regardless of whether both arguments - # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), - sympy.Symbol("r1", real=True), - sympy.Symbol("r2", real=True), - sympy.Symbol("c1", complex=True, real=False, integer=False), - sympy.Symbol("c2", complex=True, real=False, integer=False), - sympy.Symbol("s1"), - sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 7456feb45d82..e5b36c47048b 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -619,6 +619,15 @@ def test_data_dependent_operator(self): self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x)) + def test_parameter_view(self): + x = torch.nn.Parameter(torch.randn(4)) + x_view = x.view(4) + mode = FakeTensorMode() + fake_x_view = mode.from_tensor(x_view) + fake_x = mode.from_tensor(x) + self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter)) + self.assertTrue(isinstance(fake_x, torch.nn.Parameter)) + def test_tolist(self): shape_env = ShapeEnv() with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env): diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 4f9c7020c0e6..9f3a8ce223e5 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -9,11 +9,13 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + PLATFORM_SUPPORTS_CUDNN_ATTENTION ) from torch.testing._internal.common_utils import ( run_tests, TEST_WITH_TORCHDYNAMO, TestCase, + skipIfRocm, ) try: @@ -299,7 +301,8 @@ def test_noop(self): @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION - or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION + or not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", ) def test_sdpa(self): @@ -354,15 +357,31 @@ def get_flops( if backend == "math": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=True, enable_mem_efficient=False + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "flash": backend = torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=False, enable_mem_efficient=False + enable_flash=True, + enable_math=False, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "mem_efficient": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=False, enable_mem_efficient=True + enable_flash=False, + enable_math=False, + enable_mem_efficient=True, + enable_cudnn=False, + ) + elif backend == "cudnn": + backend = torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=False, + enable_mem_efficient=False, + enable_cudnn=True, ) mode = FlopCounterMode() @@ -388,22 +407,24 @@ def get_flops( flops = [ run_uniform_flops(backend, with_backward=False) - for backend in ["math", "flash", "mem_efficient"] + for backend in ["math", "flash", "mem_efficient", "cudnn"] ] - flops_fw_math, flops_fw_flash, flops_fw_efficient = flops + flops_fw_math, flops_fw_flash, flops_fw_efficient, flops_fw_cudnn = flops self.assertEqual(flops_fw_math, flops_fw_flash) self.assertEqual(flops_fw_math, flops_fw_efficient) + self.assertEqual(flops_fw_math, flops_fw_cudnn) self.assertExpectedInline(str(flops_fw_math), """134217728""") flops = [ run_uniform_flops(backend, with_backward=True) - for backend in ["math", "flash", "mem_efficient"] + for backend in ["math", "flash", "mem_efficient", "cudnn"] ] - flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops + flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient, flops_fw_bw_cudnn = flops self.assertEqual(flops_fw_math * 3, flops_fw_bw_math) self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash) self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient) + self.assertEqual(flops_fw_bw_flash, flops_fw_bw_cudnn) run_nonuniform_flops = functools.partial( get_flops, @@ -434,6 +455,7 @@ def get_flops( self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") + @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -446,15 +468,24 @@ def get_flops(q, k, v, backend, with_backward=False): if backend == "math": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=True, enable_mem_efficient=False + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "flash": backend = torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=False, enable_mem_efficient=False + enable_flash=True, + enable_math=False, + enable_mem_efficient=False, + enable_cudnn=False, ) elif backend == "mem_efficient": backend = torch.backends.cuda.sdp_kernel( - enable_flash=False, enable_math=False, enable_mem_efficient=True + enable_flash=False, + enable_math=False, + enable_mem_efficient=True, + enable_cudnn=False, ) with backend, mode: diff --git a/test/test_foreach.py b/test/test_foreach.py index c46ff8ae21b6..61d81d18db7b 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -88,7 +88,9 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): actual = self.func(*inputs, **kwargs) keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) - assert mta_called == (expect_fastpath and (not zero_size)) + assert mta_called == ( + expect_fastpath and (not zero_size) + ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}" else: actual = self.func(*inputs, **kwargs) if self.is_inplace: @@ -595,7 +597,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op): # Empty lists for fop in ops_to_test: with self.assertRaisesRegex( - RuntimeError, "There were no tensor arguments to this function" + RuntimeError, "Tensor list must have at least one tensor." ): fop(tensors1, tensors2) @@ -922,7 +924,10 @@ def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): # note: BFloat16 has the same number of exponent bits as FP32 # so if squared L2 norm overflows in BF16, then it also overflows in FP32. @onlyCUDA - @ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16)) + @ops( + [o for o in foreach_reduce_op_db if "norm" in o.name], + allowed_dtypes=(torch.half, torch.bfloat16), + ) def test_foreach_l2_large_value_input(self, device, dtype, op): ord, N = 2, 10 max_value = torch.finfo(dtype).max @@ -976,14 +981,20 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): import math - for ord in (1, 2, math.inf): + if op.name == "_foreach_norm": + ords = (1, 2, math.inf) + else: + ords = (None,) + + for ord in ords: + kwargs = {"ord": ord} if ord else {} if not use_cuda_graph: actual = fn( inputs=[tensorlist], is_cuda=True, expect_fastpath=True, - ord=ord, zero_size=False, + **kwargs, ) else: # When using CUDA graphs and the tensor metadata doesn't fit in @@ -993,9 +1004,9 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): # test verifies multi_tensor_apply's behavior in the scenario. g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - actual = fn.func(tensorlist, ord=ord) + actual = fn.func(tensorlist, **kwargs) g.replay() - expect = ref_fn(inputs=[tensorlist], ord=ord) + expect = ref_fn(inputs=[tensorlist], **kwargs) self.assertEqual(expect, actual, equal_nan=True) @@ -1003,16 +1014,23 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): @ops(foreach_reduce_op_db) def test_foreach_reduce_large_input(self, device, dtype, op): # test inputs larger than kChunkSize = 65536 - ord, N = 2, 65536 * 2 - disable_fastpath = True - if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16): - disable_fastpath = False + N = 65536 * 2 + disable_fastpath = False + kwargs = {} + if op.name == "_foreach_norm": + ord = 2 + disable_fastpath = not ( + ord in (1, 2) + and dtype in floating_types_and(torch.half, torch.bfloat16) + ) + kwargs["ord"] = ord + inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],) wrapped_op, ref, _, _ = self._get_funcs(op) self.assertEqual( - ref(inputs, ord=ord), + ref(inputs, **kwargs), wrapped_op( - inputs, self.is_cuda, not disable_fastpath, ord=ord, zero_size=False + inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs ), ) @@ -1230,6 +1248,28 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): copy_(t, s, non_blocking) self.assertEqual(ref_input, sample.input) + @onlyCUDA + @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) + def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): + # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_ + foreach_copy_ = ForeachFuncWrapper(op.inplace_variant) + for sample in op.sample_inputs(device, dtype, noncontiguous=False): + for src_dtype in floating_types_and(torch.half, torch.bfloat16): + if src_dtype == dtype: + continue + self_tensors = [t.clone() for t in sample.input] + src_tensors = [t.to(src_dtype) for t in self_tensors] + out = foreach_copy_( + (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True + ) + self.assertEqual( + out, + [ + torch.empty_like(t).copy_(s) + for t, s in zip(self_tensors, src_tensors) + ], + ) + # Test reverse-mode & forward-mode AD if supported. @onlyCUDA @ops( diff --git a/test/test_futures.py b/test/test_futures.py index 33814eda41ea..dd1e79ff83b3 100644 --- a/test/test_futures.py +++ b/test/test_futures.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: unknown"] import threading diff --git a/test/test_jit.py b/test/test_jit.py index bb6f4e255888..0e99c3602cd6 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4329,7 +4329,7 @@ def foobar(xyz): return torch.blargh(xyz) _, lineno = inspect.getsourcelines(foobar) - with self.assertRaisesRegex(RuntimeError, f"test_jit.py\", line {lineno + 1}"): + with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 1}'): scripted = torch.jit.script(foobar) def test_file_line_error_class_defn(self): @@ -4338,7 +4338,7 @@ def baz(self, xyz): return torch.blargh(xyz) _, lineno = inspect.getsourcelines(FooBar) - with self.assertRaisesRegex(RuntimeError, f"test_jit.py\", line {lineno + 2}"): + with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 2}'): torch.jit.script(FooBar) def test_file_line_graph(self): @@ -4405,7 +4405,7 @@ def forward(self, x, w): loaded = self.getExportImportCopy(ft) _, lineno = inspect.getsourcelines(FooTest) - with self.assertRaisesRegex(RuntimeError, f'test_jit.py\", line {lineno + 3}'): + with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 3}'): loaded(torch.rand(3, 4), torch.rand(30, 40)) def test_serialized_source_ranges_graph(self): @@ -4431,7 +4431,7 @@ def forward(self): _, lineno = inspect.getsourcelines(FooTest2) - with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py\", line {lineno + 3}'): + with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py", line {lineno + 3}'): ft = FooTest2() loaded = self.getExportImportCopy(ft) loaded() @@ -10260,7 +10260,7 @@ def fn(x): n = next(graph.inputs()) self.assertTrue(n.type() == torch._C.TensorType.getInferred()) - with self.assertRaisesRegex(RuntimeError, "Inferred \'x\' to be of type \'Tensor"): + with self.assertRaisesRegex(RuntimeError, "Inferred 'x' to be of type 'Tensor"): fn("1") def test_script_define_order(self): @@ -12309,7 +12309,7 @@ def forward(self, x): tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) FileCheck().check_not("value=").check("aten::mm")\ - .check("prim::CallMethod[name=\"forward\"]").check("aten::add") \ + .check('prim::CallMethod[name="forward"]').check("aten::add") \ .run(str(tm.graph)) FileCheck().check("aten::mm").run(str(tm.mod.graph)) @@ -14743,7 +14743,7 @@ def forward(self): return self.hello("hi"), self.hello(.5) w = CompileOverloadError() - with self.assertRaisesRegex(Exception, "but instead found type \'str\'"): + with self.assertRaisesRegex(Exception, "but instead found type 'str'"): torch.jit.script(w) # testing overload declared first, then non-overload diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 071249192ec6..7b087d361d8b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1,15 +1,16 @@ # Owner(s): ["NNC"] +import contextlib +import math import operator import os import unittest -import contextlib -import math +import warnings +from typing import List + import torch import torch.nn.functional as F from torch.testing import FileCheck -from typing import List -import warnings # these needs to be set before `common_utils` # infers `GRAPH_EXECUTOR`. @@ -20,42 +21,79 @@ torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) -from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ - enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \ - TEST_WITH_ROCM, IS_FBCODE -from torch.testing._internal.jit_utils import JitTestCase, \ - RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \ - clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager - -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \ - OpDTypes -from torch.testing._internal.common_jit import JitCommonTestCase -from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn +from itertools import combinations, permutations, product from textwrap import dedent -from itertools import product, permutations, combinations - -from test_jit import backward_graph, get_lstm_inputs, get_milstm_inputs, \ - LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell from jit.test_fuser_common import TestFuserCommon # noqa: F401 -FUSION_GROUP = 'prim::TensorExprGroup' +from test_jit import ( + backward_graph, + get_lstm_inputs, + get_milstm_inputs, + LSTMCellC, + LSTMCellF, + LSTMCellS, + MiLSTMCell, +) + +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + OpDTypes, + ops, +) +from torch.testing._internal.common_jit import JitCommonTestCase + +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import ( + enable_profiling_mode_for_profiling_tests, + GRAPH_EXECUTOR, + IS_FBCODE, + ProfilingMode, + run_tests, + skipIfTorchDynamo, + slowTest, + TEST_WITH_ASAN, + TEST_WITH_ROCM, +) +from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn +from torch.testing._internal.jit_utils import ( + clone_inputs, + get_traced_sample_variant_pairs, + JitTestCase, + NoTracerWarnContextManager, + RUN_CUDA, + RUN_CUDA_HALF, + RUN_CUDA_MULTI_GPU, + set_fusion_group_inlining, + TensorExprTestOptions, + warmup_backward, +) + +FUSION_GROUP = "prim::TensorExprGroup" LLVM_ENABLED = torch._C._llvm_enabled() -autograd_check_set = {'aten::__is__', 'prim::AutogradAllNonZero', 'prim::AutogradAllZero', 'prim::ListConstruct'} +autograd_check_set = { + "aten::__is__", + "prim::AutogradAllNonZero", + "prim::AutogradAllZero", + "prim::ListConstruct", +} + def strip_profiling_nodes(nodes): - profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'} + profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"} return [n for n in nodes if n.kind() not in profiling_opcodes] + def warmup_forward(f, *args, profiling_count=2): for i in range(profiling_count): results = f(*args) return results + @contextlib.contextmanager def texpr_reductions_enabled(): old = torch._C._jit_set_texpr_reductions_enabled(True) @@ -64,6 +102,7 @@ def texpr_reductions_enabled(): finally: torch._C._jit_set_texpr_reductions_enabled(old) + @contextlib.contextmanager def texpr_enable_strategy(strategy): old = torch._C._jit_set_fusion_strategy(strategy) @@ -72,6 +111,7 @@ def texpr_enable_strategy(strategy): finally: torch._C._jit_set_fusion_strategy(old) + @contextlib.contextmanager def inline_fusion_groups(): old_inlining = torch._C._debug_get_fusion_group_inlining() @@ -93,7 +133,7 @@ def setUp(self): fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)] self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy) - self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] self.int_dtypes = [ torch.int8, torch.int16, @@ -117,7 +157,11 @@ def tearDown(self): def assertAllFused(self, graph, except_for=None): except_for = except_for if except_for is not None else set() # TODO - upstream - guards = "prim::TypeCheck", "prim::RequiresGradCheck", "prim::TensorExprDynamicGuard" + guards = ( + "prim::TypeCheck", + "prim::RequiresGradCheck", + "prim::TensorExprDynamicGuard", + ) guard_found = False def autodiff_guard(node): @@ -128,7 +172,10 @@ def autodiff_guard(node): return False li_inps = list(inps[0].node().inputs()) for li_inp in li_inps: - if li_inp.node().kind() in ("prim::AutogradAllNonZero", "prim::AutogradAllZero"): + if li_inp.node().kind() in ( + "prim::AutogradAllNonZero", + "prim::AutogradAllZero", + ): return True return False @@ -151,7 +198,6 @@ def is_guard(node): self.assertTrue(guard_found) - def assertLastGraphAllFused(self): self.assertAllFused(torch.jit.last_executed_optimized_graph()) @@ -159,7 +205,7 @@ def findFusionGroups(self, graph): result = [] for n in graph.nodes(): if n.kind() == FUSION_GROUP: - result.append(n.g('Subgraph')) + result.append(n.g("Subgraph")) continue for block in n.blocks(): result += self.findFusionGroups(block) @@ -169,7 +215,7 @@ def test_typecheck(self): a = torch.ones(1) def fused_kernel(a, b): - return (a + b) * 2. + return (a + b) * 2.0 scripted = self.checkScript(fused_kernel, (a, a)) graph = scripted.graph_for(a, a) @@ -191,7 +237,7 @@ def func(x): return x2.sum() with texpr_reductions_enabled(): - a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() @@ -201,13 +247,13 @@ def test_nop(self): def test_sum_dim(self): def func(x): - return x.sum((0, )) * 2 + return x.sum((0,)) * 2 def func_neg(x): - return x.sum((-2, )) * 2 + return x.sum((-2,)) * 2 with texpr_reductions_enabled(): - a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() @@ -216,10 +262,10 @@ def func_neg(x): def test_sum_keepdim_cast(self): def func(x): - return x.sum((0, ), keepdim=True, dtype=torch.double) * 2 + return x.sum((0,), keepdim=True, dtype=torch.double) * 2 with texpr_reductions_enabled(): - a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") a = a.reshape(5, 3) self.checkScript(func, (a,)) @@ -227,6 +273,7 @@ def func(x): def test_abs(self): for device in self.devices: + def func(x): return x.abs() * 2 @@ -236,19 +283,24 @@ def func(x): def test_unsqueeze_size_calculation(self): for device in self.devices: + def foo(b, d): x = d.unsqueeze(1) - y = x * 42. + y = x * 42.0 z = b + y - r = z / 42. + r = z / 42.0 return r - inputs = (torch.rand(20, 28, device=device, requires_grad=True), torch.rand(20, device=device)) + inputs = ( + torch.rand(20, 28, device=device, requires_grad=True), + torch.rand(20, device=device), + ) scripted = self.checkScript(foo, inputs) self.assertAllFused(scripted.graph_for(*inputs)) def test_zero_element_tensors(self): for device in self.devices: + def decode(sin_t, cos_t): theta = torch.atan2(sin_t.float(), cos_t.float()) return theta @@ -267,17 +319,25 @@ def test_arg_configurations_smoke(self): # TODO: add optionally enabled debug counters to the fuser to verify # that we really can tell the difference between configurations for device in self.devices: + def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) - traced_f = torch.jit.trace(f, (x, y,)) + traced_f = torch.jit.trace( + f, + ( + x, + y, + ), + ) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) def test_broadcast(self): for device in self.devices: + def scaleshift(x, scale, shift): return x * scale + shift @@ -290,16 +350,14 @@ def scaleshift(x, scale, shift): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") + @unittest.skipIf( + GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on" + ) def test_cuda_half(self): - x = torch.randn(4, 4, dtype=torch.half, device='cuda') - y = torch.randn(4, 4, dtype=torch.half, device='cuda') + x = torch.randn(4, 4, dtype=torch.half, device="cuda") + y = torch.randn(4, 4, dtype=torch.half, device="cuda") - funcs = [ - self.fn_test_comparison_gt_lt, - self.fn_test_relu, - self.fn_test_exp - ] + funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp] # Note: Non fused inputs must be float to prevent loss of precision inputs = (x.float(), y.float()) @@ -318,9 +376,17 @@ def test_cuda_half(self): # Verifies gradients for output, fusion_output in zip(outputs_half, fusion_outputs): grads = torch.autograd.grad( - output.float().sum(), local_inputs, allow_unused=True, retain_graph=True) + output.float().sum(), + local_inputs, + allow_unused=True, + retain_graph=True, + ) fusion_grads = torch.autograd.grad( - fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True) + fusion_output.sum(), + local_fusion_inputs, + allow_unused=True, + retain_graph=True, + ) grads_half = [t.half() for t in grads] self.assertEqual(grads_half, fusion_grads) @@ -332,7 +398,7 @@ def test_checks_cat_inputs(self): # need to be checked for having the same map size, before we can # run the kernel. def f(x, y): - return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) + return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0) # NOTE: y is broadcastable to x, but output of f(x, y) should have # shape 3x4, and not 4x4. @@ -348,6 +414,7 @@ def test_chunk(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def fn(x): a, b, c = x.chunk(3, 1) return a * b + c @@ -362,6 +429,7 @@ def test_chunk_correctness(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def chunk_4_0(x): x0, x1, x2, x3 = x.chunk(4, 0) return x0 + x1 + x2 + x3 @@ -378,12 +446,12 @@ def chunk_4_last(x): tensors = [ # splitSize = 1 torch.randn(4, 4, 4, dtype=torch.float, device=device), - # contiguous case torch.randn(12, 8, 16, dtype=torch.float, device=device), - # non-contiguous case - torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2), + torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose( + 1, 2 + ), ] for tensor in tensors: @@ -399,6 +467,7 @@ def test_chunk_distributes(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 @@ -420,6 +489,7 @@ def test_chunk_motion_deduplicates_inputs(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def func1(x): z = x * x z0, z1 = z.chunk(2) @@ -462,6 +532,7 @@ def fn(s, x, y, z): def test_minmax(self): for device in self.devices: + def tmax(a, b): return torch.max(2 * a, b) @@ -470,26 +541,26 @@ def tmin(a, b): a = torch.randn(4, 4, dtype=torch.float) b = torch.randn(4, 4, dtype=torch.float) - nan = torch.tensor(float('nan'), dtype=torch.float) + nan = torch.tensor(float("nan"), dtype=torch.float) for f, inputs, device in product( - (tmax, tmin), - ([a, b], [a, nan], [b, nan]), - self.devices): + (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices + ): inputs = [t.to(device) for t in inputs] s = self.checkScript(f, inputs) self.assertAllFused(s.graph_for(*inputs)) def test_clamp(self): for device in self.devices: + def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): - return torch.clamp(a + b, min=0, max=float('inf')) + return torch.clamp(a + b, min=0, max=float("inf")) def funcNegInf(a, b): - return torch.clamp(a + b, min=float('-inf'), max=0) + return torch.clamp(a + b, min=float("-inf"), max=0) def funcOptMin(a, b): return torch.clamp(a + b, max=2) @@ -499,31 +570,47 @@ def funcOptMax(a, b): a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device=device) - nan = torch.tensor(float('nan'), dtype=torch.float, device=device) + nan = torch.tensor(float("nan"), dtype=torch.float, device=device) funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) - self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) + self.assertAllFused( + s.graph_for(inp1, inp2), + except_for={"aten::size", "aten::_size_if_not_equal"}, + ) c = s(inp1, inp2) with enable_profiling_mode_for_profiling_tests(): warmup_backward(c.sum()) graph = backward_graph(s) - self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}.union(autograd_check_set)) + self.assertAllFused( + graph, + except_for={"aten::Float", "aten::_grad_sum_to_size"}.union( + autograd_check_set + ), + ) def test_clamp_double(self): for device in self.devices: + def clamp_double(x, eta: float): return 1 - x.clamp(eta, 1 - eta) x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device) eta = 1e-9 - s = self.checkScript(clamp_double, (x, eta), profiling=ProfilingMode.PROFILING, atol=1e-10, rtol=1e-5) - self.assertAllFused(s.graph_for(x, eta), except_for={'aten::sub'}) + s = self.checkScript( + clamp_double, + (x, eta), + profiling=ProfilingMode.PROFILING, + atol=1e-10, + rtol=1e-5, + ) + self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"}) def test_clamp_int(self): for device in self.devices: + def clamp_int(x, eta: int): return x.clamp(0, eta) @@ -535,6 +622,7 @@ def clamp_int(x, eta: int): def test_add_bool(self): sizes = [(1,), (2,), (4, 4)] for device, size in product(self.devices, sizes): + def f(x, y, z): return x + y + z @@ -546,6 +634,7 @@ def f(x, y, z): def test_mul_bool(self): for device in self.devices: + def f(x, y, z): return x * y * z @@ -558,6 +647,7 @@ def f(x, y, z): def test_div_bool(self): for device in self.devices: + def f(x, y, z): return (x + y) / z @@ -605,10 +695,7 @@ def test_minmax_int_ops(self): def apply(fn): return lambda x, y, z: fn(fn(x, y), z) - binary_ops = [ - torch.min, - torch.max - ] + binary_ops = [torch.min, torch.max] devices = self.devices for dtype, op, device in product(self.int_dtypes, binary_ops, devices): try: @@ -633,6 +720,7 @@ def apply(fn): def test_comparison_eq_ne(self): for device in self.devices: + def f(x, y): mask = (x == 0).type_as(x) z = x * mask + y @@ -664,6 +752,7 @@ def test_comparison_gt_lt(self): def test_comparison_ge_le(self): for device in self.devices: + def f(x, y): mask = (x >= 0).type_as(x) z = x * mask + y @@ -678,8 +767,14 @@ def f(x, y): self.assertAllFused(ge.graph_for(x, y)) x.requires_grad_(True) y.requires_grad_(True) - self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + ge.graph_for(x, y), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) def test_addcmul(self): for device in self.devices: @@ -694,7 +789,9 @@ def foo(t, t1, t2): graph = ge.graph_for(t, t1, t2) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) - FileCheck().check("aten::add(").check("aten::addcmul(").run(str(fusion_groups[0])) + FileCheck().check("aten::add(").check("aten::addcmul(").run( + str(fusion_groups[0]) + ) # TODO: We leak CUDA memory here because the traced graph holds onto a # constant-ified tensor. Since the Python-global CompilationUnit is alive @@ -743,6 +840,7 @@ def foo(hx, cx): def test_remove_output_used_only_in_size(self): for device in self.devices: + def test_fuse(a, b): c = a + b d = c + b @@ -753,10 +851,10 @@ def test_fuse(a, b): y = torch.ones(1, requires_grad=True, device=device) warmup_forward(scripted_f, x, y, profiling_count=3) g = scripted_f.graph_for(x, y) - diff_nodes = g.findAllNodes('prim::DifferentiableGraph') + diff_nodes = g.findAllNodes("prim::DifferentiableGraph") self.assertEqual(len(diff_nodes), 1) - g = diff_nodes[0].g('Subgraph') - if_nodes = [n for n in g.nodes() if n.kind() == 'prim::If'] + g = diff_nodes[0].g("Subgraph") + if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"] self.assertEqual(len(if_nodes), 1) # the if node and the fusion group inside it should only have one output @@ -777,13 +875,13 @@ def fn(x, y, z): z = torch.randn(4, 2, dtype=torch.float, device=device) ge = self.checkTrace(fn, (x, y, z)) graph = ge.graph_for(x, y, z) - self.assertAllFused(graph, except_for={'aten::add'}) + self.assertAllFused(graph, except_for={"aten::add"}) # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @staticmethod def fn_test_exp(x, y): - return (x + .5 * y).exp() + return (x + 0.5 * y).exp() def test_exp(self): for device in self.devices: @@ -795,6 +893,7 @@ def test_exp(self): def test_threshold(self): for device in self.devices: + def f(x): return torch.threshold(x, 0, -10) + x + x + x @@ -804,6 +903,7 @@ def f(x): def test_scalar_arg(self): for device in self.devices: + def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: return p * (x * x + x) @@ -816,15 +916,23 @@ def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: # use another function otherwise we will bailout # and won't be able to do fused checks - def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor: + def fn_test_scalar_arg_requires_grad( + x: torch.Tensor, p: float + ) -> torch.Tensor: return p * (x * x + x) scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) out = scripted(x, p) out = scripted(x, p) out = scripted(x, p) - self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + scripted.graph_for(x, p), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") @@ -861,8 +969,8 @@ def fn(x, y, z): inputs = [ torch.randn(4, 4, dtype=torch.float), - torch.randn(4, 4, dtype=torch.float, device='cuda:0'), - torch.randn(4, 4, dtype=torch.float, device='cuda:1'), + torch.randn(4, 4, dtype=torch.float, device="cuda:0"), + torch.randn(4, 4, dtype=torch.float, device="cuda:1"), ] prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() @@ -870,8 +978,7 @@ def fn(x, y, z): # There are 3 FusionGroups. Because they have the same graph, they # should reuse the same KernelSpec in the KernelSpec cache. ge = self.checkScript(fn, inputs) - self.assertGraphContainsExactly( - ge.graph_for(*inputs), FUSION_GROUP, 3, True) + self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True) new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() # XXX: This assumes that the same kernel isn't already used by another test # FIXME: Use the TE fuser's way of querying the cache. @@ -879,7 +986,7 @@ def fn(x, y, z): @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") def test_nonzero_device_cuda(self): - device = 'cuda:' + str(1) + device = "cuda:" + str(1) x = torch.tensor([0.4], dtype=torch.float, device=device) y = torch.tensor([0.7], dtype=torch.float, device=device) @@ -893,7 +1000,9 @@ def test_lstm(self): for device in self.devices: inputs = get_lstm_inputs(device, training=True) module = self.checkScript(LSTMCellS, inputs) - self.assertAllFused(module.graph_for(inputs), except_for={"prim::TupleConstruct"}) + self.assertAllFused( + module.graph_for(inputs), except_for={"prim::TupleConstruct"} + ) def test_lstm_concat(self): # single fusion node causes error @@ -905,7 +1014,9 @@ def test_lstm_concat(self): except_nodes = {"prim::TupleConstruct", "aten::linear"} # TODO... Chunk if self.dynamic_shapes: - except_nodes = except_nodes.union({"aten::add", "prim::ConstantChunk"}) + except_nodes = except_nodes.union( + {"aten::add", "prim::ConstantChunk"} + ) self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes) # XXX: TE fuser can handle concats inside a fusion group. # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @@ -914,13 +1025,15 @@ def test_lstm_gates_permutations(self): for device in self.devices: # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. # Test that any permutation of this will still result in one FusionGroup. - choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] - template = dedent(''' + choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"] + template = dedent( + """ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): gates = {} + {} + {} + {} ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) return ingate * forgetgate * cellgate * outgate - ''') + """ + ) for permutation in permutations(choices, len(choices)): code = template.format(*permutation) scope = {} @@ -928,9 +1041,11 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): cu = torch.jit.CompilationUnit(code) fusion_group_len = 2 if self.dynamic_shapes else 1 inputs = get_lstm_inputs(device, training=False) - self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) + self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs)) forward_graph = cu.cell.graph_for(*inputs) - self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, fusion_group_len) + self.assertGraphContainsExactly( + forward_graph, FUSION_GROUP, fusion_group_len + ) # TODO: Fuser doesn't work at all when inputs require grad. Fix that def test_lstm_traced(self): @@ -945,7 +1060,9 @@ def test_lstm_traced(self): f = FileCheck() if not self.dynamic_shapes: f.check("Chunk") - f.check("aten::sigmoid").check("aten::tanh").run(str(fusion_groups[0 if not self.dynamic_shapes else 1])) + f.check("aten::sigmoid").check("aten::tanh").run( + str(fusion_groups[0 if not self.dynamic_shapes else 1]) + ) def test_milstm(self): if self.dynamic_shapes: @@ -958,9 +1075,11 @@ def test_milstm(self): # TODO: chunk fusion_group_len = 2 if self.dynamic_shapes else 1 self.assertGraphContainsExactly( - forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True) - FileCheck().check("DifferentiableGraph").check("TupleConstruct") \ - .check_next("return").check(FUSION_GROUP).run(str(forward_graph)) + forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True + ) + FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next( + "return" + ).check(FUSION_GROUP).run(str(forward_graph)) hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) @@ -968,17 +1087,17 @@ def test_milstm(self): @unittest.skip("rand_like is not supported yet") def test_rand_cuda(self): class M(torch.jit.ScriptModule): - __constants__ = ['d'] + __constants__ = ["d"] def __init__(self): super().__init__() - self.d = torch.device('cuda') + self.d = torch.device("cuda") @torch.jit.script_method def create(self, x): return x * x + x + torch.rand_like(x) - x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda') + x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda") m = M() out1 = m.create(x) out2 = m.create(x) @@ -991,7 +1110,7 @@ def create(self, x): @staticmethod def fn_test_relu(x, y): - return F.relu(x + .5 * y) + return F.relu(x + 0.5 * y) def test_relu(self): for device in self.devices: @@ -1004,7 +1123,7 @@ def test_relu(self): def test_erf(self): for device in self.devices: # only enabled on gpu - if device == 'cpu': + if device == "cpu": continue def fn_test_erf(x): @@ -1015,8 +1134,14 @@ def fn_test_erf(x): self.assertAllFused(ge.graph_for(x)) x.requires_grad_(True) ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) - self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + ge.graph_for(x), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1031,24 +1156,30 @@ def fn_test_rand2(x, y): r = torch.rand_like(y) return r * x * x - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device="cuda") + y = torch.randn(4, 4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_rand) warmup_forward(script_f, x, y) out = script_f(x, y) self.assertAllFused(script_f.graph_for(x, y)) x.requires_grad_(True) out = script_f(x, y) - self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", - "aten::_size_if_not_equal")) + self.assertAllFused( + script_f.graph_for(x, y), + except_for=( + "aten::size", + "prim::BroadcastSizes", + "aten::_size_if_not_equal", + ), + ) # test that broadcasting random produces correct results - x = torch.ones(4, 4, dtype=torch.float, device='cuda') - y = torch.ones(4, dtype=torch.float, device='cuda') + x = torch.ones(4, 4, dtype=torch.float, device="cuda") + y = torch.ones(4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_rand2) warmup_forward(script_f, x, y) out = script_f(x, y) - self.assertEqual(out[0, :] + torch.zeros(4, 4, device='cuda'), out) + self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1059,8 +1190,8 @@ def fn_test_diamond(x, y): b = y - r return a + b - x = torch.randn(4, 4, dtype=torch.float, device='cuda') - y = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4, 4, dtype=torch.float, device="cuda") + y = torch.randn(4, 4, dtype=torch.float, device="cuda") script_f = torch.jit.script(fn_test_diamond) warmup_forward(script_f, x, y) out = script_f(x, y) @@ -1070,8 +1201,8 @@ def test_scalar(self): def fn(x, y): return 2 * x + y - x = torch.tensor(0.1, dtype=torch.float, device='cpu') - y = torch.tensor(1, dtype=torch.float, device='cpu') + x = torch.tensor(0.1, dtype=torch.float, device="cpu") + y = torch.tensor(1, dtype=torch.float, device="cpu") ge = self.checkScript(fn, (x, y)) self.assertAllFused(ge.graph_for(x, y)) @@ -1091,7 +1222,9 @@ def foo(x): g = torch.jit.last_executed_optimized_graph() - FileCheck().check_count("prim::If", 1, exactly=True).check("prim::TensorExpr").run(g) + FileCheck().check_count("prim::If", 1, exactly=True).check( + "prim::TensorExpr" + ).run(g) torch._C._jit_pass_inline(g) f = FileCheck() for _ in range(3): @@ -1100,8 +1233,10 @@ def foo(x): def test_small_constant(self): for device in self.devices: + def fn_test_small_constant(x, y): return (1e-8 * x + 5e-9 * y) * 1e8 + x = torch.randn(4, 4, dtype=torch.float, device=device) y = torch.randn(4, 4, dtype=torch.float, device=device) @@ -1116,8 +1251,9 @@ def fn_test_small_constant(x, y): # TODO: fix that and reenable the test. def test_tensor_scalar_ops(self): for device in self.devices: + def should_fuse(x): - z = 3. + z = 3.0 y = x + z return x * y @@ -1134,22 +1270,24 @@ def should_fuse_scalar(x, z): inputs = [ torch.randn(2, 2, dtype=torch.float, device=device), - torch.tensor(3., dtype=torch.float, device=device), + torch.tensor(3.0, dtype=torch.float, device=device), ] ge = self.checkScript(should_fuse_scalar, inputs) # Check that the fused graph computes correct results when the scalar # input changes. inputs = [ torch.randn(2, 2, dtype=torch.float, device=device), - torch.tensor(7., dtype=torch.float, device=device), + torch.tensor(7.0, dtype=torch.float, device=device), ] self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) # The TE fuser supports fusion of non-constant scalars self.assertGraphContainsExactly( - ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True) + ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True + ) def test_where_and_typing(self): for device in self.devices: + def f(x, y): mask = x > y res = torch.where(mask, x, y) @@ -1159,14 +1297,16 @@ def f(x, y): y = torch.randn(4, 4, dtype=torch.double, device=device) script_f = self.checkScript(f, (x, y)) - self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) + self.assertAllFused( + script_f.graph_for(x, y), except_for={"prim::TupleConstruct"} + ) def test_disabled(self): old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() torch._C._jit_override_can_fuse_on_cpu(False) def fn(a): - return a ** 2 + a + return a**2 + a x = torch.randn(4, dtype=torch.float, device="cpu") s = self.checkScript(fn, (x,)) @@ -1193,38 +1333,46 @@ def test_torch_to(self): def foo(x): return x.to(torch.float) - foo(torch.tensor([3.], dtype=torch.float)) - foo(torch.tensor([3.], dtype=torch.float)) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + foo(torch.tensor([3.0], dtype=torch.float)) + foo(torch.tensor([3.0], dtype=torch.float)) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) # test not fusing non-const inputs @torch.jit.script def foo(x, dtype: int): return x.to(dtype) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) # test not fusing to_pinned inputs @torch.jit.script def foo(x, dtype: int): return x.to(pin_memory=True) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - foo(torch.tensor([3.], dtype=torch.float), torch.int) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) - + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + foo(torch.tensor([3.0], dtype=torch.float), torch.int) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) # test across-device not supported if torch.cuda.is_available(): + @torch.jit.script def foo(x): return x.to(device="cuda") - foo(torch.tensor([3.], dtype=torch.float)) - foo(torch.tensor([3.], dtype=torch.float)) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + foo(torch.tensor([3.0], dtype=torch.float)) + foo(torch.tensor([3.0], dtype=torch.float)) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) sizes = [(1, 4), (4, 4)] # reuses cast impl, smaller dtype set for faster test @@ -1245,7 +1393,9 @@ def forward(self, x): return x.to(self.dtype) bad_dtypes = [] - for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes): + for dtype, output_dtype, device, size in product( + dtypes, dtypes, self.devices, sizes + ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1275,12 +1425,15 @@ def test_masked_fill(self): torch.bool, ] sizes = [(2,), (4, 4)] - for self_dtype, device, scalar_val, size in product(dtypes, self.devices, [0.4, 3], sizes): + for self_dtype, device, scalar_val, size in product( + dtypes, self.devices, [0.4, 3], sizes + ): input_v = self.data_for(self_dtype, device, size=size) mask = self.data_for(torch.bool, device, size=size) def fn(input_v, mask): return torch.masked_fill(input_v, mask, scalar_val) + ref = fn(input_v, mask) try: t = torch.jit.trace(fn, (input_v, mask)) @@ -1288,16 +1441,21 @@ def fn(input_v, mask): self.assertLastGraphAllFused() except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) # noqa: F821 + " ".join( + [ + "Failed:", + str(self_dtype), + op.__name__, # noqa: F821 + device, + str(size), + ] + ) ) from e def test_isnan(self): x = torch.rand([4]) - x[0] = float('nan') - inputs = [ - x, - torch.tensor([float('nan'), .5]) - ] + x[0] = float("nan") + inputs = [x, torch.tensor([float("nan"), 0.5])] dtypes = [ torch.int8, torch.int16, @@ -1321,7 +1479,7 @@ def test_isnan(self): self.assertLastGraphAllFused() except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(dtype), 'isnan', device]) + " ".join(["Failed:", str(dtype), "isnan", device]) ) from e def test_gelu(self): @@ -1332,7 +1490,9 @@ def apply(fn): F.gelu, ] sizes = [(1,), (2,), (4, 4)] - for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + for dtype, op, device, size in product( + self.dtypes, unary_ops, self.devices, sizes + ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1357,6 +1517,7 @@ def apply(fn): def test_unary_ops(self): with torch._jit_internal._disable_emit_hooks(): + def apply(fn): return lambda x: fn(x) @@ -1411,7 +1572,9 @@ def apply(fn): ] gpu_only = {torch.erf, torch.erfc} sizes = [(1,), (2,), (4, 4)] - for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + for dtype, op, device, size in product( + self.dtypes, unary_ops, self.devices, sizes + ): # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1435,7 +1598,9 @@ def apply(fn): self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( - " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) + " ".join( + ["Failed:", str(dtype), op.__name__, device, str(size)] + ) ) from e def test_binary_ops(self): @@ -1494,6 +1659,7 @@ def apply(fn): def test_binary_scalar_ops(self): def apply(fn): return lambda x, y: fn(x, y) + ir_template = """ graph(%x : {dtype_x}, %y : {dtype_y}): %z = {op}(%x, %y) @@ -1516,10 +1682,12 @@ def apply(fn): "aten::__lshift__", "aten::__rshift__", ] - dtypes = ['int', 'float', 'bool'] - values = {'int' : [10, 3], 'float' : [12.34, 2.78], 'bool' : [True, False]} + dtypes = ["int", "float", "bool"] + values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} devices = self.devices - for dtype_x, dtype_y, op, device in product(dtypes, dtypes, binary_ops, devices): + for dtype_x, dtype_y, op, device in product( + dtypes, dtypes, binary_ops, devices + ): code = ir_template.format(**locals()) # Interpret the graph @@ -1535,7 +1703,9 @@ def apply(fn): try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: - raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) from e + raise RuntimeError( + " ".join(["Compilation failed:", device, str(code)]) + ) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): @@ -1544,7 +1714,11 @@ def apply(fn): res = k.run((x, y)) self.assertEqual(ref, res) except Exception as e: - raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) from e + raise RuntimeError( + " ".join( + ["Failed at runtime:", device, str(x), str(y), str(code)] + ) + ) from e def test_matmul(self): if self.dynamic_shapes: @@ -1553,31 +1727,33 @@ def test_matmul(self): def fn(x, y): return torch.matmul(x, y) - devices = ['cpu'] # No cuda support for ext calls yet - sizes = [[[128, 128], [128, 128]], - [[10, 10], [10, 10]], - [[1, 16], [16, 128]], - [[128], [128]], - [[128], [128, 128]], - [[3], [3]], - [[3, 4], [4]], - [[10, 3, 4], [4]], - [[10, 3, 4], [10, 4, 5]], - [[10, 3, 4], [4, 5]], - ] + devices = ["cpu"] # No cuda support for ext calls yet + sizes = [ + [[128, 128], [128, 128]], + [[10, 10], [10, 10]], + [[1, 16], [16, 128]], + [[128], [128]], + [[128], [128, 128]], + [[3], [3]], + [[3, 4], [4]], + [[10, 3, 4], [4]], + [[10, 3, 4], [10, 4, 5]], + [[10, 3, 4], [4, 5]], + ] # Only 2D x 2D matrix multiply is supported. For non-supported sizes we # still want to run results verification to test that we didn't # accidentally fuse it, but we skip the 'is-fused' check. # TODO: add support for other shape combinations and make this set empty: - skip_is_fused_check_sizes = ["[[128], [128]]", - "[[128], [128, 128]]", - "[[3], [3]]", - "[[3, 4], [4]]", - "[[10, 3, 4], [4]]", - "[[10, 3, 4], [10, 4, 5]]", - "[[10, 3, 4], [4, 5]]", - ] + skip_is_fused_check_sizes = [ + "[[128], [128]]", + "[[128], [128, 128]]", + "[[3], [3]]", + "[[3, 4], [4]]", + "[[10, 3, 4], [4]]", + "[[10, 3, 4], [10, 4, 5]]", + "[[10, 3, 4], [4, 5]]", + ] for dtype, size, device in product(self.dtypes, sizes, devices): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue @@ -1598,12 +1774,11 @@ def fn(x, y): if str(size) not in skip_is_fused_check_sizes: self.assertAllFused(t.graph_for(x, y)) except Exception as e: - raise RuntimeError( - " ".join(["Failed:", str(dtype), device]) - ) from e + raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e def test_binary_tensor_scalar_ops(self): with torch._jit_internal._disable_emit_hooks(): + def apply_with_scalar(fn, scalar): return lambda x: fn(x, scalar) @@ -1625,7 +1800,9 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] - for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product( + self.dtypes, binary_ops, devices, scalars + ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: @@ -1659,7 +1836,9 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 - for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + for dtype, op, device, scalar in product( + self.dtypes, binary_ops, devices, scalars + ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: @@ -1696,7 +1875,9 @@ def apply_with_scalar(fn, scalar): # Maybe we should split this into separate tests to speed it up by # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] - for dtype, op, device, scalar in product(dtypes, binary_ops, self.devices, scalars): + for dtype, op, device, scalar in product( + dtypes, binary_ops, self.devices, scalars + ): if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue try: @@ -1780,8 +1961,9 @@ def apply(fn): " ".join(["Failed:", str(dtype), op.__name__, device]) ) from e - - @unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure") + @unittest.skip( + "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure" + ) def test_list_ops(self): def apply(fn): return lambda x, y, z: fn([x * x, y * y, z * z]) @@ -1848,6 +2030,7 @@ def apply(fn): def test_unsupported_dtypes(self): for device in self.devices: + def fn(x): return x * x + x @@ -1904,10 +2087,13 @@ def eager(t0, t1, t2, t3, t4): for pair in zip(script(*inputs), eager(*inputs)): test, ref = pair torch.testing.assert_close(test, ref) - self.assertAllFused(script.graph_for(*inputs), except_for={"prim::TupleConstruct"}) + self.assertAllFused( + script.graph_for(*inputs), except_for={"prim::TupleConstruct"} + ) def test_sub_gt_and(self): for device in self.devices: + def eager(t1, t2, t3, t4, t: float): w = t1 - t2 h = t3 - t4 @@ -1920,6 +2106,7 @@ def eager(t1, t2, t3, t4, t: float): # careful not to create a fusion group containing it. return k + 1 return w + t = torch.rand(8, dtype=torch.float, device=device) scripted = self.checkScript(eager, (t, t, t, t, 0.1)) @@ -1929,20 +2116,24 @@ def test_chunk_mul_one(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def eager(x): z, y, w = torch.chunk(x, 3, -1) return z * 3, y, w + x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) z, y, w = eager(x) script = self.checkScript(eager, (x,)) def test_eq_unsqueeze_type_as(self): for device in self.devices: + def eager(a, b): mask = b == 1 mask = torch.unsqueeze(mask, -1) x = mask.type_as(a) return x, mask + a = torch.rand(1, 64, 1024, device=device, dtype=torch.float) b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long) script = self.checkScript(eager, (a, b)) @@ -1995,33 +2186,40 @@ def eager(input, weight, bias): bias = torch.rand((64), dtype=torch.float) script = self.checkScript(eager, (input, weight, bias)) - FileCheck().check_not("TensorExpr").run(torch.jit.last_executed_optimized_graph()) + FileCheck().check_not("TensorExpr").run( + torch.jit.last_executed_optimized_graph() + ) def test_type_as_cat(self): with inline_fusion_groups(): + def eager(x, y): return torch.cat((x, y.type_as(x)), dim=1) + dtypes = self.dtypes.copy() # CPU fuser doesn't support float16. dtypes.remove(torch.float16) dtypes.remove(torch.bfloat16) for dtype1, dtype2 in product(dtypes, dtypes): - x = torch.randint(2, (1, 13,)).to(dtype1) + x = torch.randint( + 2, + ( + 1, + 13, + ), + ).to(dtype1) zero = torch.tensor([[0]]).to(dtype2) one = torch.tensor([[1]]).to(dtype2) script = torch.jit.trace(eager, (x, zero)) for _ in range(3): - torch.testing.assert_close( - script(x, zero), - eager(x, zero)) - torch.testing.assert_close( - script(x, one), - eager(x, one)) + torch.testing.assert_close(script(x, zero), eager(x, zero)) + torch.testing.assert_close(script(x, one), eager(x, one)) self.assertAllFused(script.graph_for(x, one)) def test_to_device(self): def eager(x): return x.to(device="cpu").relu() + x = torch.rand(8) script = self.checkScript(eager, (x,)) self.assertAllFused(script.graph_for(x)) @@ -2029,7 +2227,10 @@ def eager(x): def test_dims(self): def eager(x, y): return x / (y + 0.0001) - x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided((1, 1, 768), (768, 1, 1)) + + x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided( + (1, 1, 768), (768, 1, 1) + ) y = torch.tensor([[[2.0]]], dtype=torch.float32) script = self.checkScript(eager, (x, y)) self.assertAllFused(script.graph_for(x, y)) @@ -2062,6 +2263,7 @@ def eager(x, y): def test_exhaust_specializations(self): with texpr_enable_strategy([("STATIC", 1)]): + @torch.jit.script def foo(x): return x + x + x @@ -2080,6 +2282,7 @@ def foo(x): def test_unsqueeze_var_dim(self): def eager(x, y, z: int): return x * torch.unsqueeze(y, dim=z) + x = torch.rand(4, 4, 64).permute(1, 0, 2) y = torch.rand(4, 4) z = 2 @@ -2107,34 +2310,43 @@ def _test_fwd_bwd(self, fn): def test_relu_fwd_bwd(self): def eager(x): return torch.relu(x * 1.01) + self._test_fwd_bwd(eager) def test_hardswish_fwd_bwd(self): def eager(x): return F.hardswish(x) * 1.01 + self._test_fwd_bwd(eager) def test_hardsigmoid_fwd_bwd(self): def eager(x): return F.hardsigmoid(x) * 1.01 + self._test_fwd_bwd(eager) def test_cat_graph_opt(self): def foo(x, y, z): return torch.log(torch.cat([x, y, z])) - self.checkScript(foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5]))) + self.checkScript( + foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])) + ) # TODO: not sure why not updated graph isn't reflected in last_optimized_graph self.assertLastGraphAllFused() def test_dynamic_cat(self): with inline_fusion_groups(): + @torch.jit.script - def repro(xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor]): + def repro( + xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor] + ): return [ torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1) for x, y, z in zip(xs, ys, zs) ] + for _ in range(3): N = 3 xs = [torch.ones(21) for _ in range(N)] @@ -2153,8 +2365,10 @@ def eager(b: float): def test_cat_2k_args(self): with inline_fusion_groups(): + def eager(x): return torch.relu(torch.cat([x for _ in range(2000)])) + x = torch.randn(1) trace = self.checkTrace(eager, (x,)) fusion_groups = self.findFusionGroups(trace.graph_for(x)) @@ -2164,6 +2378,7 @@ def test_adaptive_avg_pool2d(self): # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this # test should be moved there with inline_fusion_groups(): + def foo1(x): return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2)) @@ -2179,11 +2394,13 @@ def foo2(x): def test_unrolled_cat(self): with inline_fusion_groups(): + def eager(x): ret = torch.empty(0) for i in range(x.shape[0]): ret = torch.cat([ret, x[i].relu()]) return ret + script = torch.jit.script(eager) # Warm up with size=1 tensor; since the loop iterates once the @@ -2260,6 +2477,7 @@ def foo(x): def test_dynamic_shapes(self): from functools import partial + n = 10 gen_tensor = ( @@ -2272,6 +2490,7 @@ def test_dynamic_shapes(self): ) with texpr_enable_strategy([("DYNAMIC", 20)]): + def foo(x, y, z): return torch.sigmoid(torch.tanh(x)) @@ -2311,7 +2530,9 @@ def fum(x, y, z): torch._C._jit_pass_dce(g) # We should see only one optimized kernel - FileCheck().check_count("TensorExprDynamicGuard", 1, exactly=True).run(g) + FileCheck().check_count( + "TensorExprDynamicGuard", 1, exactly=True + ).run(g) self.assertEqual(func(*inps), func_s(*inps)) gen = gen_tensor[0] @@ -2327,7 +2548,9 @@ def fum(x, y, z): g = torch.jit.last_executed_optimized_graph() torch._C._jit_pass_inline(g) torch._C._jit_pass_dce(g) - FileCheck().check_count("TensorExprDynamicGuard", len(gen_tensor), exactly=True).run(g) + FileCheck().check_count( + "TensorExprDynamicGuard", len(gen_tensor), exactly=True + ).run(g) @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") def test_autocast_up(self): @@ -2382,7 +2605,6 @@ def f(x): self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3) def test_with_strict_fusion(self): - def success(x): with torch.jit.strict_fusion(): return x + x + x @@ -2445,6 +2667,7 @@ def test_constant_chunk_shapes(self): self.skipTest("TODO: chunk dynamic shapes") for device in self.devices: + def f(x, y): r = torch.tensor(4) z1, z2 = (x + y + r).chunk(2, dim=1) @@ -2474,10 +2697,10 @@ def test_pow_multiple_dtype(self): # https://github.com/pytorch/pytorch/issues/75476 def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: p = torch.sigmoid(p) - result = p ** gamma + result = p**gamma return result - x = torch.rand((2, 2), dtype=torch.half, device='cuda') + x = torch.rand((2, 2), dtype=torch.half, device="cuda") ref = fn(x) @@ -2491,138 +2714,140 @@ def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: class TestTEFuserStatic(TestTEFuser): dynamic_shapes = False + class TestTEFuserDynamic(TestTEFuser): dynamic_shapes = True + del TestTEFuser works_list = [ - '__radd__', - '__rdiv__', - '__rmul__', - '__rmod__', - 'abs', - 'acos', - 'add', - 'addcmul', - 'addmm.decomposed', - 'asin', - 'atan', - 'atan2', - 'ceil', - 'clamp', - 'clamp.scalar', - 'contiguous', - 'cos', - 'cosh', - 'div.no_rounding_mode', - 'div.true_rounding', - 'div.floor_rounding', - 'div.trunc_rounding', - 'eq', - 'erf', - 'erfc', - 'exp', - 'expand', - 'expand_as', - 'expm1', - 'floor', - 'fmod', - 'fmod.autodiffed', - 'ge', - 'gt', - 'isnan', - 'le', - 'lerp', - 'lgamma', - 'log', - 'log10', - 'log1p', - 'log2', - 'lt', - 'masked_fill', - 'max.binary', - 'mean', - 'min.binary', - 'mm', - 'mul', - 'ne', - 'neg', - 'nn.functional.hardshrink', - 'nn.functional.hardsigmoid', - 'nn.functional.hardswish', - 'nn.functional.softplus', - 'nn.functional.hardtanh', - 'nn.functional.leaky_relu', - 'nn.functional.relu', - 'nn.functional.relu6', - 'nn.functional.softsign', - 'nn.functional.tanhshrink', - 'nn.functional.threshold', - 'permute', - 'pow', - 'reciprocal', - 'remainder', - 'remainder.autodiffed', - 'reshape', - 'reshape_as', - 'round', - 'rsub', - 'rsub.rsub_tensor', - 'rsqrt', - 'sigmoid', - 'sign', - 'sin', - 'sinh', - 'sqrt', - 'sub', - 'sum', - 't', - 'tan', - 'tanh', - 'transpose', - 'true_divide', - 'trunc', - 'unsqueeze', - 'view', - 'view_as', - 'where', - 'bool', - 'byte', - 'char', - 'double', - 'float', - 'half', - 'int', - 'long', - 'short', - 'bool.channels_last', - 'byte.channels_last', - 'char.channels_last', - 'double.channels_last', - 'float.channels_last', - 'half.channels_last', - 'int.channels_last', - 'long.channels_last', - 'short.channels_last', + "__radd__", + "__rdiv__", + "__rmul__", + "__rmod__", + "abs", + "acos", + "add", + "addcmul", + "addmm.decomposed", + "asin", + "atan", + "atan2", + "ceil", + "clamp", + "clamp.scalar", + "contiguous", + "cos", + "cosh", + "div.no_rounding_mode", + "div.true_rounding", + "div.floor_rounding", + "div.trunc_rounding", + "eq", + "erf", + "erfc", + "exp", + "expand", + "expand_as", + "expm1", + "floor", + "fmod", + "fmod.autodiffed", + "ge", + "gt", + "isnan", + "le", + "lerp", + "lgamma", + "log", + "log10", + "log1p", + "log2", + "lt", + "masked_fill", + "max.binary", + "mean", + "min.binary", + "mm", + "mul", + "ne", + "neg", + "nn.functional.hardshrink", + "nn.functional.hardsigmoid", + "nn.functional.hardswish", + "nn.functional.softplus", + "nn.functional.hardtanh", + "nn.functional.leaky_relu", + "nn.functional.relu", + "nn.functional.relu6", + "nn.functional.softsign", + "nn.functional.tanhshrink", + "nn.functional.threshold", + "permute", + "pow", + "reciprocal", + "remainder", + "remainder.autodiffed", + "reshape", + "reshape_as", + "round", + "rsub", + "rsub.rsub_tensor", + "rsqrt", + "sigmoid", + "sign", + "sin", + "sinh", + "sqrt", + "sub", + "sum", + "t", + "tan", + "tanh", + "transpose", + "true_divide", + "trunc", + "unsqueeze", + "view", + "view_as", + "where", + "bool", + "byte", + "char", + "double", + "float", + "half", + "int", + "long", + "short", + "bool.channels_last", + "byte.channels_last", + "char.channels_last", + "double.channels_last", + "float.channels_last", + "half.channels_last", + "int.channels_last", + "long.channels_last", + "short.channels_last", ] known_failures = [ - '__rmatmul__', - 'frac', - 'matmul', + "__rmatmul__", + "frac", + "matmul", ] # If your OpInfo test causes this test to fail, add it here -skip_ops = [ - 'conj' -] +skip_ops = ["conj"] + def get_name(op): l = [op.name] - if op.variant_test_name != '': + if op.variant_test_name != "": l.append(op.variant_test_name) - return '.'.join(l) + return ".".join(l) + # Purpose of this class is to allow super() calls. # super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works. @@ -2631,6 +2856,7 @@ def get_name(op): class TestNNCOpInfoParent(JitCommonTestCase): pass + class TestNNCOpInfo(TestNNCOpInfoParent): def setUp(self): super(TestNNCOpInfoParent, self).setUp() @@ -2656,23 +2882,23 @@ def te_compile(self, device, dtype, op): param_values.append(v) fx_args.append(param_names[-1]) else: - fx_args.append(f'{repr(v)}') + fx_args.append(f"{repr(v)}") for k, v in kwarg_values.items(): if isinstance(v, torch.Tensor): param_names.append(k) param_values.append(v) - fx_args.append(f'{k} = {k}') + fx_args.append(f"{k} = {k}") else: - fx_args.append(f'{k} = {repr(v)}') + fx_args.append(f"{k} = {repr(v)}") code = f""" def f({', '.join(param_names)}): return op.op({', '.join(fx_args)})""" - g = {'torch': torch, 'inf' : math.inf, 'op': op} + g = {"torch": torch, "inf": math.inf, "op": op} exec(code, g) - f = g['f'] - f.__module__ = 'test' + f = g["f"] + f.__module__ = "test" out = f(*param_values) ts_g = torch.jit.trace(f, param_values) @@ -2683,35 +2909,48 @@ def f({', '.join(param_names)}): @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") - @ops([op for op in op_db if get_name(op) in works_list], allowed_dtypes=(torch.float,)) + @ops( + [op for op in op_db if get_name(op) in works_list], + allowed_dtypes=(torch.float,), + ) def test_working(self, device, dtype, op): self.te_compile(device, dtype, op) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") - @ops([op for op in op_db if get_name(op) in known_failures], allowed_dtypes=(torch.float,)) + @ops( + [op for op in op_db if get_name(op) in known_failures], + allowed_dtypes=(torch.float,), + ) def test_failures(self, device, dtype, op): try: self.te_compile(device, dtype, op) except Exception as e: pass else: - raise RuntimeError("Expected test to fail. If it now works, move op into works_list") + raise RuntimeError( + "Expected test to fail. If it now works, move op into works_list" + ) @onlyCPU @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") - @ops([op for op in op_db if get_name(op) not in works_list + known_failures], allowed_dtypes=(torch.float,)) + @ops( + [op for op in op_db if get_name(op) not in works_list + known_failures], + allowed_dtypes=(torch.float,), + ) def test_unsupported(self, device, dtype, op): if get_name(op) in skip_ops: return try: with warnings.catch_warnings(): - warnings.simplefilter('ignore', TracerWarning) # noqa: F821 + warnings.simplefilter("ignore", TracerWarning) # noqa: F821 self.te_compile(device, dtype, op) except Exception as e: pass else: - raise RuntimeError("Expected test to fail. If it now works, move op into works_list") + raise RuntimeError( + "Expected test to fail. If it now works, move op into works_list" + ) @slowTest @onlyCPU @@ -2725,10 +2964,14 @@ def test_nnc_correctness(self, device, dtype, op): for variant, sample in variant_sample_pairs: trace = create_traced_fn(self, variant, cache_traced_fn=True) - ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + ref = variant( + *clone_inputs((sample.input, *sample.args)), **sample.kwargs + ) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) - val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + val = trace( + *clone_inputs((sample.input, *sample.args)), **sample.kwargs + ) atol = 2e-1 if dtype == torch.bfloat16 else 1e-5 rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5 @@ -2740,14 +2983,17 @@ def test_nnc_correctness(self, device, dtype, op): # if the CU is not cleared. torch.jit._state._python_cu.drop_all_functions() + # CPU fuser not currently used in fbcode only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda") instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) + # Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent) class TestLoopnestRandomizationParent(JitTestCase): pass + class TestLoopnestRandomization(TestLoopnestRandomizationParent): def setUp(self): super(TestLoopnestRandomizationParent, self).setUp() @@ -2812,5 +3058,5 @@ def fn_test_relu(x, y): instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu")) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_linalg.py b/test/test_linalg.py index e6fb5fdca250..207290f5a6a8 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -963,9 +963,9 @@ def test_eigh_errors_and_warnings(self, device, dtype): # eigh requires 'uplo' parameter to be 'U' or 'L' t = torch.randn(3, 3, device=device, dtype=dtype) for uplo in ["a", "wrong"]: - with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): torch.linalg.eigh(t, UPLO=uplo) - with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): np.linalg.eigh(t.cpu().numpy(), UPLO=uplo) # if non-empty out tensor with wrong shape is passed a warning is given @@ -1062,9 +1062,9 @@ def test_eigvalsh_errors_and_warnings(self, device, dtype): # eigvalsh requires 'uplo' parameter to be 'U' or 'L' t = torch.randn(3, 3, device=device, dtype=dtype) for uplo in ["a", "wrong"]: - with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"): torch.linalg.eigvalsh(t, UPLO=uplo) - with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"): np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo) # if non-empty out tensor with wrong shape is passed a warning is given @@ -2416,7 +2416,7 @@ def test_nuclear_norm_exceptions_old(self, device): @skipCUDAIfNoCusolver @skipCPUIfNoLapack - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) def test_svd_lowrank(self, device, dtype): from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix @@ -2439,14 +2439,12 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option # check if u, s, v is a SVD u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.mT) + A = (u * s.unsqueeze(-2)).matmul(v.mH) self.assertEqual(A, a, rtol=1e-7, atol=2e-7) - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) + # check if svd_lowrank produces same singular values as linalg.svdvals + U, S, Vh = torch.linalg.svd(a, full_matrices=False) + V = Vh.mH self.assertEqual(s, S) if density == 1: @@ -2454,10 +2452,11 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option # # check if pairs (u, U) and (v, V) span the same # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.mT.matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.mT.matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + u, v = u[..., :actual_rank], v[..., :actual_rank] + U, V = U[..., :actual_rank], V[..., :actual_rank] + expected_ones = u.mH.matmul(U).det().abs() + self.assertEqual(expected_ones, torch.ones_like(expected_ones)) + self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones)) all_batches = [(), (1,), (3,), (2, 3)] for actual_rank, size, all_batches in [ # noqa: B020 @@ -4480,6 +4479,72 @@ def test_matmul_small_brute_force_3d_Nd(self, device, dtype): y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) + @onlyCUDA + @dtypes(*floating_types_and(torch.half)) + def test_matmul_small_brute_force_tunableop(self, device, dtype): + # disable tunableop buffer rotation for all tests everywhere, it can be slow + import os + os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0" + assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off by default" + assert torch.cuda.tunable.tuning_is_enabled(), "TunableOp's tuning should be enabled by default" + torch.cuda.tunable.tuning_enable(False) + assert torch.cuda.tunable.tuning_is_enabled() is False + torch.cuda.tunable.tuning_enable(True) + assert torch.cuda.tunable.tuning_is_enabled() + assert torch.cuda.tunable.get_max_tuning_duration() == 30 + assert torch.cuda.tunable.get_max_tuning_iterations() == 100 + + torch.cuda.tunable.enable() + # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_max_tuning_duration(1) + torch.cuda.tunable.set_max_tuning_iterations(1) + + make_arg = partial(make_tensor, device=device, dtype=dtype) + + for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): + x = make_arg(size_x, noncontiguous=nctg_x) + y = make_arg(size_y, noncontiguous=nctg_y) + self.check_single_matmul(x, y) + + filename1 = torch.cuda.tunable.get_filename() + filename2 = "tunableop_results_tmp1.csv" + filename3 = "tunableop_results_tmp2.csv" + ordinal = torch.cuda.current_device() + assert filename1 == f"tunableop_results{ordinal}.csv" + assert len(torch.cuda.tunable.get_validators()) > 0 + assert len(torch.cuda.tunable.get_results()) > 0 + + assert torch.cuda.tunable.write_file() # use default filename + assert torch.cuda.tunable.write_file(filename2) # use custom, one-time filename + torch.cuda.tunable.set_filename(filename3) + assert torch.cuda.tunable.write_file() # use previously set filename + assert torch.cuda.tunable.read_file() # use previously set filename, will ignore duplicates and return True + + with open(filename1) as file1: + file1_contents = file1.read() + with open(filename2) as file2: + file2_contents = file2.read() + with open(filename3) as file3: + file3_contents = file3.read() + assert file1_contents == file2_contents + assert file1_contents == file3_contents + + # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors + for filename in [filename1, filename2, filename3]: + try: + import os + os.remove(filename) + finally: + pass + + # disables TunableOp, no file will be written, restore to default values + torch.cuda.tunable.enable(False) + torch.cuda.tunable.set_filename(filename1) # reset back to default filename for next unit test + torch.cuda.tunable.set_max_tuning_duration(30) + torch.cuda.tunable.set_max_tuning_iterations(100) + assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off after resetting" + assert torch.cuda.tunable.get_max_tuning_iterations() == 100 + @dtypes(torch.float, torch.complex64) def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) diff --git a/test/test_meta.py b/test/test_meta.py index ebd91e71c29f..a5368fbfaee7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1724,6 +1724,11 @@ def f(): out = f() self.assertEqual(out.shape, [10, 16]) + def test_local_scalar_dense_call(self): + with self.assertRaisesRegex(RuntimeError, "cannot be called on meta tensors"): + meta_tensor = torch.randn(1, device='meta') + meta_tensor.item() + instantiate_device_type_tests(TestMeta, globals()) def print_op_str_if_not_supported(op_str): diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index e672d69ab5dd..28113d0bdf08 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -149,7 +149,7 @@ def forward(self, x): bn_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11) - FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ + FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} @@ -250,7 +250,7 @@ def foo(self, x): bn_no_forward_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11) - FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ + FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ .run(bn_no_forward_scripted_module.foo.graph) bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo']) @@ -471,7 +471,7 @@ def _quant_script_and_optimize(model): # basic case m, m_optim = _quant_script_and_optimize(Standalone()) - FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ + FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) @@ -485,7 +485,7 @@ def _quant_script_and_optimize(model): # generic case m, m_optim = _quant_script_and_optimize(Parent()) - FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \ + FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) diff --git a/test/test_modules.py b/test/test_modules.py index ab05e9df4355..e854eec8add7 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -863,7 +863,8 @@ def test_errors(self, device, dtype, module_info, training): else: raise NotImplementedError(f"Unknown error type {error_input.error_on}") - @modules([module for module in module_db if not module.is_lazy]) + # Only run this test for float32 because the test loops over all the dtypes + @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) @parametrize('swap', [True, False]) @parametrize('set_grad', [True, False]) @wrapSwapTensorsTest() @@ -879,6 +880,7 @@ def test_to(self, device, dtype, module_info, training, swap, set_grad): for module_input in module_inputs: c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs m = module_cls(*c_args, **c_kwargs) @@ -896,6 +898,17 @@ def _to(m, set_grad=False): setattr(m, n, new_b) _to(m, set_grad=set_grad) + # Check .to() can be run after forward and backward with swap + has_params = len(list(m.parameters())) > 0 + if swap and not set_grad and has_params: + out = m(*args, **kwargs) + if isinstance(out, tuple): + out = out[0] + out.sum().backward() + m.to(dtype=torch.half) + # reset + m.to(dtype=torch.float32) + prev_device, prev_dtype = device, dtype for device_, dtype_ in product(devices, dtypes): # if device/dtype do not change, grad.to(device, dtype) is a no-op so @@ -903,6 +916,7 @@ def _to(m, set_grad=False): # parameters will be wrapped in an nn.Parameter before swapping # which will cause the ._cdata to change g_no_swap = device_ == prev_device and dtype_ == prev_dtype + prev_prev_device, prev_prev_dtype = prev_device, prev_dtype prev_device, prev_dtype = device_, dtype_ p_ids_before = [id(p) for p in m.parameters()] @@ -940,7 +954,6 @@ def _to(m, set_grad=False): self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after))) self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after))) - @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) @parametrize('swap', [True, False]) @wrapSwapTensorsTest() diff --git a/test/test_mps.py b/test/test_mps.py index 8c3bbf4b7bcf..c59a598facc4 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -193,6 +193,9 @@ def mps_ops_grad_modifier(ops): # Failures due to lack of implementation of downstream functions on MPS backend # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 'linalg.matrix_rank': None, + + # Exception: Caused by sample input at index 3 on MPS + 'nn.functional.conv3d': [torch.float32], } def addDecorator(op, d) -> None: @@ -240,6 +243,7 @@ def mps_ops_modifier(ops): '__getitem__', 'abs', 'add', + 'alias_copy', 'argwhere', 'atleast_1d', 'atleast_2d', @@ -667,6 +671,11 @@ def mps_ops_modifier(ops): 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], } + MACOS_BEFORE_14_4_XFAILLIST = { + # These ops work fine in 14.4 but fail in 14.2 or 13.x + 'fft.hfft2': [torch.complex64], + } + # Those ops are not expected to work UNIMPLEMENTED_XFAILLIST = { # Failures due to lack of op implementation on MPS backend @@ -1020,6 +1029,9 @@ def mps_ops_modifier(ops): # Unsupported # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible 'nn.functional.avg_pool2d': [torch.float16], + + # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 + 'nn.functional.conv3d': None, } def addDecorator(op, d) -> None: @@ -1040,6 +1052,11 @@ def addDecorator(op, d) -> None: unittest.expectedFailure, dtypes=xfaillist[key])) + if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4): + addDecorator(op, DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_14_4_XFAILLIST[key])) + if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): addDecorator(op, DecorateInfo( unittest.expectedFailure, @@ -3261,6 +3278,36 @@ def helper(shape, value): helper((2, 8, 4, 5), 0.2) helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally + def test_addcdiv_transpose(self): + # Regression test for issue https://github.com/pytorch/pytorch/issues/118115 + # Testing continuity of all input tensors + + def helper(shape, value): + shape_t = shape[::-1] + for i in range(2): + for j in range(2): + for k in range(2): + x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t() + y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t() + z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t() + + x_mps = x.detach().clone().to(device="mps") + y_mps = y.detach().clone().to(device="mps") + z_mps = z.detach().clone().to(device="mps") + + result_cpu = x.addcdiv_(y, z, value=value) + result_mps = x_mps.addcdiv(y_mps, z_mps, value=value) + result_mps_out = result_cpu.detach().clone().to('mps') + torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value) + + self.assertEqual(result_cpu, result_mps) + self.assertEqual(result_cpu, result_mps_out) + + helper((2, 3), 1.0) + helper((2, 3), 0.2) + helper((100, 300), 1.0) + helper((100, 300), 0.2) + def test_buffer_size_match(self): # this test shouldn't cause any crash size = 16 @@ -7846,6 +7893,11 @@ def test_mps_allocator_module(self): self.assertTrue(current_alloc_after > current_alloc_before) self.assertTrue(driver_alloc_after > driver_alloc_before) + def test_mps_allocator_stats(self): + max_memory = torch.mps.recommended_max_memory() + print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB") + self.assertTrue(max_memory > 0) + # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool, # press record, then run this python test, and press stop. Next expand # the os_signposts->PyTorchMPS and check if events or intervals are logged diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 597180129f72..5524658b0123 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -2,18 +2,31 @@ import io import itertools +import math import sys -from typing import Optional, Tuple import unittest from functools import partial -import math +from typing import Optional, Tuple import numpy as np + import torch +import torch._dynamo +import torch._dynamo.testing import torch.nn import torch.nn.functional as F + +from torch.nested._internal.nested_tensor import ( + buffer_from_jagged, + jagged_from_list, + nested_view_from_values_offsets, + NestedTensor, + ViewNestedFromBuffer, +) from torch.testing._internal.common_cuda import ( - SM70OrLater, SM80OrLater, PLATFORM_SUPPORTS_FUSED_ATTENTION, + PLATFORM_SUPPORTS_FUSED_ATTENTION, + SM70OrLater, + SM80OrLater, ) from torch.testing._internal.common_device_type import ( dtypes, @@ -21,10 +34,10 @@ instantiate_device_type_tests, onlyCPU, onlyCUDA, + PYTORCH_CUDA_MEMCHECK, skipCUDAIf, skipCUDAIfRocm, skipMeta, - PYTORCH_CUDA_MEMCHECK, ) from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_utils import ( @@ -34,23 +47,15 @@ instantiate_parametrized_tests, IS_FBCODE, IS_WINDOWS, + markDynamoStrictTest, parametrize, run_tests, skipIfSlowGradcheckEnv, skipIfTorchDynamo, - markDynamoStrictTest, - xfailIfTorchDynamo, subtest, TEST_WITH_ROCM, TestCase, -) - -from torch.nested._internal.nested_tensor import ( - buffer_from_jagged, - jagged_from_list, - NestedTensor, - nested_view_from_values_offsets, - ViewNestedFromBuffer, + xfailIfTorchDynamo, ) # Tests are ported from pytorch/nestedtensor. @@ -61,6 +66,7 @@ def _iter_constructors(): # yield as_nested_tensor yield torch.nested.nested_tensor + # Helper function to generate a pair of random nested tensors # one is contiguous, the other is not, but they appear to have same entries # an output nested tensor consists of @@ -82,6 +88,7 @@ def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16 nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2) return nt_contiguous, nt_noncontiguous + # Helper functions to pad a noncontiguous nested tensor # can be replaced once to_padded_tensor supports noncontiguous memory @@ -108,10 +115,19 @@ def noncontiguous_to_padded_tensor(input, shape=None): view.copy_(tensor) return result + # Helper function to generate a random nested tensor -def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, require_non_empty=True): +def random_nt( + device, + dtype, + num_tensors, + max_dims, + min_dims=None, + layout=torch.strided, + require_non_empty=True, +): if min_dims is None: min_dims = tuple([0] * len(max_dims)) @@ -120,9 +136,9 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch. assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" assert min_dim >= 0, "random_nt: min_dim must be non-negative" if require_non_empty: - assert not (min_dim == 0 and max_dim == 1), ( - "random_nt: zero cannot be the only possible value if require_non_empty is True" - ) + assert not ( + min_dim == 0 and max_dim == 1 + ), "random_nt: zero cannot be the only possible value if require_non_empty is True" if require_non_empty: # Select a random idx that will be required to be non-empty @@ -135,7 +151,9 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch. new_min_dim = min_dim if require_non_empty and i == non_zero_idx and min_dim == 0: new_min_dim = 1 - tensor_dims.append(torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()) + tensor_dims.append( + torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item() + ) t1 = torch.randn(tensor_dims, device=device, dtype=dtype) ts1.append(t1) @@ -145,14 +163,23 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch. # Alternate approach to generating a random NT. # dims should be something like [5, None, 10], with None indicating that a # random ragged structure should be used -def random_nt_from_dims(dims, device=None, dtype=None, layout=torch.strided, requires_grad=False): +def random_nt_from_dims( + dims, device=None, dtype=None, layout=torch.strided, requires_grad=False +): sizes = [ - [d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]] + [ + d if d is not None else torch.randint(2, 10, size=(1,)).item() + for d in dims[1:] + ] for d in range(dims[0]) ] - return torch.nested.nested_tensor([ - torch.randn(*size) for size in sizes - ], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad) + return torch.nested.nested_tensor( + [torch.randn(*size) for size in sizes], + device=device, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + ) # Creates an NT matching another NT's number of components and @@ -174,9 +201,9 @@ def random_nt_from_similar(other, dims=None): ret_size.append(d) ret_sizes.append(ret_size) - return torch.nested.nested_tensor([ - torch.randn(*size) for size in ret_sizes - ], device=other.device) + return torch.nested.nested_tensor( + [torch.randn(*size) for size in ret_sizes], device=other.device + ) # makes naming nice for tests that parametrize over layout. @@ -234,8 +261,7 @@ def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): nested_tensor_list = nested_tensor.unbind() for id in range(batch_size): self.assertEqual( - nested_tensor_list[id], - nested_tensor_ref_list[id].type(torch.int64) + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) ) @parametrize("batch_size", [2, 4]) @@ -257,8 +283,7 @@ def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): nested_tensor_list = nested_tensor.unbind() for id in range(batch_size): self.assertEqual( - nested_tensor_list[id], - nested_tensor_ref_list[id].type(torch.int64) + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) ) @parametrize("batch_size", [2, 4]) @@ -282,11 +307,9 @@ def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): nested_tensor_list = nested_tensor.unbind() for id in range(batch_size): self.assertEqual( - nested_tensor_list[id], - nested_tensor_ref_list[id].type(torch.float) + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float) ) - @torch.inference_mode() def _test_unbind_case(self, a, b): nt = torch.nested.nested_tensor([a, b]) @@ -306,25 +329,29 @@ def _test_unbind_case(self, a, b): @torch.inference_mode() def test_unbind_0(self): self._test_unbind_case( - torch.tensor([1, 2]), torch.tensor([7, 8]), + torch.tensor([1, 2]), + torch.tensor([7, 8]), ) @torch.inference_mode() def test_unbind_1(self): self._test_unbind_case( - torch.tensor([1]), torch.tensor([7]), + torch.tensor([1]), + torch.tensor([7]), ) @torch.inference_mode() def test_unbind_3(self): self._test_unbind_case( - torch.tensor([1.0]), torch.tensor([]), + torch.tensor([1.0]), + torch.tensor([]), ) @torch.inference_mode() def test_unbind_4(self): self._test_unbind_case( - torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) @torch.inference_mode() @@ -343,7 +370,9 @@ def _test_fn(unbind_fn): @torch.inference_mode() def test_nested_tensor(self): - self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))) + self.assertRaises( + TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])) + ) self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) @torch.inference_mode() @@ -432,18 +461,22 @@ def test_size_dim(self): a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)]) self.assertEqual(a.size(0), 2) - a = torch.nested.nested_tensor([torch.rand(1, 2), - torch.rand(1, 8)]) + a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)]) self.assertEqual(a.size(0), 2) self.assertEqual(a.size(1), 1) self.assertRaisesRegex( - RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2)) + RuntimeError, + "Given dimension 2 is irregular and does not have a size", + lambda: a.size(2), + ) - a = torch.nested.nested_tensor([torch.rand(3, 4), - torch.rand(5, 4)]) + a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)]) self.assertEqual(a.size(0), 2) self.assertRaisesRegex( - RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1)) + RuntimeError, + "Given dimension 1 is irregular and does not have a size", + lambda: a.size(1), + ) self.assertEqual(a.size(2), 4) @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.") @@ -476,8 +509,12 @@ def test_is_contiguous(self): self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) # Test querying by memory_format - self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue( + nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) + ) @torch.inference_mode() def test_repr_string(self): @@ -497,7 +534,6 @@ def test_repr_string(self): self.assertEqual(repr(a), expected) def test_to_padded_tensor_on_empty_tensor(self): - nt = torch.nested.nested_tensor([]) empty = torch.nested.to_padded_tensor(nt, 4) self.assertEqual(empty, torch.tensor([])) @@ -510,7 +546,7 @@ def test_nested_namespace(self): def test_to(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking)) @@ -518,113 +554,141 @@ def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)) + self.assertIsNot( + t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True) + ) devices = [t.device] - if t.device.type == 'cuda': + if t.device.type == "cuda": if t.device.index == -1: - devices.append(f'cuda:{torch.cuda.current_device()}') + devices.append(f"cuda:{torch.cuda.current_device()}") elif t.device.index == torch.cuda.current_device(): - devices.append('cuda') + devices.append("cuda") for device in devices: self.assertIs(t, t.to(device, non_blocking=non_blocking)) self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) + self.assertIsNot( + t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True) + ) test_copy_behavior(nt) - self.assertEqual(nt.device, nt.to('cpu').device) - self.assertEqual(nt.device, nt.to('cpu', dtype=torch.float32).device) - self.assertIs(torch.float32, nt.to('cpu', dtype=torch.float32).dtype) + self.assertEqual(nt.device, nt.to("cpu").device) + self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device) + self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype) self.assertEqual(nt.device, nt.to(torch.float32).device) self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype) def test_data_ptr(getter): - self.assertEqual(getter(nt), getter(nt.to('cpu'))) - self.assertEqual(getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))) - self.assertEqual(getter(nt), getter(nt.to('cpu', copy=False))) - self.assertNotEqual(getter(nt), getter(nt.to('cpu', copy=True))) + self.assertEqual(getter(nt), getter(nt.to("cpu"))) + self.assertEqual( + getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False)) + ) + self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False))) + self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True))) test_data_ptr(lambda nt: nt.data_ptr()) if torch.cuda.is_available(): for non_blocking in [True, False]: - for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: + for cuda in [ + "cuda", + "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1", + ]: nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4)) test_copy_behavior(nt2, non_blocking) - self.assertEqual(nt2.device, nt2.to(cuda, non_blocking=non_blocking).device) - self.assertEqual(nt.device, nt2.to('cpu', non_blocking=non_blocking).device) - self.assertEqual(nt2.device, nt.to(cuda, non_blocking=non_blocking).device) - self.assertIs(torch.int32, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype) - self.assertEqual(nt.device, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device) + self.assertEqual( + nt2.device, nt2.to(cuda, non_blocking=non_blocking).device + ) + self.assertEqual( + nt.device, nt2.to("cpu", non_blocking=non_blocking).device + ) + self.assertEqual( + nt2.device, nt.to(cuda, non_blocking=non_blocking).device + ) + self.assertIs( + torch.int32, + nt2.to( + "cpu", dtype=torch.int32, non_blocking=non_blocking + ).dtype, + ) + self.assertEqual( + nt.device, + nt2.to( + "cpu", dtype=torch.int32, non_blocking=non_blocking + ).device, + ) self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) def test_copy_(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) nt_copy = torch.empty_like(nt) nt_copy.copy_(nt) - for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])]) self.assertRaisesRegex( RuntimeError, "copy_ only supports tensors that are the same size for Nested implementations", - lambda: nt_error.copy_(nt) + lambda: nt_error.copy_(nt), ) if torch.cuda.is_available(): - nt = random_nt(torch.device('cuda'), torch.float32, ntensors, (4, 4)) - nt_copy = torch.empty_like(nt, device=torch.device('cpu')) + nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4)) + nt_copy = torch.empty_like(nt, device=torch.device("cpu")) nt_copy.copy_(nt, non_blocking=True) torch.cuda.current_stream(torch.cuda.current_device()).synchronize() - for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) - nt_copy = torch.empty_like(nt, device=torch.device('cpu')) + nt_copy = torch.empty_like(nt, device=torch.device("cpu")) nt_copy.copy_(nt, non_blocking=False) - for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) def test_fill_(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) - nt.fill_(10.) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) + nt.fill_(10.0) for nt_ub in nt.unbind(): t = torch.empty_like(nt_ub) - t.fill_(10.) + t.fill_(10.0) self.assertEqual(nt_ub, t) - fill_tensor = torch.tensor([11.]) + fill_tensor = torch.tensor([11.0]) self.assertRaisesRegex( RuntimeError, "fill_ only supports 0-dimension value tensor", - lambda: nt.fill_(fill_tensor) + lambda: nt.fill_(fill_tensor), ) nt.fill_(fill_tensor[0]) for nt_ub in nt.unbind(): t = torch.empty_like(nt_ub) - t.fill_(11.) + t.fill_(11.0) self.assertEqual(nt_ub, t) def test_zero_(self): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) nt.zero_() for nt_ub in nt.unbind(): t = torch.empty_like(nt_ub) - t.fill_(0.) + t.fill_(0.0) self.assertEqual(nt_ub, t) - @parametrize("func", [torch.ones_like, torch.zeros_like, torch.randn_like], - name_fn=lambda f: f.__name__) + @parametrize( + "func", + [torch.ones_like, torch.zeros_like, torch.randn_like], + name_fn=lambda f: f.__name__, + ) def test_like_functions(self, func): ntensors = 4 - nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) torch.manual_seed(1) nt_like = func(nt) @@ -640,7 +704,8 @@ def test_cat(self): y = random_nt_from_dims([3, 4, None]) output = torch.cat([x, y], dim=0) for out_component, xy_component in zip( - output.unbind(), itertools.chain(x.unbind(), y.unbind())): + output.unbind(), itertools.chain(x.unbind(), y.unbind()) + ): self.assertEqual(out_component, xy_component) # dim=-1 success case @@ -650,29 +715,40 @@ def test_cat(self): y = random_nt_from_similar(x, dims=[-1, -1, 8]) # should be shape (B, *, D + D') when supported output = torch.cat([x, y], dim=-1) - for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()): - self.assertEqual(out_component, torch.cat([x_component, y_component], dim=-1)) + for out_component, x_component, y_component in zip( + output.unbind(), x.unbind(), y.unbind() + ): + self.assertEqual( + out_component, torch.cat([x_component, y_component], dim=-1) + ) # dim between 0 and -1 success case x = random_nt_from_dims([5, None, 2, 3]) # same structure as x but dim=2 differs y = random_nt_from_similar(x, dims=[-1, -1, 4, -1]) output = torch.cat([x, y], dim=2) - for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()): - self.assertEqual(out_component, torch.cat([x_component, y_component], dim=1)) + for out_component, x_component, y_component in zip( + output.unbind(), x.unbind(), y.unbind() + ): + self.assertEqual( + out_component, torch.cat([x_component, y_component], dim=1) + ) # error case: mixed NT / dense inputs x = random_nt_from_dims([5, None, 2]) y = torch.randn(5, 3, 2) with self.assertRaisesRegex( - RuntimeError, "expected each tensor in given list to be nested"): + RuntimeError, "expected each tensor in given list to be nested" + ): torch.cat([x, y], dim=-1) # error case: NTs with different dims x = random_nt_from_dims([5, None, 2]) y = random_nt_from_dims([5, None, 2, 3]) with self.assertRaisesRegex( - RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): torch.cat([x, y], dim=-1) # error case: non-contiguous NT @@ -680,43 +756,56 @@ def test_cat(self): # transpose to put ragged dim next to batch dim x, y = x.transpose(-2, -1), y.transpose(-2, -1) with self.assertRaisesRegex( - RuntimeError, "only contiguous nested tensors are supported"): + RuntimeError, "only contiguous nested tensors are supported" + ): torch.cat([x, y], dim=-1) # error case: multiple ragged dims in inputs x = random_nt_from_dims([5, None, None, 2]) y = random_nt_from_similar(x) with self.assertRaisesRegex( - RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"): + RuntimeError, + "only nested tensors with a single ragged dim next to the batch dim are supported", + ): torch.cat([x, y], dim=-1) # error case: ragged dim not next to batch dim x = random_nt_from_dims([5, 2, None]) y = random_nt_from_similar(x) with self.assertRaisesRegex( - RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"): + RuntimeError, + "only nested tensors with a single ragged dim next to the batch dim are supported", + ): torch.cat([x, y], dim=1) # error case: NTs with different batch sizes x = random_nt_from_dims([5, None, 2]) y = random_nt_from_dims([3, None, 2]) with self.assertRaisesRegex( - RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): torch.cat([x, y], dim=-1) # error case: NTs with different ragged structures - x = torch.nested.nested_tensor([ - torch.randn(2, 6), - torch.randn(4, 6), - torch.randn(5, 6), - ]) - y = torch.nested.nested_tensor([ - torch.randn(5, 6), - torch.randn(4, 6), - torch.randn(2, 6), - ]) + x = torch.nested.nested_tensor( + [ + torch.randn(2, 6), + torch.randn(4, 6), + torch.randn(5, 6), + ] + ) + y = torch.nested.nested_tensor( + [ + torch.randn(5, 6), + torch.randn(4, 6), + torch.randn(2, 6), + ] + ) with self.assertRaisesRegex( - RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): torch.cat([x, y], dim=-1) @@ -728,13 +817,20 @@ def random_nt_pair(self, device, dtype, num_tensors, max_dims): ts1 = [] ts2 = [] for _ in range(num_tensors): - tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims]) + tensor_dims = tuple( + [ + torch.randint(low=0, high=max_dim, size=(1,)).item() + for max_dim in max_dims + ] + ) t1 = torch.randn(tensor_dims, device=device, dtype=dtype) t2 = torch.randn(tensor_dims, device=device, dtype=dtype) ts1.append(t1) ts2.append(t2) - return (torch.nested.nested_tensor(ts1, device=device, dtype=dtype), - torch.nested.nested_tensor(ts2, device=device, dtype=dtype)) + return ( + torch.nested.nested_tensor(ts1, device=device, dtype=dtype), + torch.nested.nested_tensor(ts2, device=device, dtype=dtype), + ) @dtypes(*floating_types_and_half()) def test_detach(self, device, dtype): @@ -766,7 +862,9 @@ def test_detach(self, device, dtype): @dtypes(torch.float, torch.float16, torch.double) def test_unbind_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) ub_contiguous = nt_contiguous.unbind() ub_noncontiguous = nt_noncontiguous.unbind() self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) @@ -785,7 +883,7 @@ def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): nt_to = torch._nested_from_padded_and_nested_example(padded, nt) - for (t1, t2) in zip(nt.unbind(), nt_to.unbind()): + for t1, t2 in zip(nt.unbind(), nt_to.unbind()): self.assertEqual(t1, t2) self.assertEqual(nt.device, nt_to.device) @@ -802,7 +900,7 @@ def _test(size): nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) self.assertEqual(nt_subresult, t_result) @@ -814,28 +912,36 @@ def _test(size): nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) self.assertEqual(nt_subresult, t_result) if size <= 128: # Test with multidimensional tensors after irregular dim # (run only with smaller dimensions to ensure fast execution) - t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False) + t0 = torch.randn( + 4, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t1 = torch.randn( + 10, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t2 = torch.randn( + 7, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) ts = [t0, t1, t2, t0, t2] nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) - layer_norm = torch.nn.LayerNorm((size, size, 4), device=device, dtype=dtype) + layer_norm = torch.nn.LayerNorm( + (size, size, 4), device=device, dtype=dtype + ) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) self.assertEqual(nt_subresult, t_result) # Test where the normalizing dimensions are not all layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype) nt_result = layer_norm(nt) - for (nt_subresult, t) in zip(nt_result.unbind(), ts): + for nt_subresult, t in zip(nt_result.unbind(), ts): t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) self.assertEqual(nt_subresult, t_result) @@ -848,9 +954,15 @@ def _test(size): @torch.inference_mode() def test_layer_norm_breaking(self, device, dtype): size = 128 - t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False) - t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False) + t0 = torch.randn( + 4, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t1 = torch.randn( + 10, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t2 = torch.randn( + 7, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) ts = [t0, t1, t2, t0, t2] nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype) @@ -869,7 +981,7 @@ def test_layer_norm_breaking(self, device, dtype): @decorateIf( xfailIfTorchDynamo, # only fails in python 3.11. TODO: Ensure this is fixed once views work! - lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11) + lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11), ) @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) def test_embedding(self, device, layout): @@ -877,14 +989,15 @@ def test_embedding(self, device, layout): torch.randint(100, (L,), device=device, dtype=torch.int64) for L in torch.randint(5, 50, (8,)) ] - x = torch.nested.nested_tensor(inputs, device=device, dtype=torch.int64, layout=layout) + x = torch.nested.nested_tensor( + inputs, device=device, dtype=torch.int64, layout=layout + ) emb = torch.nn.Embedding(100, 8, device=device) y = emb(x) ys = y.unbind() for i, inp in enumerate(inputs): self.assertEqual(emb(inp), ys[i]) - @skipMeta @torch.inference_mode() @dtypes(*floating_types_and_half()) @@ -892,11 +1005,12 @@ def test_masked_fill(self, device, dtype): # nested tensor * nested tensor (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) - ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]) + ref = torch.nested.nested_tensor( + [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())] + ) out = nt.masked_fill(mask, 0) self.assertEqual(ref, out) - @dtypes(torch.float, torch.float16) def test_to_padded_tensor_simple(self, device, dtype): t = torch.randn(4, 4, 4, device=device, dtype=dtype) @@ -924,8 +1038,12 @@ def test_to_padded_tensor_output_size(self, device, dtype): ts[0] = ts[0][:-1] nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) for padding_value in (0, 1): - padded = torch.nested.to_padded_tensor(nt, padding_value, output_size=output_size) - correct_output = torch.ones(output_size, device=device, dtype=dtype) * padding_value + padded = torch.nested.to_padded_tensor( + nt, padding_value, output_size=output_size + ) + correct_output = ( + torch.ones(output_size, device=device, dtype=dtype) * padding_value + ) correct_output[:4:, :4, :4] = t.clone() if padding_value == 0: correct_output[0][3] = torch.zeros_like(correct_output[0][3]) @@ -949,7 +1067,7 @@ def test_to_padded_tensor_dim2(self, device, dtype): for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) - next_output[:t.size(0)].copy_(t) + next_output[: t.size(0)].copy_(t) correct_output = torch.stack(correct_output) padded = torch.nested.to_padded_tensor(nt, pad) self.assertEqual(padded, correct_output) @@ -967,7 +1085,7 @@ def test_to_padded_tensor_dim3(self, device, dtype): for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) - next_output[:t.size(0), :t.size(1)].copy_(t) + next_output[: t.size(0), : t.size(1)].copy_(t) correct_output = torch.stack(correct_output) padded = torch.nested.to_padded_tensor(nt, pad) self.assertEqual(padded, correct_output) @@ -985,7 +1103,7 @@ def test_to_padded_tensor_dim4(self, device, dtype): for t in ts: next_output = torch.ones_like(ts[2]) * pad correct_output.append(next_output) - next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t) + next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t) correct_output = torch.stack(correct_output) padded = torch.nested.to_padded_tensor(nt, pad) self.assertEqual(padded, correct_output) @@ -997,22 +1115,25 @@ def test_to_padded_tensor_dim4(self, device, dtype): @dtypes(torch.float, torch.float16, torch.double) @torch.inference_mode() def test_to_padded_tensor_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) # test noncontiguous_to_padded_tensor functionality self.assertEqual( torch.nested.to_padded_tensor(nt_contiguous, 0.0), - noncontiguous_to_padded_tensor(nt_noncontiguous)) + noncontiguous_to_padded_tensor(nt_noncontiguous), + ) # test to_padded_tensor error message self.assertRaisesRegex( RuntimeError, r"for now to_padded_tensor only supports contiguous nested tensor", - lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0) + lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0), ) @skipMeta def test_device_checks(self, device): nt = torch.nested.nested_tensor([], device=device) - is_cuda = 'cuda' in str(device) + is_cuda = "cuda" in str(device) self.assertEqual(nt.is_cuda, is_cuda) @dtypes(torch.float, torch.float16, torch.double) @@ -1060,26 +1181,35 @@ def test_nested_tensor_indexing(self, device, dtype): self.assertEqual(nt[-1], x1) grad_x0 = torch.randn((2, 5), device=device, dtype=dtype) nt[0].backward(grad_x0) - expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]) + expected_grad = torch.nested.nested_tensor( + [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)] + ) self.assertEqual(nt.grad, expected_grad) - @parametrize("func", [subtest(torch.nn.functional.relu, name='relu'), - subtest(torch.nn.functional.relu_, name='relu_'), - subtest(torch.nn.functional.gelu, name='gelu'), - subtest(torch._C._nn.gelu_, name='gelu_'), - subtest(torch.tanh, name='tanh'), - subtest(torch.tanh_, name='tanh_'), - subtest(torch.neg, name='neg'), - subtest(torch.nn.functional.silu, name='silu'), - subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'), - subtest(torch.abs, name="abs"), - subtest(torch.abs_, name="abs_"), - subtest(torch.sgn, name="sgn"), - subtest(torch.logical_not, name='logical_not'), - subtest(torch.sin, name='sin'), - subtest(torch.cos, name='cos')]) + @parametrize( + "func", + [ + subtest(torch.nn.functional.relu, name="relu"), + subtest(torch.nn.functional.relu_, name="relu_"), + subtest(torch.nn.functional.gelu, name="gelu"), + subtest(torch._C._nn.gelu_, name="gelu_"), + subtest(torch.tanh, name="tanh"), + subtest(torch.tanh_, name="tanh_"), + subtest(torch.neg, name="neg"), + subtest(torch.nn.functional.silu, name="silu"), + subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"), + subtest(torch.abs, name="abs"), + subtest(torch.abs_, name="abs_"), + subtest(torch.sgn, name="sgn"), + subtest(torch.logical_not, name="logical_not"), + subtest(torch.sin, name="sin"), + subtest(torch.cos, name="cos"), + ], + ) def test_activations(self, device, func): - nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32) + nt, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device=device, dtype=torch.float32 + ) nested_result = func(nt) self.assertTrue(nested_result.is_nested) for t, t_res in zip(nt.unbind(), nested_result.unbind()): @@ -1087,13 +1217,14 @@ def test_activations(self, device, func): self.assertRaisesRegex( RuntimeError, "NestedTensor must be contiguous to get buffer.", - lambda: func(nt_noncontiguous)) + lambda: func(nt_noncontiguous), + ) - @parametrize("func", [subtest(torch.ge, name='ge'), - subtest(torch.eq, name='eq')]) + @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")]) def test_binary_ops_with_scalar(self, device, func): nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( - (2, 3, 6, 7), device=device, dtype=torch.float32) + (2, 3, 6, 7), device=device, dtype=torch.float32 + ) scalar = 0.0 # should work regardless of contiguity @@ -1129,30 +1260,42 @@ def test_nested_tensor_chunk(self, device, dtype): # Failure chunking on ragged dimensions self.assertRaisesRegex( - RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.", - lambda: torch.chunk(nt, 5, dim=1)) + RuntimeError, + "Chunk for nested tensors is currently only supported for the last dimension.", + lambda: torch.chunk(nt, 5, dim=1), + ) self.assertRaisesRegex( - RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.", - lambda: torch.chunk(nt, 5, dim=0)) + RuntimeError, + "Chunk for nested tensors is currently only supported for the last dimension.", + lambda: torch.chunk(nt, 5, dim=0), + ) # Failure on non-contiguous nt _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) self.assertRaisesRegex( - RuntimeError, "chunk expects `self` to be contiguous.", lambda: torch.chunk(nt_noncontiguous, 5, dim=-1)) + RuntimeError, + "chunk expects `self` to be contiguous.", + lambda: torch.chunk(nt_noncontiguous, 5, dim=-1), + ) # Failure when calling non divisible n_chunks self.assertRaisesRegex( - RuntimeError, "Chunk for nested tensors is only supported for " + RuntimeError, + "Chunk for nested tensors is only supported for " "nested tensors with trailing dimension divisible by chunks.", - lambda: torch.chunk(nt, 5, dim=-1)) + lambda: torch.chunk(nt, 5, dim=-1), + ) # Failure when calling backward on a chunk a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True) b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True) nt_grad = torch.nested.as_nested_tensor([a, b]) chunked = torch.chunk(nt_grad, 2, dim=-1) - self.assertRaisesRegex(RuntimeError, "derivative for aten::chunk is not implemented", - lambda: chunked[0].backward(chunked[0].clone())) + self.assertRaisesRegex( + RuntimeError, + "derivative for aten::chunk is not implemented", + lambda: chunked[0].backward(chunked[0].clone()), + ) @dtypes(*floating_types_and_half()) def test_nested_tensor_split_with_sizes(self, device, dtype): @@ -1169,42 +1312,56 @@ def test_nested_tensor_split_with_sizes(self, device, dtype): nt_splits = nt.split_with_sizes(split_sizes, dim=-1) for i, nt_split in enumerate(nt_splits): - self.assertEqual(nt_split, torch.nested.nested_tensor( - [a_splits[i], b_splits[i], c_splits[i]])) - dense_strides = torch.stack([ - torch.tensor(a_splits[i].stride()), - torch.tensor(b_splits[i].stride()), - torch.tensor(c_splits[i].stride()) - ]) + self.assertEqual( + nt_split, + torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]), + ) + dense_strides = torch.stack( + [ + torch.tensor(a_splits[i].stride()), + torch.tensor(b_splits[i].stride()), + torch.tensor(c_splits[i].stride()), + ] + ) self.assertEqual(nt_split._nested_tensor_strides(), dense_strides) self.assertFalse(nt_split.is_contiguous()) # Failure calling on ragged dimensions self.assertRaisesRegex( - RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.", - lambda: torch.split_with_sizes(nt, split_sizes, dim=1)) + RuntimeError, + "split_with_sizes for nested tensors is currently only supported for the last dimension.", + lambda: torch.split_with_sizes(nt, split_sizes, dim=1), + ) # Failure calling on non-last dimension self.assertRaisesRegex( - RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.", - lambda: torch.split_with_sizes(nt, split_sizes, dim=0)) + RuntimeError, + "split_with_sizes for nested tensors is currently only supported for the last dimension.", + lambda: torch.split_with_sizes(nt, split_sizes, dim=0), + ) # Failure on non-contiguous nt _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) self.assertRaisesRegex( - RuntimeError, "split_with_sizes expects `self` to be contiguous.", - lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1)) + RuntimeError, + "split_with_sizes expects `self` to be contiguous.", + lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1), + ) # Failure when calling with split_sizes that don't cover the full dim size bad_split_sizes = [4, 6, 9] # don't add up to 20 self.assertRaisesRegex( - RuntimeError, "split_with_sizes expects split_sizes to sum exactly to 20", - lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1)) + RuntimeError, + "split_with_sizes expects split_sizes to sum exactly to 20", + lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1), + ) @dtypes(torch.float, torch.float16, torch.double) @torch.inference_mode() def test_nested_tensor_indexing_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) n = nt_contiguous.size(0) for i in range(n): @@ -1224,7 +1381,9 @@ def test_nested_tensor_add(self, device, dtype, transpose): nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) else: (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) out = nt1 + nt2 self.assertEqual(ref, out) @@ -1242,7 +1401,9 @@ def test_nested_tensor_sub(self, device, dtype, transpose): nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) else: (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) out = nt1 - nt2 self.assertEqual(ref, out) @@ -1253,9 +1414,11 @@ def test_nested_tensor_sub(self, device, dtype, transpose): def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): def _test_add_mul(nt, t): ref_add = torch.nested.nested_tensor( - [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) + [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] + ) ref_mul = torch.nested.nested_tensor( - [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) + [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] + ) self.assertEqual(nt.add(t), ref_add) self.assertEqual(nt.mul(t), ref_mul) @@ -1280,7 +1443,9 @@ def _test_add_mul(nt, t): def test_nested_tensor_mul(self, device, dtype): # nested tensor * nested tensor (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) out = nt1 * nt2 self.assertEqual(ref, out) # nested tensor * scalar @@ -1300,12 +1465,12 @@ def test_nested_tensor_mul(self, device, dtype): self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a nested self and non-nested other", - lambda: nt1.mul(vector) + lambda: nt1.mul(vector), ) self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a non-nested self and nested other", - lambda: vector.mul(nt1) + lambda: vector.mul(nt1), ) @dtypes(torch.float, torch.float16) @@ -1321,31 +1486,43 @@ def test_nested_tensor_div(self, device, dtype): out = nt.transpose(1, 2) / 4.0 self.assertEqual(ref_transposed, out) - ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())] + ) out = nt / nt2 self.assertEqual(ref, out) out = nt.transpose(1, 2) / nt2.transpose(1, 2) self.assertEqual(ref.transpose(1, 2), out) - nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()]) + nt_transpose_copy = torch.nested.nested_tensor( + [t.transpose(0, 1) for t in nt.unbind()] + ) self.assertRaisesRegex( - RuntimeError, "div requires strides to match when given NestedTensors", - lambda: nt_transpose_copy.transpose(1, 2) / nt2) + RuntimeError, + "div requires strides to match when given NestedTensors", + lambda: nt_transpose_copy.transpose(1, 2) / nt2, + ) - nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype + ) nt_chunks = nt.chunk(2, -1) self.assertRaisesRegex( - RuntimeError, "div requires offsets to match when given NestedTensors", - lambda: nt_chunks[0] / nt_chunks[1]) + RuntimeError, + "div requires offsets to match when given NestedTensors", + lambda: nt_chunks[0] / nt_chunks[1], + ) @dtypes(torch.float, torch.float16) @skipMeta @torch.inference_mode() def test_nested_tensor_add_in_place(self, device, dtype): (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) nt1 += nt2 self.assertEqual(ref, nt1) @@ -1355,7 +1532,9 @@ def test_nested_tensor_add_in_place(self, device, dtype): def test_nested_tensor_mul_in_place(self, device, dtype): # nested tensor * nested tensor (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) - ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) + ref = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) nt1 *= nt2 self.assertEqual(ref, nt1) # nested tensor * scalar @@ -1371,19 +1550,19 @@ def test_nested_tensor_mul_in_place(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", - lambda: scalar.mul_(nt1) + lambda: scalar.mul_(nt1), ) # error case: numel == 1 but dim > 0 vector = torch.tensor([number]).to(dtype).to(device) self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a nested self and non-nested other", - lambda: nt1.mul_(vector) + lambda: nt1.mul_(vector), ) self.assertRaisesRegex( RuntimeError, "Expected both self and other to be nested, but got a non-nested self and nested other", - lambda: vector.mul_(nt1) + lambda: vector.mul_(nt1), ) @onlyCPU @@ -1419,14 +1598,26 @@ def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): test_sum(device, dtype, ntensors, max_sizes, len(max_sizes)) # Test error inputs - with self.assertRaisesRegex(RuntimeError, "NestedTensor can only be reduced across the last"): - torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(0, keepdim=True) + with self.assertRaisesRegex( + RuntimeError, "NestedTensor can only be reduced across the last" + ): + torch.nested.nested_tensor( + [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] + ).sum(0, keepdim=True) - with self.assertRaisesRegex(RuntimeError, "NestedTensor only allows reduction of a single"): - torch.nested.nested_tensor([torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]).sum([0, 1], keepdim=True) + with self.assertRaisesRegex( + RuntimeError, "NestedTensor only allows reduction of a single" + ): + torch.nested.nested_tensor( + [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])] + ).sum([0, 1], keepdim=True) - with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."): - torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1) + with self.assertRaisesRegex( + RuntimeError, "NestedTensor always requires keepdim=True for now." + ): + torch.nested.nested_tensor( + [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] + ).sum(-1) @dtypes(torch.float, torch.float16) def test_contiguous(self, device, dtype): @@ -1436,8 +1627,12 @@ def test_contiguous(self, device, dtype): # whose numels is now less than the size of the buffer. Clone was # previously creating a new NT with a buffer that was the same size as the # original. - nt_contiguous = torch.nested.nested_tensor([torch.randn(2, 20, device=device, dtype=dtype), - torch.randn(4, 20, device=device, dtype=dtype)]) + nt_contiguous = torch.nested.nested_tensor( + [ + torch.randn(2, 20, device=device, dtype=dtype), + torch.randn(4, 20, device=device, dtype=dtype), + ] + ) # Split up the last dimension which has a consistent size of 20 into 5 chunks chunks = nt_contiguous.chunk(5, dim=-1) @@ -1549,12 +1744,12 @@ def test_softmax(self, device, dtype): self.assertRaisesRegex( RuntimeError, "Cannot apply softmax across nested dimension 0", - lambda: torch.nn.functional.softmax(nt, 0) + lambda: torch.nn.functional.softmax(nt, 0), ) self.assertRaisesRegex( RuntimeError, "Cannot apply softmax across nested dimension 0", - lambda: torch.nn.functional.softmax(nt, -3) + lambda: torch.nn.functional.softmax(nt, -3), ) # error case: dimension out of range self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) @@ -1581,91 +1776,95 @@ def test_softmax(self, device, dtype): @dtypes(torch.float, torch.double) @torch.inference_mode() def test_softmax_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) self.assertEqual( torch.nn.functional.softmax(nt_contiguous, -1), - torch.nn.functional.softmax(nt_noncontiguous, -1)) + torch.nn.functional.softmax(nt_noncontiguous, -1), + ) def _test_bmm(self, device, dtype): # error case: not 3D tensors nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) - nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + nt2 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt0.bmm(nt0) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt0.bmm(nt1) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt0.bmm(nt2) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt1.bmm(nt0) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt1.bmm(nt1) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1) ) self.assertRaisesRegex( - RuntimeError, - "batch1 must be a 3D tensor", - lambda: nt1.bmm(nt2) + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2) ) self.assertRaisesRegex( - RuntimeError, - "batch2 must be a 3D tensor", - lambda: nt2.bmm(nt0) + RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0) ) self.assertRaisesRegex( - RuntimeError, - "batch2 must be a 3D tensor", - lambda: nt2.bmm(nt1) + RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1) ) # error case: incompatible batch size - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), - torch.randn((4, 5)), - torch.randn((4, 7))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], + device=device, + dtype=dtype, + ) self.assertRaisesRegex( RuntimeError, "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", - lambda: nt0.bmm(nt1) + lambda: nt0.bmm(nt1), ) self.assertRaisesRegex( RuntimeError, "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", - lambda: nt1.bmm(nt0) + lambda: nt1.bmm(nt0), ) # error case: underlying matrices cannot be multiplied - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", - lambda: nt0.bmm(nt0) + lambda: nt0.bmm(nt0), ) # normal nested tensor - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) - expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( + torch.nested.to_padded_tensor(nt1, 0.0) + ) if dtype == torch.float16: self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) else: self.assertEqual(actual, expect) # nested tensor bmm normal tensor - nt0 = torch.nested.nested_tensor([torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype + ) nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) @@ -1684,10 +1883,11 @@ def _test_bmm(self, device, dtype): else: self.assertEqual(actual, expect) - # normal tensor bmm nested tensor nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device) - nt1 = torch.nested.nested_tensor([torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0)) if dtype == torch.float16: @@ -1696,10 +1896,16 @@ def _test_bmm(self, device, dtype): self.assertEqual(actual, expect) # test tensorcore path - nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) - expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( + torch.nested.to_padded_tensor(nt1, 0.0) + ) if dtype == torch.float16: self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) else: @@ -1719,11 +1925,16 @@ def test_bmm_cpu(self, device, dtype): # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' @dtypes(torch.float, torch.double) def test_bmm_noncontiguous(self, device, dtype): - nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) - nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( + (2, 3), device, dtype + ) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( + (6, 7), device, dtype + ) self.assertEqual( nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), - nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) + nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous), + ) @dtypes(torch.float, torch.double) def test_matmul_with_bmm_path(self, device, dtype): @@ -1756,142 +1967,176 @@ def unbind_rebind_matmul(nt1, nt2): seq_len = np.random.randint(2, 5) t3s.append(torch.randn(seq_len, n_heads, head_dim)) t4s.append(torch.randn(seq_len, n_heads, head_dim)) - nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2) - nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3) + nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose( + 1, 2 + ) + nt4 = ( + torch.nested.nested_tensor(t4s, device=device, dtype=dtype) + .transpose(1, 2) + .transpose(2, 3) + ) self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' @dtypes(torch.float, torch.double) def test_matmul(self, device, dtype): # error case: one is nested but the other is not - nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) t = torch.randn(4, device=device, dtype=dtype) self.assertRaisesRegex( RuntimeError, "Expected both to be nested, but got a nested self and non-nested other", - lambda: torch.matmul(nt, t) + lambda: torch.matmul(nt, t), ) self.assertRaisesRegex( RuntimeError, "Expected both to be nested, but got a non-nested self and nested other", - lambda: torch.matmul(t, nt) + lambda: torch.matmul(t, nt), ) # error case: not 3+D tensors nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) - nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + nt2 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt0, nt0) + lambda: torch.matmul(nt0, nt0), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt0, nt2) + lambda: torch.matmul(nt0, nt2), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt1, nt0) + lambda: torch.matmul(nt1, nt0), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt1, nt1) + lambda: torch.matmul(nt1, nt1), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", - lambda: torch.matmul(nt1, nt2) + lambda: torch.matmul(nt1, nt2), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", - lambda: torch.matmul(nt2, nt0) + lambda: torch.matmul(nt2, nt0), ) self.assertRaisesRegex( RuntimeError, r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", - lambda: torch.matmul(nt2, nt1) + lambda: torch.matmul(nt2, nt1), ) # error case: incompatible batch size - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), - torch.randn((4, 5)), - torch.randn((4, 7))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], + device=device, + dtype=dtype, + ) self.assertRaisesRegex( RuntimeError, r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) self.assertRaisesRegex( RuntimeError, r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", - lambda: torch.matmul(nt1, nt0) + lambda: torch.matmul(nt1, nt0), ) # error case: incompatible (wrong) batch sizes that shouldn't even broadcast? - nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)), - torch.randn((2, 3, 4))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((3, 4, 6)), - torch.randn((3, 4, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, "matmul(): For nested tensors, batch dimensions must have the same sizes,", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) # error case: incompatible batch sizes that should technically broadcast - nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)), - torch.randn((1, 3, 4))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)), - torch.randn((3, 4, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, "matmul(): For nested tensors, batch dimensions must have the same sizes,", - lambda: torch.matmul(nt0, nt1) + lambda: torch.matmul(nt0, nt1), ) # error case: underlying matrices cannot be multiplied - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) self.assertRaisesRegex( RuntimeError, "matmul(): Nested tensors cannot be matrix multiplied", - lambda: torch.matmul(nt0, nt0) + lambda: torch.matmul(nt0, nt0), ) # normal nested tensor: 3D - nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) - expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) self.assertEqual(actual, expect) # normal nested tensor: 4D (with testing for batch_size=1) - nt0 = torch.nested.nested_tensor([torch.randn((1, 2, 4)), - torch.randn((8, 3, 7))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)), - torch.randn((8, 7, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype + ) actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) - expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) self.assertEqual(actual, expect) # normal nested tensor: 5D - nt0 = torch.nested.nested_tensor([torch.randn((8, 9, 2, 4)), - torch.randn((8, 9, 3, 7))], - device=device, dtype=dtype) - nt1 = torch.nested.nested_tensor([torch.randn((8, 9, 4, 6)), - torch.randn((8, 9, 7, 5))], - device=device, dtype=dtype) + nt0 = torch.nested.nested_tensor( + [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))], + device=device, + dtype=dtype, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))], + device=device, + dtype=dtype, + ) actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) - expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) self.assertEqual(actual, expect) # only supported on CUDA for now @@ -1910,11 +2155,16 @@ def test_matmul_nt_with_broadcasted_t(self, device, dtype): # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' @dtypes(torch.float, torch.double) def test_matmul_noncontiguous(self, device, dtype): - nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) - nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( + (2, 3), device, dtype + ) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( + (6, 7), device, dtype + ) self.assertEqual( torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), - torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous)) + torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous), + ) @dtypes(torch.float, torch.double) def test_linear(self, device, dtype): @@ -1929,29 +2179,39 @@ def test_linear(self, device, dtype): torch.functional.F.linear(nt, weight, bias) # invalid nested tensor dimension - msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2' - nt1 = torch.nested.nested_tensor([torch.randn(1, device=device, dtype=dtype), - torch.randn(2, device=device, dtype=dtype)]) + msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2" + nt1 = torch.nested.nested_tensor( + [ + torch.randn(1, device=device, dtype=dtype), + torch.randn(2, device=device, dtype=dtype), + ] + ) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt1, weight, bias) # invalid weight shape - msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3' + msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3" weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt, weight1, bias) # inconsistent last dim of nested tensor msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:" - nt2 = torch.nested.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype), - torch.randn(2, 3, device=device, dtype=dtype)]) + nt2 = torch.nested.nested_tensor( + [ + torch.randn(1, 2, device=device, dtype=dtype), + torch.randn(2, 3, device=device, dtype=dtype), + ] + ) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt2, weight, bias) # Mismatch of nested tensor last dim and weight dimension weight2 = torch.randn(2, 4, device=device, dtype=dtype) - msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \ + msg = ( + r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" + ) with self.assertRaisesRegex(RuntimeError, msg): torch.functional.F.linear(nt, weight2, bias) @@ -1966,22 +2226,26 @@ def test_linear(self, device, dtype): # since linear does not support noncontiguous buffer yet @dtypes(torch.float, torch.double) def test_linear_noncontiguous(self, device, dtype): - nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) weight = torch.randn((8, 5), device=device, dtype=dtype) self.assertRaisesRegex( RuntimeError, r"for now linear only supports contiguous nested tensor", - lambda: torch.nn.functional.linear(nt_noncontiguous, weight) + lambda: torch.nn.functional.linear(nt_noncontiguous, weight), ) @dtypes(torch.float, torch.float16, torch.double) def test_to_padded_tensor_zero_numel_errors(self, device, dtype): ts = [torch.ones(1, 0), torch.ones(0, 0)] - nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided) + nt = torch.nested.nested_tensor( + ts, device=device, dtype=dtype, layout=torch.strided + ) self.assertRaisesRegex( RuntimeError, r"at least one constituent tensor should have non-zero numel", - lambda: torch.nested.to_padded_tensor(nt, 0.0) + lambda: torch.nested.to_padded_tensor(nt, 0.0), ) @dtypes(torch.float, torch.float16, torch.double) @@ -1991,12 +2255,12 @@ def test_transpose(self, device, dtype): self.assertRaisesRegex( RuntimeError, "Nested tensor dimension 0 cannot be transposed", - lambda: nt.transpose(0, 1) + lambda: nt.transpose(0, 1), ) self.assertRaisesRegex( RuntimeError, "Nested tensor dimension 0 cannot be transposed", - lambda: nt.transpose(1, -3) + lambda: nt.transpose(1, -3), ) # error case: dimension out of range self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) @@ -2017,13 +2281,13 @@ def test_squeeze_unsqueeze(self, device, dtype): self.assertRaisesRegex( RuntimeError, "For nested tensors, squeeze without the dim argument", - lambda: nt.squeeze() + lambda: nt.squeeze(), ) # error case: squeeze nested dimension self.assertRaisesRegex( RuntimeError, "For nested tensors, squeezing dimension 0", - lambda: nt.squeeze(0) + lambda: nt.squeeze(0), ) # error case: dimension out of range self.assertRaises(IndexError, lambda: nt.squeeze(3)) @@ -2033,7 +2297,7 @@ def test_squeeze_unsqueeze(self, device, dtype): self.assertRaisesRegex( RuntimeError, "For nested tensors, squeezing a nested tensor of singleton", - lambda: nt_singleton.squeeze(1) + lambda: nt_singleton.squeeze(1), ) # squeezing a dim which does not have size 1 should be a no-op @@ -2044,7 +2308,7 @@ def test_squeeze_unsqueeze(self, device, dtype): nt_sizes = nt._nested_tensor_size() nt_strides = nt._nested_tensor_strides() for i in range(-2, 4): - if (i == 0): + if i == 0: # cannot unsqueeze batch dim continue nt_unsqueezed = nt.unsqueeze(i) @@ -2052,9 +2316,12 @@ def test_squeeze_unsqueeze(self, device, dtype): wrapped_i = i + nt.dim() + 1 if i < 0 else i # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim size_idx = wrapped_i - 1 - self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long)) + self.assertEqual( + nt_unsqueezed._nested_tensor_size()[:, size_idx], + torch.ones(2, dtype=torch.long), + ) unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] - if (i == nt.ndim or i == -1): + if i == nt.ndim or i == -1: self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) else: stride_col_after = nt_strides[:, size_idx] @@ -2092,25 +2359,25 @@ def test_view(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"shape '\[\]' is invalid for a nested tensor", - lambda: nt.view(()) + lambda: nt.view(()), ) # error case: empty nested tensor nt_empty = torch.nested.nested_tensor([]) self.assertRaisesRegex( RuntimeError, "empty nested tensor cannot be reshaped", - lambda: nt_empty.view(-1) + lambda: nt_empty.view(-1), ) # error case: -1 for batch size self.assertRaisesRegex( RuntimeError, r"view: For now nested view cannot change or infer the implicit batch dimension", - lambda: nt.view(-1, 2, 3) + lambda: nt.view(-1, 2, 3), ) self.assertRaisesRegex( RuntimeError, r"shape '\[.*\]' is invalid for input of size [0-9]+", - lambda: nt.view(4, 2, 3) + lambda: nt.view(4, 2, 3), ) # normal case x0 = torch.randn((2, 20), device=device, dtype=dtype) @@ -2121,7 +2388,7 @@ def test_view(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"For now nested view cannot change or infer the implicit batch dimension", - lambda: nt.transpose(-1, -2).view(40, -1) + lambda: nt.transpose(-1, -2).view(40, -1), ) # inherit only the ragged dimension # (2, 20) -> (2, 5, 4) @@ -2137,13 +2404,15 @@ def test_view(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"only one dimension can be inferred", - lambda: nt1.view(2, -1, -1, 2, 2) + lambda: nt1.view(2, -1, -1, 2, 2), ) @dtypes(torch.float, torch.float16, torch.double) def test_view_inference_mode_interaction(self, device, dtype): # Construct in default mode and view while in inference mode - nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype + ) with torch.inference_mode(): ntT = nt.view(2, -1, 4, 5) ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) @@ -2152,7 +2421,9 @@ def test_view_inference_mode_interaction(self, device, dtype): self.assertEqual(ptT, ptT_from_ntT) # Construct and view while in inference mode with torch.inference_mode(): - nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) + nt = torch.nested.nested_tensor( + [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype + ) ntT = nt.view(2, -1, 4, 5) ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) pt = torch.nested.to_padded_tensor(nt, 0.0) @@ -2166,25 +2437,25 @@ def test_reshape(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"shape '\[\]' is invalid for a nested tensor", - lambda: nt.reshape(()) + lambda: nt.reshape(()), ) # error case: empty nested tensor nt_empty = torch.nested.nested_tensor([]) self.assertRaisesRegex( RuntimeError, "empty nested tensor cannot be reshaped", - lambda: nt_empty.reshape(-1) + lambda: nt_empty.reshape(-1), ) # error case: -1 for batch size self.assertRaisesRegex( RuntimeError, r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", - lambda: nt.reshape(-1, 2, 3) + lambda: nt.reshape(-1, 2, 3), ) self.assertRaisesRegex( RuntimeError, r"shape '\[.*\]' is invalid for input of size [0-9]+", - lambda: nt.reshape(4, 2, 3) + lambda: nt.reshape(4, 2, 3), ) # normal case x0 = torch.randn((2, 20), device=device, dtype=dtype) @@ -2195,7 +2466,7 @@ def test_reshape(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", - lambda: nt.transpose(-1, -2).reshape(40, -1) + lambda: nt.transpose(-1, -2).reshape(40, -1), ) # inherit only the ragged dimension # (2, 20) -> (2, 5, 4) @@ -2211,7 +2482,7 @@ def test_reshape(self, device, dtype): self.assertRaisesRegex( RuntimeError, r"only one dimension can be inferred", - lambda: nt1.reshape(2, -1, -1, 2, 2) + lambda: nt1.reshape(2, -1, -1, 2, 2), ) @dtypes(torch.float, torch.float16, torch.double) @@ -2230,35 +2501,50 @@ def test_narrow(self, device, dtype): # dim != 0 is not supported for dim in range(1, nt.dim()): - with self.assertRaisesRegex(RuntimeError, "only dim=0 supported for nested tensors"): + with self.assertRaisesRegex( + RuntimeError, "only dim=0 supported for nested tensors" + ): nt.narrow(dim=dim, start=0, length=1) # error case: non-contiguous NT _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4)) - with self.assertRaisesRegex(RuntimeError, "only contiguous nested tensors supported"): + with self.assertRaisesRegex( + RuntimeError, "only contiguous nested tensors supported" + ): nt_noncont.narrow(dim=0, start=0, length=1) @parametrize("input_dim", [3, 4]) def test_scaled_dot_product_attention(self, device, input_dim): - def rand_tensor(*shape): return torch.randn(shape, device=device) E = 8 if input_dim == 3: # Shape: (N, L, E); ragged L - query = torch.nested.nested_tensor([rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]) + query = torch.nested.nested_tensor( + [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)] + ) # Shape: (N, S, E); ragged S - key = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) - value = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) + key = torch.nested.nested_tensor( + [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] + ) + value = torch.nested.nested_tensor( + [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] + ) elif input_dim == 4: # In the 4D case the L and S is ragged # Shape: (N, N', L, E); ragged N' and L - query = torch.nested.nested_tensor([rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]) + query = torch.nested.nested_tensor( + [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)] + ) # Shape: (N, N', S, E); ragged N' and S - key = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) - value = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) + key = torch.nested.nested_tensor( + [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] + ) + value = torch.nested.nested_tensor( + [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] + ) else: self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") @@ -2266,31 +2552,43 @@ def rand_mask(size): return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) # Shape: (N, L, S); ragged L and S matching above - attn_mask = torch.nested.nested_tensor([rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]) + attn_mask = torch.nested.nested_tensor( + [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))] + ) dropout_p = 0.0 # no dropout for reproducibility # Success case: no attn_mask set and is_causal=False. actual = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p) + query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p + ) expected_outputs = [] for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): output = torch.nn.functional.scaled_dot_product_attention( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_mask=None, dropout_p=dropout_p) + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + attn_mask=None, + dropout_p=dropout_p, + ) expected_outputs.append(output.squeeze(0)) expected_output_nested = torch.nested.nested_tensor(expected_outputs) self.assertEqual(actual, expected_output_nested) # Error case: explicit attn_mask set. - with self.assertRaisesRegex(RuntimeError, "not supported when an explicit attn_mask is set"): + with self.assertRaisesRegex( + RuntimeError, "not supported when an explicit attn_mask is set" + ): torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p) + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p + ) # Error case: is_causal=True. with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): torch.nn.functional.scaled_dot_product_attention( - query, key, value, dropout_p=dropout_p, is_causal=True) + query, key, value, dropout_p=dropout_p, is_causal=True + ) @dtypes(torch.float, torch.float16, torch.double) def test_empty_like(self, device, dtype): @@ -2306,10 +2604,10 @@ def test_empty_like(self, device, dtype): if torch.cuda.is_available(): if device == "cpu": - nt_cuda = torch.empty_like(nt, device='cuda') + nt_cuda = torch.empty_like(nt, device="cuda") self.assertEqual(torch.device("cuda").type, nt_cuda.device.type) else: - nt_cpu = torch.empty_like(nt, device='cpu') + nt_cpu = torch.empty_like(nt, device="cpu") self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) # Check changing dtype of empty_like nested tensor output @@ -2333,19 +2631,36 @@ def test_empty_like(self, device, dtype): assert nt_noncont.is_same_size(nt_empty_non_contig) # Test the contiguous memory format option - nt_empty_contig = torch.empty_like(nt_cont, memory_format=torch.contiguous_format) + nt_empty_contig = torch.empty_like( + nt_cont, memory_format=torch.contiguous_format + ) assert nt_cont.is_same_size(nt_empty_contig) assert nt_empty_contig.is_contiguous() - nt_empty_non_contig = torch.empty_like(nt_noncont, memory_format=torch.contiguous_format) + nt_empty_non_contig = torch.empty_like( + nt_noncont, memory_format=torch.contiguous_format + ) assert nt_noncont.is_same_size(nt_empty_non_contig) assert nt_empty_non_contig.is_contiguous() # Test other memory formats fail - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last)) - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last)) - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d)) - self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d)) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d), + ) + @markDynamoStrictTest class TestNestedTensorAutograd(TestCase): @@ -2353,12 +2668,26 @@ class TestNestedTensorAutograd(TestCase): # includes the default parameters used for testing ops with gradcheck. However nested tensor # does not support the stack op therefore we turn it off for these tests def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): - return torch.nested.nested_tensor([torch.randn(1, 2,), - torch.randn(7, 8)], requires_grad=requires_grad, device=tensor_device) + return torch.nested.nested_tensor( + [ + torch.randn( + 1, + 2, + ), + torch.randn(7, 8), + ], + requires_grad=requires_grad, + device=tensor_device, + ) def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): - return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad), - torch.randn(7, 8, requires_grad=requires_grad)], device=tensor_device) + return torch.nested.as_nested_tensor( + [ + torch.randn(1, 2, requires_grad=requires_grad), + torch.randn(7, 8, requires_grad=requires_grad), + ], + device=tensor_device, + ) def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) @@ -2376,7 +2705,9 @@ def test_as_nested_tensor_propagates_gradients(self, device): a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) nt2 = torch.nested.as_nested_tensor([a, b]) - fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) + fake_grad = torch.nested.nested_tensor( + [torch.ones_like(a), torch.zeros_like(b)], device=device + ) nt2.backward(fake_grad) self.assertEqual(a.grad, fake_grad[0]) self.assertEqual(b.grad, fake_grad[1]) @@ -2393,7 +2724,9 @@ def test_nested_tensor_generates_leaf(self, device): self.assertTrue(nt2.is_leaf) self.assertTrue(nt2.requires_grad) - fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) + fake_grad = torch.nested.nested_tensor( + [torch.ones_like(a), torch.zeros_like(b)], device=device + ) nt2.backward(fake_grad) self.assertEqual(nt2.grad, fake_grad) self.assertEqual(a.grad, None) @@ -2443,17 +2776,33 @@ def test_backward_for_sub_op(self, device): self.assertEqual(nt_2.grad, -1 * grad_output) def test_backward_sub_strided(self, device): - a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device) - b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device) - c = a - b.transpose(-1, -2) - grad_output = c.clone() + a = torch.nested.nested_tensor( + [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], + requires_grad=True, + device=device, + ) + b = torch.nested.nested_tensor( + [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], + requires_grad=True, + device=device, + ) + c = a - b.transpose(-1, -2) + grad_output = c.clone() c.backward(grad_output) self.assertEqual(a.grad, grad_output) self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2)) def test_backward_add_strided(self, device): - a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device) - b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device) + a = torch.nested.nested_tensor( + [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], + requires_grad=True, + device=device, + ) + b = torch.nested.nested_tensor( + [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], + requires_grad=True, + device=device, + ) c = a + b.transpose(-1, -2) grad_output = c.clone() c.backward(grad_output) @@ -2463,13 +2812,20 @@ def test_backward_add_strided(self, device): # Test Factory Functions def test_nested_tensor_to_padded_tensor(self, device): for padding_val in [0, 1]: - nt = self._create_leaf_nested_tensor_from_list(tensor_device=device, requires_grad=True) + nt = self._create_leaf_nested_tensor_from_list( + tensor_device=device, requires_grad=True + ) out = torch.nested.to_padded_tensor(nt, padding_val) grad_output = torch.ones(out.shape, device=device) out.backward(grad_output) - self.assertEqual(nt.grad, torch.nested.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)], device=device)) + self.assertEqual( + nt.grad, + torch.nested.nested_tensor( + [torch.ones(1, 2), torch.ones(7, 8)], device=device + ), + ) def test_nested_tensor_from_mask_and_to_padded(self, device): N, L, D = 2, 4, 4 @@ -2481,12 +2837,15 @@ def test_nested_tensor_from_mask_and_to_padded(self, device): mask[0, :] = 1 mask = mask.bool() - data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64, device=device) + data = torch.randn( + N, L, D, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(inpt): nt = torch._nested_tensor_from_mask(inpt, mask) # This implicitly tests to_padded_tensor grads return torch.nested.to_padded_tensor(nt, 0) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) def test_nested_tensor_from_padded(self, device): @@ -2496,7 +2855,9 @@ def test_nested_tensor_from_padded(self, device): padded_tensor.requires_grad_() def grad_test_func(tensor, nested_size): - nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=False) + nt = torch._nested_from_padded( + tensor, nested_size, fuse_transform_0213=False + ) # This implicitly tests to_padded_tensor grads return torch.nested.to_padded_tensor(nt, 0) @@ -2510,14 +2871,16 @@ def test_nested_tensor_from_padded_fused(self, device): padded_tensor.requires_grad_() def grad_test_func(tensor, nested_size): - nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=True) + nt = torch._nested_from_padded( + tensor, nested_size, fuse_transform_0213=True + ) # This implicitly tests to_padded_tensor grads return torch.nested.to_padded_tensor(nt, 0) + data = (padded_tensor, nested_size) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) def test_nested_tensor_from_list(self, device): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) @@ -2526,20 +2889,29 @@ def grad_test_func(a, b, c): c = torch.nested.as_nested_tensor([a, b, c]) # This implictily tests to_padded_tensor grads return torch.nested.to_padded_tensor(c, 0) + data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @decorateIf( xfailIfTorchDynamo, # only fails in python 3.11. TODO: Debug this! - lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11) + lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11), ) @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) def test_dropout_backward(self, layout): if layout == torch.jagged: - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 5))], requires_grad=True, layout=layout) + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 5))], + requires_grad=True, + layout=layout, + ) else: - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, layout=layout) + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 4))], + requires_grad=True, + layout=layout, + ) p = 0.2 y = torch.nn.functional.dropout(nt, p) y.backward(nt.clone().detach()) @@ -2561,8 +2933,16 @@ def grad_test_func(a, b, c, d): assert torch.autograd.gradcheck(grad_test_func, inputs=data) def test_nested_tensor_bmm_backward(self, device): - nt0 = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) - nt1 = torch.nested.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))], requires_grad=True, device=device) + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], + requires_grad=True, + device=device, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((6, 4)), torch.randn((6, 5))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) @@ -2591,8 +2971,16 @@ def grad_test_func(a, b, c, d): assert torch.autograd.gradcheck(grad_test_func, inputs=data) def test_nested_tensor_matmul_backward(self, device): - nt0 = torch.nested.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], requires_grad=True, device=device) - nt1 = torch.nested.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], requires_grad=True, device=device) + nt0 = torch.nested.nested_tensor( + [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], + requires_grad=True, + device=device, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) @@ -2618,7 +3006,11 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) def test_nested_tensor_transpose_backward(self, device): - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, device=device) + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 4))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2642,7 +3034,9 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) def test_nested_tensor_reshape_backward(self): - nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True) + nt = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2654,7 +3048,11 @@ def test_nested_tensor_reshape_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) def test_nested_tensor_squeeze_backward(self, device): - nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True, device=device) + nt = torch.nested.nested_tensor( + [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2666,8 +3064,12 @@ def test_nested_tensor_squeeze_backward(self, device): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) def test_nested_tensor_squeeze_gradcheck(self, device): - a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True, device=device) - b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True, device=device) + a = torch.randn( + (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device + ) + b = torch.randn( + (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device + ) def grad_test_func(a, b): nt = torch.nested.as_nested_tensor([a, b]) @@ -2677,7 +3079,11 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) def test_nested_tensor_unsqueeze_backward(self, device): - nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) + nt = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], + requires_grad=True, + device=device, + ) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -2700,12 +3106,13 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) def test_nested_tensor_linear(self, device): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + weight = torch.randn( + 2, 2, requires_grad=True, dtype=torch.float64, device=device + ) bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, weight, bias=None): @@ -2713,6 +3120,7 @@ def grad_test_func(a, b, c, weight, bias=None): # This implicitly tests to_padded_tensor grads d = torch.functional.F.linear(nt, weight, bias) return torch.nested.to_padded_tensor(d, 0) + data = (a, b, c, weight, bias) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @@ -2725,7 +3133,9 @@ def test_nested_tensor_linear_plus_transpose(self, device): b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + weight = torch.randn( + 2, 2, requires_grad=True, dtype=torch.float64, device=device + ) bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, weight, bias=None): @@ -2734,6 +3144,7 @@ def grad_test_func(a, b, c, weight, bias=None): d = torch.functional.F.linear(nt, weight, bias) d = d.transpose(-1, -2).contiguous() return torch.nested.to_padded_tensor(d, 0) + data = (a, b, c, weight, bias) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @@ -2843,7 +3254,9 @@ def test_indexing_backward(self, device): self.assertEqual(nt[-1], x1) grad_x0 = torch.randn((2, 5), device=device) nt[0].backward(grad_x0) - expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)]) + expected_grad = torch.nested.nested_tensor( + [grad_x0, torch.zeros((3, 4), device=device)] + ) self.assertEqual(nt.grad, expected_grad) def test_masked_fill_backward(self, device): @@ -2857,6 +3270,7 @@ def grad_test_func(a, b, c): out = nt.masked_fill(mask, 0) out = torch.nested.to_padded_tensor(out, 0) return out + data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) @@ -2916,9 +3330,13 @@ def grad_test_func(a, b, c): # NotImplementedError: Cannot access storage of UndefinedTensorImpl def test_layer_norm_backward_edge_case(self, device): size = 4 - a = torch.randn(1, 2, size, requires_grad=False, dtype=torch.float64, device=device) + a = torch.randn( + 1, 2, size, requires_grad=False, dtype=torch.float64, device=device + ) nt = torch.nested.nested_tensor([a]) - nt_layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64) + nt_layer_norm = torch.nn.LayerNorm( + nt.size(-1), device=device, dtype=torch.float64 + ) out = nt_layer_norm(nt) out.backward(out.clone()) @@ -2939,13 +3357,21 @@ def grad_test_func(a, b): @skipIfSlowGradcheckEnv @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2]) def test_layer_norm_backward(self, device, size): - a = torch.randn(1, 2, size, requires_grad=True, dtype=torch.float64, device=device) - b = torch.randn(2, 2, size, requires_grad=True, dtype=torch.float64, device=device) - c = torch.randn(3, 2, size, requires_grad=True, dtype=torch.float64, device=device) + a = torch.randn( + 1, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 2, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 3, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) - layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64) + layer_norm = torch.nn.LayerNorm( + nt.size(-1), device=device, dtype=torch.float64 + ) nt_layer_norm = layer_norm(nt) return torch.nested.to_padded_tensor(nt_layer_norm, 0) @@ -2957,23 +3383,33 @@ def grad_test_func(a, b, c): # Could either mark slow or reduce size @parametrize("size", [128, 32, 4, 2]) def test_layer_norm_backward_5d(self, device, size): - a = torch.randn(4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) - b = torch.randn(7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) - c = torch.randn(10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) + a = torch.randn( + 4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) - layer_norm = torch.nn.LayerNorm((size, size, nt.size(-1)), device=device, dtype=torch.float64) + layer_norm = torch.nn.LayerNorm( + (size, size, nt.size(-1)), device=device, dtype=torch.float64 + ) nt_layer_norm = layer_norm(nt) return torch.nested.to_padded_tensor(nt_layer_norm, 0) data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + # Found in torch/testing/_comparison.py default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} + def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: deviation = true_value - computed_value deviation = torch.abs(deviation / true_value) @@ -3006,6 +3442,7 @@ def get_tolerances( rtol = default_rtol[computed_value.dtype] return atol, rtol + # We can probably parametrizing existing tests instead of having a separate # test class as we begin to support more ops. Also maybe rewrite with OpInfos. @markDynamoStrictTest @@ -3016,16 +3453,25 @@ def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): out = [] for s in nested_size[0]: out.append( - torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64) + torch.randn( + s, + *Ds, + requires_grad=requires_grad, + device=device, + dtype=torch.float64, + ) ) return out - def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True): - - def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True): + def _get_example_tensor_lists( + self, include_list_of_lists=True, include_requires_grad=True + ): + def _make_tensor( + *shape, include_requires_grad=include_requires_grad, requires_grad=True + ): return torch.randn( *shape, - requires_grad=(requires_grad if include_requires_grad else False) + requires_grad=(requires_grad if include_requires_grad else False), ) # Purposefully introduce mixed requires_grad settings for the components @@ -3036,7 +3482,7 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g _make_tensor(2, 5), _make_tensor(3, 5, requires_grad=False), _make_tensor(4, 5, requires_grad=False), - _make_tensor(6, 5) + _make_tensor(6, 5), ], # (B, *, D_0, D_1) with B=5 [ @@ -3046,6 +3492,15 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g _make_tensor(5, 5, 6), _make_tensor(6, 5, 6), ], + # (B, *, D_0, D_1, D_2) with B=6 + [ + _make_tensor(2, 5, 6, 7), + _make_tensor(3, 5, 6, 7), + _make_tensor(4, 5, 6, 7, requires_grad=False), + _make_tensor(5, 5, 6, 7), + _make_tensor(6, 5, 6, 7), + _make_tensor(7, 5, 6, 7), + ], ] if include_list_of_lists: @@ -3055,7 +3510,8 @@ def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_g _make_tensor(2, 5, requires_grad=False).tolist(), _make_tensor(3, 5).tolist(), _make_tensor(4, 5).tolist(), - ]) + ] + ) return example_lists @@ -3077,11 +3533,14 @@ def test_tensor_attributes(self, device): ): op(nt) - with self.assertRaisesRegex(RuntimeError, - "directly calling torch.ops.aten.size"): + with self.assertRaisesRegex( + RuntimeError, "directly calling torch.ops.aten.size" + ): torch.ops.aten.size.default(nt) - nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1) + nested_int = torch.nested._internal.nested_tensor.get_tensor_symint( + _offsets, coeff=1 + ) self.assertEqual(nt.size(), (3, nested_int, 3)) self.assertEqual(nt.shape, (3, nested_int, 3)) self.assertEqual(nt.dim(), 3) @@ -3091,7 +3550,9 @@ def test_linear(self, device): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + weight = torch.randn( + 4, 3, requires_grad=True, dtype=torch.float64, device=device + ) def grad_test_func(a, b, c, weight): nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) @@ -3114,19 +3575,30 @@ def grad_test_func(a, b, c): def test_unary_pointwise_transposed_inputs(self, device): a, b, c = ( - torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) ) - nt = torch.nested.nested_tensor([a.detach(), b.detach(), c.detach()], layout=torch.jagged) + nt = torch.nested.nested_tensor( + [a.detach(), b.detach(), c.detach()], layout=torch.jagged + ) nt_t = nt.transpose(1, 2) self.assertFalse(nt_t.is_contiguous()) out = torch.nn.functional.silu(nt_t.sin().cos()) - self.assertEqual(out.is_contiguous(), torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous()) + self.assertEqual( + out.is_contiguous(), + torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(), + ) self.assertEqual(nt_t.shape, out.shape) a, b, c = ( - torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) ) def grad_test_func(a, b, c): @@ -3137,7 +3609,6 @@ def grad_test_func(a, b, c): gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) - def test_binary_pointwise(self, device): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) @@ -3151,7 +3622,8 @@ def test_binary_pointwise(self, device): self.assertRaisesRegex( RuntimeError, "cannot call binary pointwise function .* with inputs of shapes", - lambda: nt1 * nt2) + lambda: nt1 * nt2, + ) # Correct usage: chain the calls using the same offsets tensor object def grad_test_func(a, b, c): @@ -3186,7 +3658,10 @@ def test_binary_pointwise_transposed(self, device): ) a, b, c = ( - torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) ) # Correct usage: chain the calls using the same offsets tensor object @@ -3210,11 +3685,15 @@ def test_split(self, device): self.assertEqual(len(out), 2) self.assertEqual( out[0], - torch.nested.as_nested_tensor([a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged + ), ) self.assertEqual( out[1], - torch.nested.as_nested_tensor([a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged + ), ) with self.assertRaisesRegex( @@ -3233,11 +3712,15 @@ def test_split_with_sizes(self, device): self.assertEqual(len(out), 2) self.assertEqual( out[0], - torch.nested.as_nested_tensor([a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged + ), ) self.assertEqual( out[1], - torch.nested.as_nested_tensor([a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged) + torch.nested.as_nested_tensor( + [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged + ), ) with self.assertRaisesRegex( RuntimeError, @@ -3248,7 +3731,8 @@ def test_split_with_sizes(self, device): def test_views_inherit_ragged_dim(self, device): # view nt = random_nt_from_dims( - [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged) + [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged + ) # inherit ragged dim via -1 view = nt.view(4, -1, 80) self.assertEqual(nt.shape[1], view.shape[1]) @@ -3258,20 +3742,25 @@ def test_views_inherit_ragged_dim(self, device): # expand nt = random_nt_from_dims( - [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged + ) # inherit batch and ragged dims via -1 view = nt.expand(-1, -1, 5) self.assertEqual(nt.shape[:2], view.shape[:2]) def test_view_ragged_idx_not_one(self, device): - nt = random_nt_from_dims([2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged + ) view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1)) self.assertEqual((2, 20, nt.size(1)), (view_transposed.size())) self.assertEqual(view_transposed._base, nt._base) def test_unsafe_view(self, device): - nt = random_nt_from_dims([4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged + ) # basic view view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80)) self.assertEqual((4, nt.size(1), 80), tuple(view1.size())) @@ -3288,12 +3777,16 @@ def test_unsafe_view(self, device): @parametrize("requires_grad", [False, True]) def test_reshape_decomp(self, device, requires_grad): # contiguous NT should result in view. - nt = random_nt_from_dims( - [3, None, 10], - device=device, - dtype=torch.float32, - layout=torch.jagged, - ).detach().requires_grad_(requires_grad) + nt = ( + random_nt_from_dims( + [3, None, 10], + device=device, + dtype=torch.float32, + layout=torch.jagged, + ) + .detach() + .requires_grad_(requires_grad) + ) view = nt.reshape(-1, -1, 5, 2) self.assertEqual(view.shape[:2], nt.shape[:2]) self.assertTrue(view._is_view() and view._base is nt) @@ -3308,7 +3801,7 @@ def test_reshape_decomp(self, device, requires_grad): device=device, dtype=torch.float32, layout=torch.jagged, - requires_grad=requires_grad + requires_grad=requires_grad, ) nt_noncontig = nt.transpose(-1, -2) self.assertFalse(nt_noncontig.is_contiguous()) @@ -3322,12 +3815,14 @@ def test_reshape_decomp(self, device, requires_grad): def test_flatten_decomp(self, device): nt = random_nt_from_dims( - [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged + ) flattened = nt.flatten(-2, -1) self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape) nt = random_nt_from_dims( - [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged) + [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged + ) flattened = nt.flatten(-3, -2) self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape) @@ -3335,7 +3830,9 @@ def test_chunk(self, device): # normal case D = 30 B = 8 - nt = random_nt_from_dims([B, None, D], device=device, dtype=torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [B, None, D], device=device, dtype=torch.float32, layout=torch.jagged + ) NUM_CHUNKS = 3 chunks = nt.chunk(NUM_CHUNKS, dim=-1) self.assertEqual(len(chunks), NUM_CHUNKS) @@ -3351,12 +3848,17 @@ def test_chunk(self, device): self.assertEqual(chunks[i].shape[0], chunk_size) else: self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1)) - offsets_expected = nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] - nt._offsets[i * chunk_size] + offsets_expected = ( + nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] + - nt._offsets[i * chunk_size] + ) self.assertEqual(chunks[i]._offsets[1:], offsets_expected) self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0)) # chunk on ragged dim not supported - with self.assertRaisesRegex(RuntimeError, "chunk.* not supported for NestedTensor on dim=1"): + with self.assertRaisesRegex( + RuntimeError, "chunk.* not supported for NestedTensor on dim=1" + ): nt.chunk(2, dim=1) def test_squeeze(self, device): @@ -3364,7 +3866,8 @@ def test_squeeze(self, device): D = 6 # squeeze middle dim nt = random_nt_from_dims( - [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged) + [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged + ) j0 = nt.shape[1] for dim_arg in [-2, 2]: @@ -3374,7 +3877,8 @@ def test_squeeze(self, device): # squeeze last dim nt = random_nt_from_dims( - [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged) + [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged + ) j1 = nt.shape[1] for dim_arg in [-1, 2]: @@ -3384,17 +3888,21 @@ def test_squeeze(self, device): # squeeze on batch dim not supported with self.assertRaisesRegex( - RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"): + RuntimeError, "squeeze.* not supported for NestedTensor on dim=0" + ): nt.squeeze(0) # squeeze on ragged dim not supported with self.assertRaisesRegex( - RuntimeError, "squeeze.* not supported for NestedTensor on dim=1"): + RuntimeError, "squeeze.* not supported for NestedTensor on dim=1" + ): nt.squeeze(1) def test_binary_pointwise_broadcasting(self, device): # (B, j0, 3, 4) - ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True) + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device, requires_grad=True + ) # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?) @@ -3414,12 +3922,18 @@ def grad_test_func(t, *ts): return out.values() for t_size in t_sizes: - t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64) + t = torch.rand( + t_size, requires_grad=True, device=device, dtype=torch.float64 + ) gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) def test_threshold_backward(self, device): - ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) - ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) + ts1 = self._get_list_for_jagged_tensor( + ((2, 3, 4), 16), device=device, requires_grad=False + ) + ts2 = self._get_list_for_jagged_tensor( + ((2, 3, 4), 16), device=device, requires_grad=False + ) nt1, offsets = jagged_from_list(ts1, None) nt2, offsets = jagged_from_list(ts2, offsets) @@ -3431,11 +3945,12 @@ def test_threshold_backward(self, device): self.assertEqual(res_dense, res_nt.values()) - @parametrize("keepdim", [False, True]) def test_sum_int_DimList(self, device, keepdim): # (B, j0, 3, 4) - ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True) + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device=device, requires_grad=True + ) # Check shape correctness reduce_dims = ( @@ -3451,8 +3966,9 @@ def test_sum_int_DimList(self, device, keepdim): for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims: if (0 in rd) ^ (1 in rd): with self.assertRaisesRegex( - RuntimeError, - "applying over the ragged dimension, but not the batch dimension"): + RuntimeError, + "applying over the ragged dimension, but not the batch dimension", + ): nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) out = torch.sum(nt, dim=rd, keepdim=keepdim) continue @@ -3483,18 +3999,17 @@ def test_sum_int_DimList(self, device, keepdim): self.assertNotIsInstance(out, NestedTensor) self.assertTrue(torch.allclose(out, out_ref)) - - @dtypes(torch.float, torch.double, torch.half) @parametrize("requires_grad", [False, True]) @parametrize("weights_only", [False, True]) def test_serialization(self, device, dtype, requires_grad, weights_only): - def compare_metadata(nt1, nt2): self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size()) self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides()) - self.assertEqual(nt1._nested_tensor_storage_offsets(), - nt2._nested_tensor_storage_offsets()) + self.assertEqual( + nt1._nested_tensor_storage_offsets(), + nt2._nested_tensor_storage_offsets(), + ) nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) for a in [nt_contiguous, nt_noncontiguous]: @@ -3509,7 +4024,9 @@ def compare_metadata(nt1, nt2): self.assertEqual(b, nt_contiguous) self.assertEqual(b, nt_noncontiguous) - @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @unittest.skipIf( + PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" + ) @onlyCUDA def test_pin_memory(self, device): nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) @@ -3524,7 +4041,9 @@ def test_pin_memory(self, device): self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) @torch.compiler.disable - def _validate_nt(self, nt, device, dtype, layout, requires_grad, dim, batch_size, base=None): + def _validate_nt( + self, nt, device, dtype, layout, requires_grad, dim, batch_size, base=None + ): # Validate a bunch of properties after NT construction. device = torch.device(device) self.assertEqual(nt.dim(), dim) @@ -3546,20 +4065,30 @@ def _validate_nt(self, nt, device, dtype, layout, requires_grad, dim, batch_size @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) def test_jagged_layout_construction_nested_tensor( - self, device, dtype, requires_grad, components_require_grad): + self, device, dtype, requires_grad, components_require_grad + ): for tensor_list in self._get_example_tensor_lists( - include_list_of_lists=True, include_requires_grad=components_require_grad): + include_list_of_lists=True, include_requires_grad=components_require_grad + ): nt = torch.nested.nested_tensor( tensor_list, device=device, dtype=dtype, layout=torch.jagged, - requires_grad=requires_grad) + requires_grad=requires_grad, + ) expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 expected_batch_size = len(tensor_list) self._validate_nt( - nt, device, dtype, torch.jagged, requires_grad, expected_dim, expected_batch_size) + nt, + device, + dtype, + torch.jagged, + requires_grad, + expected_dim, + expected_batch_size, + ) # Make sure grads -don't- flow back into original tensors for nested_tensor() if requires_grad: @@ -3571,15 +4100,15 @@ def test_jagged_layout_construction_nested_tensor( @dtypes(torch.float, torch.double, torch.half) @parametrize("components_require_grad", [False, True]) def test_jagged_layout_construction_as_nested_tensor( - self, device, dtype, components_require_grad): + self, device, dtype, components_require_grad + ): # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list for tensor_list in self._get_example_tensor_lists( - include_list_of_lists=False, include_requires_grad=components_require_grad): + include_list_of_lists=False, include_requires_grad=components_require_grad + ): nt = torch.nested.as_nested_tensor( - tensor_list, - device=device, - dtype=dtype, - layout=torch.jagged) + tensor_list, device=device, dtype=dtype, layout=torch.jagged + ) # nt.requires_grad=True should be set if at least one component requires grad expected_dim = tensor_list[0].dim() + 1 @@ -3591,7 +4120,8 @@ def test_jagged_layout_construction_as_nested_tensor( torch.jagged, components_require_grad, expected_dim, - expected_batch_size) + expected_batch_size, + ) # Make sure grads flow back into original tensors for as_nested_tensor() if components_require_grad: @@ -3603,15 +4133,15 @@ def test_jagged_layout_construction_as_nested_tensor( self.assertTrue(t.grad is None) @xfailIfTorchDynamo - @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @unittest.skipIf( + PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" + ) @onlyCUDA def test_jagged_layout_construction_with_pinned_memory(self, device): for tensor_list in self._get_example_tensor_lists(): nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device="cpu", - pin_memory=True) + tensor_list, layout=torch.jagged, device="cpu", pin_memory=True + ) expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 expected_batch_size = len(tensor_list) @@ -3622,20 +4152,26 @@ def test_jagged_layout_construction_with_pinned_memory(self, device): layout=torch.jagged, requires_grad=False, dim=expected_dim, - batch_size=expected_batch_size) + batch_size=expected_batch_size, + ) self.assertTrue(nt.is_pinned()) @dtypes(torch.float, torch.double, torch.half) @parametrize("requires_grad", [False, True]) @parametrize("values_is_view", [False, True]) - def test_jagged_view_from_values_offsets(self, device, dtype, requires_grad, values_is_view): + def test_jagged_view_from_values_offsets( + self, device, dtype, requires_grad, values_is_view + ): if values_is_view: # make values a view of base base = torch.randn( - 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad) + 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad + ) values = base.flatten(0, -2) else: - values = torch.randn(10, 5, device=device, dtype=dtype, requires_grad=requires_grad) + values = torch.randn( + 10, 5, device=device, dtype=dtype, requires_grad=requires_grad + ) offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) nt = nested_view_from_values_offsets(values, offsets) @@ -3644,9 +4180,15 @@ def test_jagged_view_from_values_offsets(self, device, dtype, requires_grad, val expected_batch_size = offsets.shape[0] - 1 expected_base = base if values_is_view else values self._validate_nt( - nt, device, dtype, torch.jagged, requires_grad, expected_dim, expected_batch_size, + nt, + device, + dtype, + torch.jagged, + requires_grad, + expected_dim, + expected_batch_size, # ensure NT is a proper view - base=expected_base + base=expected_base, ) if requires_grad: @@ -3676,7 +4218,9 @@ def test_nested_tensor_from_jagged(self, device, dtype): # construct from (values, offsets, lengths) lengths = torch.tensor([2, 1, 1, 2], device=device) - nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets, lengths=lengths) + nt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, lengths=lengths + ) self.assertTrue(isinstance(nt, NestedTensor)) self.assertTrue(nt._is_view() and nt._base is values) self.assertEqual(nt.dim(), 3) @@ -3698,32 +4242,44 @@ def test_nested_tensor_from_jagged(self, device, dtype): # for now, if only lengths is specified, convert to offsets to integrate best with the # existing kernels expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device) - expected_nt = torch.nested.nested_tensor_from_jagged(values, offsets=expected_offsets) + expected_nt = torch.nested.nested_tensor_from_jagged( + values, offsets=expected_offsets + ) for n1, n2 in zip(nt.unbind(), expected_nt.unbind()): self.assertEqual(n1, n2) # error case: no offsets or lengths - with self.assertRaisesRegex(RuntimeError, "At least one of offsets or lengths is required"): + with self.assertRaisesRegex( + RuntimeError, "At least one of offsets or lengths is required" + ): torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) @dtypes(torch.float, torch.double, torch.half) @parametrize("dim", range(5)) - @parametrize("layout", [torch.strided, torch.jagged], - name_fn=lambda l: f"layout_{str(l).split('.')[1]}") + @parametrize( + "layout", + [torch.strided, torch.jagged], + name_fn=lambda l: f"layout_{str(l).split('.')[1]}", + ) @parametrize("requires_grad", [False, True]) @parametrize("contiguous", [False, True]) def test_as_nested_tensor_from_tensor( - self, device, dtype, dim, layout, requires_grad, contiguous): + self, device, dtype, dim, layout, requires_grad, contiguous + ): if dim == 0: - t = torch.tensor(3., requires_grad=requires_grad) + t = torch.tensor(3.0, requires_grad=requires_grad) else: t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad) assert t.dim() == dim if dim < 2: # 0-1 dim tensors can't be converted to NTs - with self.assertRaisesRegex(RuntimeError, "Expected tensor argument to have dim"): - nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout) + with self.assertRaisesRegex( + RuntimeError, "Expected tensor argument to have dim" + ): + nt = torch.nested.as_nested_tensor( + t, device=device, dtype=dtype, layout=layout + ) return orig_t = t @@ -3734,7 +4290,8 @@ def test_as_nested_tensor_from_tensor( expected_dim = t.dim() expected_batch_size = t.size(0) self._validate_nt( - nt, device, dtype, layout, requires_grad, expected_dim, expected_batch_size) + nt, device, dtype, layout, requires_grad, expected_dim, expected_batch_size + ) if torch.device(device) == t.device and dtype == t.dtype and contiguous: # should be the non-copying (view) case @@ -3742,18 +4299,24 @@ def test_as_nested_tensor_from_tensor( # should be equivalent to construction from unbound tensor list nt_from_unbind = torch.nested.as_nested_tensor( - list(t.unbind(0)), device=device, dtype=dtype, layout=layout) + list(t.unbind(0)), device=device, dtype=dtype, layout=layout + ) self.assertEqual(nt, nt_from_unbind) # ensure call on a NT with the same properties returns the NT directly - nt2 = torch.nested.as_nested_tensor(nt, device=device, dtype=dtype, layout=layout) + nt2 = torch.nested.as_nested_tensor( + nt, device=device, dtype=dtype, layout=layout + ) self.assertTrue(nt is nt2) # we don't support conversion between layouts this way atm other_layout = torch.strided if layout == torch.jagged else torch.jagged with self.assertRaisesRegex( - RuntimeError, "Converting between nested tensor layouts is not supported"): - torch.nested.as_nested_tensor(nt, device=device, dtype=dtype, layout=other_layout) + RuntimeError, "Converting between nested tensor layouts is not supported" + ): + torch.nested.as_nested_tensor( + nt, device=device, dtype=dtype, layout=other_layout + ) if requires_grad: # make sure gradients flow back into inputs @@ -3767,10 +4330,8 @@ def test_device_dtype_transfer_updates_offsets(self, device, dtype): orig_device = torch.device("cpu") orig_dtype = torch.float32 nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device=orig_device, - dtype=orig_dtype) + tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype + ) self.assertEqual(torch.int64, nt.offsets().dtype) nt = nt.to(device=device).to(dtype=dtype) @@ -3782,14 +4343,153 @@ def test_device_dtype_transfer_updates_offsets(self, device, dtype): def test_unbind(self, device): for tensor_list in self._get_example_tensor_lists(): nt = torch.nested.nested_tensor( - tensor_list, - layout=torch.jagged, - device=device) + tensor_list, layout=torch.jagged, device=device + ) # ragged_idx = 1 out = nt.unbind() self.assertEqual(len(out), len(tensor_list)) for i, t in enumerate(out): self.assertEqual(t, tensor_list[i]) + @parametrize("ragged_idx", [2, 3]) + def test_unbind_transpose(self, device, ragged_idx): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device=device + ) + if ragged_idx < nt.dim(): + nt = nt.transpose(1, ragged_idx) # set ragged_idx + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual( + t.transpose(0, ragged_idx - 1), tensor_list[i] + ) # transpose back each element of result + + def test_unbind_transpose_ragged_idx_last_dim(self, device): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device=device + ).transpose( + 1, -1 + ) # set ragged_idx = last dimension + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual( + t.transpose(0, -1), tensor_list[i] + ) # transpose back each element of result + + def test_unbind_lengths(self, device): + values = torch.randn(16, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + nt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, lengths=lengths + ) # 3D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])]) + + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_1(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + ragged_idx = 1 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + ragged_idx = 2 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + self.assertRaisesRegex( + RuntimeError, + r"unbind\(\): nested tensor offsets and lengths.*", + lambda: nt.unbind(), + ) + + def test_unbind_lengths_ragged_idx_2(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 2, 4, 8], device=device) + lengths = torch.tensor([2, 1, 3], device=device) + ragged_idx = 2 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_3(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 100, 128], device=device) + lengths = torch.tensor([50, 28], device=device) + ragged_idx = 3 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + @skipIfTorchDynamo( + "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch" + ) + def test_unbind_lengths_ragged_idx_0(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 100, 128], device=device) + lengths = torch.tensor([50, 28], device=device) + ragged_idx = 0 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) + + self.assertRaisesRegex( + RuntimeError, + r"unbind\(\): nested tensor.*out of bounds", + lambda: nt.unbind(), + ) + @xfailIfTorchDynamo def test_layer_norm_2(self, device): test_tensor_list = self._get_list_for_jagged_tensor( @@ -3818,15 +4518,12 @@ def test_narrow(self, device): lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) buffer = ( torch.arange(0, 10, device=device, dtype=torch.int64) - .unsqueeze(0).expand(5, -1).clone().detach() - ) - nt = torch.nested.narrow( - buffer, - 1, - starts, - lengths, - layout=torch.jagged + .unsqueeze(0) + .expand(5, -1) + .clone() + .detach() ) + nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged) self.assertTrue(nt._is_view() and nt._base is buffer) @@ -3836,8 +4533,10 @@ def test_narrow(self, device): # self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i]) for i in range(starts.shape[0]): self.assertEqual( - torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), - nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])] + torch.arange( + starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64 + ), + nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])], ) def test_is_contiguous(self, device): @@ -3848,23 +4547,20 @@ def test_is_contiguous(self, device): starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) - narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone() + narrow_base = ( + torch.arange(0, 10, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(5, -1) + .clone() + ) nt_noncontiguous = torch.nested.narrow( - narrow_base, - 1, - starts_nc, - lengths_nc, - layout=torch.jagged + narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged ) starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64) lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64) nt_contiguous_narrow = torch.nested.narrow( - narrow_base, - 1, - starts_c, - lengths_c, - layout=torch.jagged + narrow_base, 1, starts_c, lengths_c, layout=torch.jagged ) # Test contiguous case @@ -3875,23 +4571,36 @@ def test_is_contiguous(self, device): assert nt_contiguous_narrow.is_contiguous() # Test querying by memory_format - self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue( + nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format) + ) def test_layout_under_torch_dispatch_mode(self): - from torch.testing._internal.logging_tensor import capture_logs_with_logging_tensor_mode + from torch.testing._internal.logging_tensor import ( + capture_logs_with_logging_tensor_mode, + ) - nt = random_nt_from_dims([2, None, 3], torch.device('cpu'), torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) with capture_logs_with_logging_tensor_mode(): self.assertEqual(nt.layout, torch.jagged) @skipIfTorchDynamo("Not a suitable test for TorchDynamo") - @parametrize("func", [torch.empty_like, torch.randn_like], - name_fn=lambda f: f.__name__) + @parametrize( + "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__ + ) def test_like_shape(self, func): - nt = random_nt_from_dims([2, None, 3], torch.device('cpu'), torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) nt_like = func(nt) for nt_ub in nt_like.unbind(): @@ -3899,10 +4608,13 @@ def test_like_shape(self, func): self.assertEqual(nt_ub.shape, t_like.shape) @skipIfTorchDynamo("Not a suitable test for TorchDynamo") - @parametrize("func", [torch.ones_like, torch.zeros_like], - name_fn=lambda f: f.__name__) + @parametrize( + "func", [torch.ones_like, torch.zeros_like], name_fn=lambda f: f.__name__ + ) def test_like_value(self, func): - nt = random_nt_from_dims([2, None, 3], torch.device('cpu'), torch.float32, layout=torch.jagged) + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) nt_like = func(nt) for nt_ub in nt_like.unbind(): @@ -3938,8 +4650,13 @@ def check_nt_equality(x, y): def test_to_copy(self, device): nt = torch.nested.nested_tensor( - [torch.randn(i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) - for i in range(3)], layout=torch.jagged + [ + torch.randn( + i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ], + layout=torch.jagged, ) nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16) @@ -3966,16 +4683,20 @@ def test_profiler_sequence_nr(self): fwd_seq_nrs = [] for evt in prof.events(): - if "linear" in evt.name.lower() and "backward" not in evt.name.lower() and evt.sequence_nr != -1: + if ( + "linear" in evt.name.lower() + and "backward" not in evt.name.lower() + and evt.sequence_nr != -1 + ): fwd_seq_nrs.append(evt.sequence_nr) bwd_seq_nrs = [] for evt in prof.events(): if ( - "linear" in evt.name.lower() and - "backward" in evt.name.lower() and - "evaluate_function" not in evt.name.lower() and - evt.sequence_nr != -1 + "linear" in evt.name.lower() + and "backward" in evt.name.lower() + and "evaluate_function" not in evt.name.lower() + and evt.sequence_nr != -1 ): bwd_seq_nrs.append(evt.sequence_nr) @@ -3990,7 +4711,12 @@ def test_profiler_sequence_nr(self): def test_is_same_size(self, device): def get_3_tensors(): - return [torch.randn(i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)] + return [ + torch.randn( + i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ] nt1, offsets1 = jagged_from_list(get_3_tensors(), None) nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1) @@ -4008,6 +4734,69 @@ def check_size(nt1, nt2, nt3, nt4): nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) check_size(nt1_t, nt2_t, nt3_t, nt4_t) + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape(self, device): + values = torch.randn((18, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device) + like_values = torch.randn_like(values) + + # this marks values as dynamic + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + def fn(values, same_size): + # here, the dynamic shape is specialized by same_size's shape + # https://github.com/pytorch/pytorch/issues/127097 + # make sure this doesn't error out in torch.compile + return values + same_size + + self.assertEqual( + fn(values, like_values), + torch.compile(fn)(values, like_values), + ) + + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape_recompile(self, device): + def generate_inp(total_len): + values = torch.randn((total_len, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device) + like_values = torch.randn_like(values) + return values, offsets, like_values + + def check_results(ref_fn, res_fn, args): + values, offsets, like_values = args + # this may add dynamic shape markings + # goal of this test is to make sure that whatever markings are there, + # we eventually stop recompiling as shape changes. + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + self.assertEqual( + ref_fn(values, like_values), + res_fn(values, like_values), + ) + + def fn(values, same_size): + return values + same_size + + compile_counter = torch._dynamo.testing.CompileCounter() + + compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn) + check_results(fn, compiled_fn, generate_inp(18)) + self.assertEqual(compile_counter.frame_count, 1) + + check_results(fn, compiled_fn, generate_inp(19)) + # we'll probably recompile here with dynamic shapes - it's okay if not though. + frame_count_2 = compile_counter.frame_count + self.assertIn(frame_count_2, [1, 2]) + + # make sure that by now we've already compiled with dynamic shapes, so additional + # shapes should not trigger additional recompiles. + check_results(fn, compiled_fn, generate_inp(20)) + self.assertEqual(compile_counter.frame_count, frame_count_2) + # Doesn't work until we have real views @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA @@ -4016,8 +4805,12 @@ def check_size(nt1, nt2, nt3, nt4): TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32], + ) def test_sdpa(self, device, dtype): batch_size = 1 emb_dims = 128 @@ -4027,27 +4820,63 @@ def test_sdpa(self, device, dtype): sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) - query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + query = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + key = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + value = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) # Simplest case: 1 sentence, no batching x_d1 = sen1.unsqueeze(0) x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) # See note below for why we detach here. - q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + q_d1 = ( + query(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_d1_t = q_d1.transpose(1, 2) - k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + k_d1 = ( + key(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_d1_t = k_d1.transpose(1, 2) - v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + v_d1 = ( + value(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_d1_t = v_d1.transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_nt_t = q_nt.transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_nt_t = k_nt.transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_nt_t = v_nt.transpose(1, 2) # High Precision Math Reference @@ -4057,11 +4886,15 @@ def test_sdpa(self, device, dtype): q_d1_f32_t = q_d1_f32.transpose(1, 2) k_d1_f32_t = k_d1_f32.transpose(1, 2) v_d1_f32_t = v_d1_f32.transpose(1, 2) - out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32_t, k_d1_f32_t, v_d1_f32_t)[0] + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_f32_t, k_d1_f32_t, v_d1_f32_t + )[0] grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_t, k_d1_t, v_d1_t)[0] + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_t, k_d1_t, v_d1_t + )[0] grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) # Compute tolerances @@ -4072,10 +4905,19 @@ def test_sdpa(self, device, dtype): grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] - attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1_t, k_d1_t, v_d1_t).transpose(1, 2) - attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2) + attn_d1 = torch.nn.functional.scaled_dot_product_attention( + q_d1_t, k_d1_t, v_d1_t + ).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) - self.assertEqual(attn_d1, attn_nt.unbind()[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + attn_d1, + attn_nt.unbind()[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) # Simple case: 2 sentences, no extra params x_d2 = sen2.unsqueeze(0) @@ -4084,46 +4926,106 @@ def test_sdpa(self, device, dtype): # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before # it is transposed. This is because today we cannot backward through view or unbind a # transposed tensor. - q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + q_d2 = ( + query(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_d2_t = q_d2.transpose(1, 2) - k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + k_d2 = ( + key(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_d2_t = k_d2.transpose(1, 2) - v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) + v_d2 = ( + value(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_d2_t = v_d2.transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) q_nt_t = q_nt.transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) k_nt_t = k_nt.transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) v_nt_t = v_nt.transpose(1, 2) - attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2_t, k_d2_t, v_d2_t).transpose(1, 2) + attn_d2 = torch.nn.functional.scaled_dot_product_attention( + q_d2_t, k_d2_t, v_d2_t + ).transpose(1, 2) d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) def check_forward_backward(): - attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) attn_nts = attn_nt.unbind() - self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) - self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + attn_d1, + attn_nts[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + self.assertEqual( + attn_d2, + attn_nts[1].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt)) - for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols): + for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( + nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols + ): unbound_nt_grads = nt_grad.unbind() - self.assertEqual(d1_grad, unbound_nt_grads[0].unsqueeze(0), atol=grad_atol, rtol=grad_rtol) - self.assertEqual(d2_grad, unbound_nt_grads[1].unsqueeze(0), atol=grad_atol, rtol=grad_rtol) + self.assertEqual( + d1_grad, + unbound_nt_grads[0].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) + self.assertEqual( + d2_grad, + unbound_nt_grads[1].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) # Default check_forward_backward() # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=True, enable_math=True + ): check_forward_backward() # Test math fallback - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Math fallback doesn't work with bfloat16 on CUDA because # "group_gemm_dispatch" not implemented for 'BFloat16' if not (str(device).startswith("cuda") and dtype == torch.bfloat16): @@ -4135,8 +5037,13 @@ def check_forward_backward(): # Guarding with sqrt() doesn't work on ROCm? @skipCUDAIfRocm @onlyCUDA - @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater - else [torch.float16, torch.float32])) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) def test_sdpa_compile(self, device, dtype): batch_size = 1 emb_dims = 1024 @@ -4146,9 +5053,15 @@ def test_sdpa_compile(self, device, dtype): sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) - query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) - value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) + query = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + key = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + value = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) # Simplest case: 1 sentence, no batching x_d1 = sen1.unsqueeze(0) @@ -4162,28 +5075,61 @@ def test_sdpa_compile(self, device, dtype): k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) - q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) - k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) - v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().transpose(1, 2) + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) # High Precision Math Reference q_d1_f32 = q_d1.to(torch.float32) k_d1_f32 = k_d1.to(torch.float32) v_d1_f32 = v_d1.to(torch.float32) - out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32, k_d1_f32, v_d1_f32)[0] + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_f32, k_d1_f32, v_d1_f32 + )[0] # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1, k_d1, v_d1)[0] + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1, k_d1, v_d1 + )[0] output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1, k_d1, v_d1).transpose(1, 2) - attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2, k_d2, v_d2).transpose(1, 2) + attn_d1 = torch.nn.functional.scaled_dot_product_attention( + q_d1, k_d1, v_d1 + ).transpose(1, 2) + attn_d2 = torch.nn.functional.scaled_dot_product_attention( + q_d2, k_d2, v_d2 + ).transpose(1, 2) compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) attn_nts = attn_nt.unbind() - self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) - self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + attn_d1, + attn_nts[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + self.assertEqual( + attn_d2, + attn_nts[1].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) @dtypes(torch.float32, torch.double, torch.half) def test_sdpa_with_constant_sequence_length(self, device, dtype): @@ -4193,14 +5139,17 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): # S: (constant) sequence length # D: embedding size query = random_nt_from_dims( - [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged) + [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged + ) key = random_nt_from_similar(query) value = random_nt_from_similar(query) output = F.scaled_dot_product_attention(query, key, value) self.assertTrue(isinstance(output, NestedTensor)) # should be equivalent to just running the buffers through - output_dense = F.scaled_dot_product_attention(query._values, key._values, value._values) + output_dense = F.scaled_dot_product_attention( + query._values, key._values, value._values + ) self.assertEqual(output._values, output_dense) # Doesn't work until we have real views @@ -4208,20 +5157,28 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): @onlyCUDA @unittest.skipIf( not PLATFORM_SUPPORTS_FUSED_ATTENTION, - "Platform doesn't support flash or mem-efficient attention" + "Platform doesn't support flash or mem-efficient attention", + ) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) ) - @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater - else [torch.float16, torch.float32])) def test_sdpa_with_packed_in_proj(self, device, dtype): # shape (B, *, D) input_packed = random_nt_from_dims( - [5, None, 10], device=device, dtype=dtype, layout=torch.jagged) + [5, None, 10], device=device, dtype=dtype, layout=torch.jagged + ) # Do input projection. num_heads = 2 # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient) head_dim = 8 - qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(device=device, dtype=dtype) + qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to( + device=device, dtype=dtype + ) def in_proj(input_packed, qkv_linear=qkv_linear): qkv_post_proj = qkv_linear(input_packed) @@ -4237,18 +5194,22 @@ def in_proj(input_packed, qkv_linear=qkv_linear): # compare to individually running unbound components through for in_component, out_component in zip( - input_packed.unbind(), - output.transpose(-2, -3).unbind() + input_packed.unbind(), output.transpose(-2, -3).unbind() ): q, k, v = in_proj(in_component) out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3) # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( - q, k, v)[0].transpose(-2, -3) - output_ref_atol, output_ref_rtol = get_tolerances(out, out_lp_ref, fudge_factor=2) + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[ + 0 + ].transpose(-2, -3) + output_ref_atol, output_ref_rtol = get_tolerances( + out, out_lp_ref, fudge_factor=2 + ) - self.assertEqual(out, out_component, atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual( + out, out_component, atol=output_ref_atol, rtol=output_ref_rtol + ) @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @@ -4256,8 +5217,13 @@ def in_proj(input_packed, qkv_linear=qkv_linear): # mha_varlen_fwd not supported on ROCm @skipCUDAIfRocm @onlyCUDA - @dtypes(*([torch.float16, torch.bfloat16, torch.float32] if SM80OrLater - else [torch.float16, torch.float32])) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) def test_sdpa_backwards(self, device, dtype): values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype) offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64) @@ -4299,7 +5265,6 @@ def __init__(self): self.linear = torch.nn.Linear(d2, d3, device=device) def forward(self, query, value, offsets): - value = self.linear(value) key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) value = convert_jagged_to_nested_tensor(value, offsets, max_length_2) @@ -4348,11 +5313,37 @@ def forward(self, query, value, offsets): self.assertTrue(torch.allclose(attn_output_eager, attn_output)) self.assertTrue(torch.allclose(value_grad, value.grad)) + @dtypes(torch.float64, torch.float32, torch.half) + @onlyCUDA + def test_fbgemm_jagged_to_padded_dense_kernels(self, device, dtype): + values = torch.randn(10, 5, device=device, dtype=dtype) + offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64) + max_length = offsets.diff().max().item() + padding_value = 1.3 + + # convert jagged -> padded dense + padded = torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets], [max_length], padding_value + ) + + batch_size = offsets.shape[0] - 1 + expected_padded_shape = (batch_size, max_length, values.shape[-1]) + self.assertEqual(padded.shape, expected_padded_shape) + + # convert padded dense -> jagged + total_L = values.shape[0] + output_jagged = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], total_L + ) + + # should be equivalent to the original values + self.assertEqual(values, output_jagged) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) instantiate_device_type_tests(TestNestedTensorAutograd, globals()) instantiate_device_type_tests(TestNestedTensorSubclass, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index d49c9bc1eec4..2553db01ee6b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1063,6 +1063,7 @@ def check(): self.assertRaises(NotImplementedError, module_dict) self.assertRaises(NotImplementedError, module_dict, torch.rand(1, 3)) + @skipIfTorchDynamo() def test_ParameterList(self): def make_param(): return Parameter(torch.randn(2, 2)) @@ -1594,19 +1595,29 @@ def add_one_inplace(t): finally: torch.__future__.set_overwrite_module_params_on_conversion(False) - def test_swap_module_params_fails_after_forward(self): + def test_swap_module_params_poisons_acc_grad(self): try: torch.__future__.set_swap_module_params_on_conversion(True) + # (1) backward cannot be run after _apply + # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors + # additionally, if any Tensors are saved for backward, their use_count will be bumped m = torch.nn.Linear(2, 3) inp = torch.randn(2, 2) - # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors out = m(inp) - with self.assertRaisesRegex(RuntimeError, re.escape("_apply(): Couldn't swap Linear.weight")): - m.half() - del out - # works as expected now m.half() self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters())) + with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"): + out.sum().backward() + # (2) _apply can be run after backward() + # After running backward, all the references generated by "save for backward" will be cleared + # So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors + # should allow this. + inp2 = torch.randn(2, 2, dtype=torch.half) + out2 = m(inp2) + out2.sum().backward() + m.float() + self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters())) + out3 = m(inp) finally: torch.__future__.set_swap_module_params_on_conversion(False) @@ -2271,58 +2282,6 @@ def test_state_dict(self): # Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545 self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata") - def _test_register_state_dict_pre_hook(self, model, submodule): - _state_dict_prefix = "foo." - state_dict_pre_hook_count = 0 - keep_var_setting = False - - def my_state_dict_pre_hook(module, prefix, keep_vars): - self.assertEqual(keep_vars, keep_var_setting) - nonlocal state_dict_pre_hook_count - state_dict_pre_hook_count += 1 - self.assertTrue(prefix.startswith(_state_dict_prefix)) - - model.register_state_dict_pre_hook(my_state_dict_pre_hook) - # Test to ensure submodules run the hook as well. - submodule.register_state_dict_pre_hook(my_state_dict_pre_hook) - - def check_results(model): - nonlocal state_dict_pre_hook_count, keep_var_setting - for keep_var_setting in [True, False]: - _ = model.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting) - self.assertEqual(2, state_dict_pre_hook_count) - state_dict_pre_hook_count = 0 - # Test state dict works as expected after model construction - check_results(model) - # Test state dict works as expected after forward - model(torch.ones(10, 3)) - check_results(model) - - def test_register_state_dict_pre_hook(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)) - - def forward(self, x): - return self.a(x) - - mod = MyModule() - self._test_register_state_dict_pre_hook(mod, mod.a) - - def test_register_state_dict_pre_hook_lazy_module(self): - class MyLazyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = nn.LazyLinear(8) - self.layer2 = nn.LazyLinear(5) - - def forward(self, x): - return self.layer2(self.layer1(x)) - - mod = MyLazyModule() - self._test_register_state_dict_pre_hook(mod, mod.layer1) - def test_extra_state(self): class SubModule(torch.nn.Module): @@ -8936,7 +8895,9 @@ def test_linear_empty(self, device): _test_module_empty_input(self, mod, inp) def test_one_hot(self, device): - if self.device_type != 'cuda': # cuda throws device assert for invalid data + # cuda throws device assert for invalid data + # xla ignores out of bound indices + if self.device_type != 'cuda' and self.device_type != 'xla': with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 940938c79dde..bf81701f37dc 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -476,13 +476,18 @@ def test_multiplication_numpy_scalar(self, device) -> None: self.assertTrue(r2.requires_grad) @onlyCPU - def test_parse_numpy_int(self, device): + @skipIfTorchDynamo() + def test_parse_numpy_int_overflow(self, device): + # assertRaises uses a try-except which dynamo has issues with # Only concrete class can be given where "Type[number[_64Bit]]" is expected self.assertRaisesRegex( RuntimeError, "(Overflow|an integer is required)", lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)), ) # type: ignore[call-overload] + + @onlyCPU + def test_parse_numpy_int(self, device): # https://github.com/pytorch/pytorch/issues/29252 for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]: scalar = 3 diff --git a/test/test_ops.py b/test/test_ops.py index 44f503ae9b6e..cbec88136ed2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2522,8 +2522,8 @@ def map_to_fake(e): or name in sometimes_dynamic_output_op_test ) self.assertTrue( - mode.shape_env is None - or not mode.shape_env.allow_dynamic_output_shape_ops + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops or name not in supported_dynamic_output_op_tests ) except torch._subclasses.fake_tensor.DataDependentOutputException: diff --git a/test/test_optim.py b/test/test_optim.py index 9e3ee50ff302..d61c33e2adce 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -20,7 +20,7 @@ register_optimizer_step_post_hook, register_optimizer_step_pre_hook, ) -from torch.testing._internal.common_cuda import _create_scaling_case, TEST_MULTIGPU +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, largeTensorTest, @@ -287,7 +287,7 @@ def test_param_group_with_lrscheduler_goes_right_direction( inpt = torch.randn(5, device=device, dtype=dtype) # avoid endless recompiles by wrapping LR in a tensor if we're compiling - lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01 + lr = torch.tensor(0.01) if torch._utils.is_compiling() else 0.01 optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] @@ -1954,100 +1954,6 @@ def test_fused_cpu_matches_cuda(self, device, dtype, optim_info): optimizers.append(optimizer) self._compare_between(inpts, models, optimizers) - @onlyNativeDeviceTypes - @optims( - [optim for optim in optim_db if "fused" in optim.supported_impls], - dtypes=[torch.float32], - ) - def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): - # This ut is from test_cuda.py test_grad_scaling_autocast_fused_optimizers - # but only test Adam/AdamW on CPU - # TODO: haozhe, support SGD and unified this ut with the CUDA only one - if device not in optim_info.supports_fused_on: - self.skipTest( - f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" - ) - optim_inputs = optim_info.optim_inputs_func(device=device) - optim_cls = optim_info.optim_cls - for optim_input in optim_inputs: - kwargs = optim_input.kwargs - kwargs["fused"] = True - for _separate_unscale in (True, False): - self._grad_scaling_autocast_fused_optimizers( - device=device, - optimizer_ctor=optim_cls, - optimizer_kwargs=kwargs, - separate_unscale=_separate_unscale, - ) - - def _grad_scaling_autocast_fused_optimizers( - self, device, optimizer_ctor, optimizer_kwargs, separate_unscale - ): - torch.manual_seed(20) - ( - mod_control, - mod_scaling, - opt_control, - opt_scaling, - data, - loss_fn, - _, - ) = _create_scaling_case( - optimizer_ctor=optimizer_ctor, - optimizer_kwargs=optimizer_kwargs, - device="cpu", - ) - kwargs = deepcopy(optimizer_kwargs) - kwargs["fused"] = False - if "lr" not in optimizer_kwargs: - # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr - kwargs["lr"] = 1.0 - opt_control = optimizer_ctor(mod_control.parameters(), **kwargs) - - scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) - scaler_control = torch.amp.GradScaler(device, init_scale=128.0) - tracker = TensorTracker() - for input, target in data: - opt_control.zero_grad() - with torch.autocast(device_type=device, dtype=torch.half): - output_control = mod_control(input) - loss_control = loss_fn(output_control, target) - scaler_control.scale(loss_control).backward() - scaler_control.step(opt_control) - scaler_control.update() - - opt_scaling.zero_grad() - with torch.autocast(device_type=device, dtype=torch.half): - output_scaling = mod_scaling(input) - loss_scaling = loss_fn(output_scaling, target) - scaler_scaling.scale(loss_scaling).backward() - if separate_unscale: - scaler_scaling.unscale_(opt_scaling) - scaler_scaling.step(opt_scaling) - scaler_scaling.update() - - tracker.add(loss_control) - tracker.pop_check_set(loss_scaling, self) - for param_control, param_scaling in zip( - mod_control.parameters(), mod_scaling.parameters() - ): - tracker.add(param_control.grad) - tracker.pop_check_set(param_scaling.grad, self) - tracker.add(param_control) - tracker.pop_check_set(param_scaling, self) - - state_control, state_scaling = ( - opt_control.state[param_control], - opt_scaling.state[param_scaling], - ) - - for k in state_control: - actual = state_scaling[k] - if k == "step": - actual = actual.squeeze() - tracker.add(state_control[k]) - tracker.pop_check_set(actual, self) - @onlyCUDA @optims( [o for o in optim_db if "foreach" in o.supported_impls], dtypes=[torch.float32] diff --git a/test/test_overrides.py b/test/test_overrides.py index cb46ca6ed880..a55688b95f31 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -8,8 +8,9 @@ import pickle import collections import unittest +import contextlib -from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF +from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO from torch.overrides import ( handle_torch_function, has_torch_function, @@ -377,6 +378,27 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs) class TestTorchFunctionOverride(TestCase): + @classmethod + def setUpClass(cls): + cls._stack = contextlib.ExitStack() + if TEST_WITH_TORCHDYNAMO: + # Add classes to the wrapped tensor subclasses + @contextlib.contextmanager + def setup_subclasses(): + old = set(torch._dynamo.config.traceable_tensor_subclasses) + torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor) + try: + yield + finally: + torch._dynamo.config.traceable_tensor_subclasses.clear() + torch._dynamo.config.traceable_tensor_subclasses.update(old) + + cls._stack.enter_context(setup_subclasses()) + + @classmethod + def tearDownClass(cls): + cls._stack.close() + def test_mean_semantics(self): """Test that a function with one argument can be overrided""" t1 = DiagonalTensor(5, 2) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d8aa8863d566..04483ffba0fc 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,7 +1618,8 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - pow_1 = sym_size_int ** 0.5; sym_size_int = None + sym_float = torch.sym_float(sym_size_int); sym_size_int = None + pow_1 = sym_float ** 0.5; sym_float = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") @@ -2003,7 +2004,6 @@ def f(t): xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition - xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... @@ -2034,8 +2034,6 @@ def f(t): inplace_symbolic_tensor_failures = { # bugs xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double - # decomp not implemented - xfail('unique', ''), } out_symbolic_tensor_failures = { diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 1db0e5718ce6..8ab2ac1f511f 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -329,7 +329,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.fake_pg", "torch.testing._internal.distributed.multi_threaded_pg", "torch.testing._internal.distributed.nn.api.remote_module_test", - "torch.testing._internal.distributed.pipe_with_ddp_test", "torch.testing._internal.distributed.rpc.dist_autograd_test", "torch.testing._internal.distributed.rpc.dist_optimizer_test", "torch.testing._internal.distributed.rpc.examples.parameter_server_test", @@ -408,7 +407,6 @@ def test_modules_can_be_imported(self): "torch.distributed.nn.api.remote_module", "torch.distributed.optim", "torch.distributed.optim.optimizer", - "torch.distributed.pipeline.sync", "torch.distributed.rendezvous", "torch.distributed.rpc.api", "torch.distributed.rpc.backend_registry", diff --git a/test/test_serialization.py b/test/test_serialization.py index e83cafd3f3d8..f22331831c39 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -18,7 +18,6 @@ from collections import OrderedDict from copy import deepcopy from itertools import product -from types import ModuleType from torch._utils_internal import get_file_path_2 from torch._utils import _rebuild_tensor @@ -4111,23 +4110,6 @@ def __setstate__(self, state): class TestEmptySubclass(torch.Tensor): ... -# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them -# Cannot define locally in test or pickle will fail. -class TestEmptySubclassSpoof(TestEmptySubclass): - ... - -class TestWrapperSubclassSpoof(TestWrapperSubclass): - ... - -class RebuildFromTypeV2Spoof(torch.Tensor): - def __new__(cls, elem, naughty, **kwargs): - if naughty: - raise RuntimeError("naughty") - return super().__new__(cls, elem) - - def __reduce_ex__(self, protocol): - return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {})) - class TestSubclassSerialization(TestCase): def test_tensor_subclass_wrapper_serialization(self): @@ -4207,201 +4189,42 @@ def test_empty_class_serialization(self): f.seek(0) tensor2 = torch.load(f) - def _create_bad_func(self, name): - def bad_func(self, *args, **kwargs): - raise RuntimeError(f"running {name}") - return bad_func - - @parametrize("wrapper", (True, False)) - def test_tensor_subclass_method_spoofing(self, wrapper): - ''' - This tests seeks to do the following: - - determine which methods of a tensor subclass might be called during unpickling (weights_only=False) - we consider these methods "risky" for weights_only - - ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True) - - ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True) - - We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError - when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that - only the RuntimeErrors that we expect are thrown. - - We then load with weights_only and ensure that weights_only will fail unless all the risky methods - are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load. - The final weights_only load call when all the risky methods are no longer overriden. - ''' - subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof - t = subclass(torch.randn(2, 3)) - # To trigger setattr for the non-wrapper case - if not wrapper: - t.foo = 'bar' - inp = {'weight': t} - - with TemporaryFileName() as f: - torch.save(inp, f) - loaded = torch.load(f, weights_only=True) - self.assertEqual(loaded['weight'], inp['weight']) - - restore_methods = dict() - methods = [func for func in dir(subclass) if callable(getattr(subclass, func))] - for method in methods: - if method != "__class__": - restore_methods[method] = getattr(subclass, method) - setattr(subclass, method, self._create_bad_func(method)) - # These additional methods might be called during getattr or setattr - # but are not in methods above (not defined on tensor base class) - subclass.__get__ = self._create_bad_func("__get__") - subclass.__set__ = self._create_bad_func("__set__") - subclass.__getattr__ = self._create_bad_func("__getattr__") - restore_methods["__get__"] = None - restore_methods["__getattr__"] = None - restore_methods["__set__"] = None - - try: - # Check that weights_only=False load raises the RuntimeErrors we expect - with self.assertRaisesRegex(RuntimeError, "running __getattribute__"): - torch.load(f, weights_only=False) - subclass.__getattribute__ = restore_methods['__getattribute__'] - with self.assertRaisesRegex(RuntimeError, "running __setstate__"): - torch.load(f, weights_only=False) - subclass.__setstate__ = restore_methods['__setstate__'] - with self.assertRaisesRegex(RuntimeError, "running __setattr__"): - torch.load(f, weights_only=False) - subclass.__setattr__ = restore_methods['__setattr__'] - # should finally work - torch.load(f, weights_only=False) - - # Check that weights_only=True catches that risky methods are overriden - subclass.__setstate__ = self._create_bad_func("__setstate__") - subclass.__getattribute__ = self._create_bad_func("__getattribute__") - subclass.__setattr__ = self._create_bad_func("__setattr__") - with self.assertRaisesRegex(pickle.UnpicklingError, - "methods: __getattribute__=True __getattr__=True __get__=True " - "__setattr__=True __set__=True __setstate__=True"): - torch.load(f, weights_only=True) - risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__'] - for i, meth in enumerate(risky_methods): - setattr(subclass, meth, restore_methods[meth]) - if i != len(risky_methods) - 1: - # When the given methods are not all back to default, load should still throw - # but reflect which methods are no longer overriden - with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"): - torch.load(f, weights_only=True) - else: - # When the given methods are all back to default, weights_only load should finally work - loaded = torch.load(f, weights_only=True) - finally: - for method, func in restore_methods.items(): - setattr(subclass, method, func) - a = subclass(torch.randn(2, 3)) - @skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined") def test_safe_globals_for_weights_only(self): ''' Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs ''' - # Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment - global TwoTensor t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) p = torch.nn.Parameter(t) sd = OrderedDict([('t', t), ('p', p)]) with tempfile.NamedTemporaryFile() as f: torch.save(sd, f) - # unimport TwoTensor - try: - del sys.modules['torch.testing._internal.two_tensor'] - - # Loading tensor subclass with weights_only=True should fail - # if tensor subclass has not been imported - with self.assertRaisesRegex(pickle.UnpicklingError, - "expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"): - f.seek(0) - sd = torch.load(f, weights_only=True) - # Loading tensor subclass with weights_only=True should work - # if target methods are not overriden and user has imported the subclass - from torch.testing._internal.two_tensor import TwoTensor + # Loading tensor subclass with weights_only=True should fail + # since tensor subclass is not in safe_globals + with self.assertRaisesRegex(pickle.UnpicklingError, + "Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"): f.seek(0) sd = torch.load(f, weights_only=True) + + # Loading tensor subclass should work if the class is marked safe + f.seek(0) + try: + torch.serialization.add_safe_globals([TwoTensor]) + self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) + sd = torch.load(f, weights_only=True) self.assertEqual(sd['t'], t) self.assertEqual(sd['p'], p) - # Loading tensor subclass with weights_only=True should fail - # if __setstate__ is overriden + # Should fail again when safe globals are cleared + torch.serialization.clear_safe_globals() f.seek(0) - restore_setstate = TwoTensor.__setstate__ - try: - TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state) - with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): - torch.load(f, weights_only=True) - - # Loading tensor subclass with overriden __setstate__ with weights_only=True should work - # if the class is marked safe - f.seek(0) - torch.serialization.add_safe_globals([TwoTensor]) - self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor]) - sd = torch.load(f, weights_only=True) - self.assertEqual(sd['t'], t) - self.assertEqual(sd['p'], p) - - # Should fail again when safe globals are cleared - torch.serialization.clear_safe_globals() - f.seek(0) - with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"): - torch.load(f, weights_only=True) - finally: - TwoTensor.__setstate__ = restore_setstate + with self.assertRaisesRegex(pickle.UnpicklingError, + "Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"): + torch.load(f, weights_only=True) finally: - from torch.testing._internal.two_tensor import TwoTensor - - - def test_tensor_subclass_parent_module_method_spoofing(self): - ''' - Tests that weights_only load does not call any methods of the parent module - that contains the tensor subclass. - - We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError - when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that - no RuntimeErrors are thrown. - ''' - # Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass` - class SpoofModule(ModuleType): - pass - - spoof_mod = SpoofModule('bla') - spoof_mod.TestEmptySubclass = TestEmptySubclass - inp = {'weight': TestEmptySubclass(torch.randn(2, 3))} - TestEmptySubclass.__module__ = 'spoof_mod' - sys.modules['spoof_mod'] = spoof_mod - - try: - with TemporaryFileName() as f: - torch.save(inp, f) - torch.load(f, weights_only=True) - restore_methods = dict() - methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))] - for method in methods: - if method != "__class__": - restore_methods[method] = getattr(SpoofModule, method) - setattr(SpoofModule, method, self._create_bad_func(method)) - SpoofModule.__get__ = self._create_bad_func("__get__") - SpoofModule.__getattr__ = self._create_bad_func("__getattr__") - loaded = torch.load(f, weights_only=True) - self.assertEqual(loaded['weight'], inp['weight']) - finally: - TestEmptySubclass.__module__ = __name__ - del sys.modules['spoof_mod'] - - def test_rebuild_from_type_v2_spoof(self): - t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False) - inp = {'weight': t} - - with TemporaryFileName() as f: - torch.save(inp, f) - # subclass will be pushed onto unpickler's stack as a string - # and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2 - with self.assertRaisesRegex(TypeError, "'str' object is not callable"): - loaded = torch.load(f, weights_only=True) + torch.serialization.clear_safe_globals() @unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda") def test_tensor_subclass_map_location(self): diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 47acfff9c6d4..2353d6841bbb 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -1,22 +1,41 @@ # Owner(s): ["module: tests"] -import torch -import numpy as np - -from itertools import product, combinations, permutations, chain -from functools import partial import random -import warnings import unittest +import warnings +from functools import partial + +from itertools import chain, combinations, permutations, product + +import numpy as np + +import torch from torch import nan from torch.testing import make_tensor -from torch.testing._internal.common_utils import ( - TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes, - dtypesIfCUDA, largeTensorTest) -from torch.testing._internal.common_dtype import all_types_and_complex_and, all_types, all_types_and + dtypes, + dtypesIfCUDA, + instantiate_device_type_tests, + largeTensorTest, + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, +) +from torch.testing._internal.common_dtype import ( + all_types, + all_types_and, + all_types_and_complex_and, +) +from torch.testing._internal.common_utils import ( + IS_JETSON, + run_tests, + skipIfTorchDynamo, + TEST_PRIVATEUSE1_DEVICE_TYPE, + TestCase, + torch_to_numpy_dtype_dict, +) + # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -29,17 +48,19 @@ def _generate_input(shape, dtype, device, with_extremal): x = torch.randn(*shape, device=device) * random.randint(30, 100) x = x.to(torch.bfloat16) else: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( + 30, 100 + ) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values - x[torch.randn(*shape) > 0.5] = float('nan') - x[torch.randn(*shape) > 0.5] = float('inf') - x[torch.randn(*shape) > 0.5] = float('-inf') + x[torch.randn(*shape) > 0.5] = float("nan") + x[torch.randn(*shape) > 0.5] = float("inf") + x[torch.randn(*shape) > 0.5] = float("-inf") elif with_extremal and dtype.is_complex: - x[torch.randn(*shape) > 0.5] = complex('nan') - x[torch.randn(*shape) > 0.5] = complex('inf') - x[torch.randn(*shape) > 0.5] = complex('-inf') + x[torch.randn(*shape) > 0.5] = complex("nan") + x[torch.randn(*shape) > 0.5] = complex("inf") + x[torch.randn(*shape) > 0.5] = complex("-inf") elif dtype == torch.bool: x = torch.zeros(shape, dtype=dtype, device=device) x[torch.randn(*shape) > 0.5] = True @@ -48,8 +69,8 @@ def _generate_input(shape, dtype, device, with_extremal): return x -class TestShapeOps(TestCase): +class TestShapeOps(TestCase): # TODO: update to work on CUDA, too @onlyCPU def test_unbind(self, device): @@ -71,7 +92,7 @@ def test_tolist(self, device): tensor0D = torch.tensor(list0D) self.assertEqual(tensor0D.tolist(), list0D) - table1D = [1., 2., 3.] + table1D = [1.0, 2.0, 3.0] tensor1D = torch.tensor(table1D) storage = torch.Storage(table1D) self.assertEqual(tensor1D.tolist(), table1D) @@ -102,19 +123,29 @@ def test_movedim_invalid(self, device, dtype): fn(x, 0, 5) # Mismatch in size of `source` and `destination` - with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"): - fn(x, (1, 0), (0, )) - - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: Invalid source or destination dims:" + ): + fn(x, (1, 0), (0,)) + + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `source`" + ): fn(x, (0, 0), (0, 1)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `source`" + ): fn(x, (0, 1, 0), (0, 1, 2)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `destination`" + ): fn(x, (0, 1), (1, 1)) - with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"): + with self.assertRaisesRegex( + RuntimeError, "movedim: repeated dim in `destination`" + ): fn(x, (0, 1, 2), (1, 0, 1)) @dtypes(torch.int64, torch.float, torch.complex128) @@ -137,8 +168,12 @@ def test_movedim(self, device, dtype): # Integer `source` and `destination` torch_fn = partial(fn, source=src_dim, destination=dst_dim) - np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + np_fn = partial( + np.moveaxis, source=src_dim, destination=dst_dim + ) + self.compare_with_numpy( + torch_fn, np_fn, x, device=None, dtype=None + ) if nd == 0: continue @@ -148,9 +183,13 @@ def make_index_negative(sequence, idx): sequence[random_idx] = sequence[random_idx] - nd return tuple(src_sequence) - for src_sequence in permutations(range(nd), r=random.randint(1, nd)): + for src_sequence in permutations( + range(nd), r=random.randint(1, nd) + ): # Sequence `source` and `destination` - dst_sequence = tuple(random.sample(range(nd), len(src_sequence))) + dst_sequence = tuple( + random.sample(range(nd), len(src_sequence)) + ) # Randomly change a dim to a negative dim representation of itself. random_prob = random.random() @@ -166,9 +205,15 @@ def make_index_negative(sequence, idx): random_idx = random.randint(0, len(src_sequence) - 1) src_sequence = make_index_negative(src_sequence, random_idx) - torch_fn = partial(fn, source=src_sequence, destination=dst_sequence) - np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence) - self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) + torch_fn = partial( + fn, source=src_sequence, destination=dst_sequence + ) + np_fn = partial( + np.moveaxis, source=src_sequence, destination=dst_sequence + ) + self.compare_with_numpy( + torch_fn, np_fn, x, device=None, dtype=None + ) # Move dim to same position x = torch.randn(2, 3, 5, 7, 11) @@ -213,10 +258,7 @@ def test_diagonal(self, device): def test_diagonal_multidim(self, device, dtype): x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) xn = x.numpy() - for args in [(2, 2, 3), - (2,), - (-2, 1, 2), - (0, -2, -1)]: + for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]: result = torch.diagonal(x, *args) expected = xn.diagonal(*args) self.assertEqual(expected.shape, result.shape) @@ -270,14 +312,22 @@ def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nan max_vals = max_vals.cpu().numpy() # Use NumPy implementation as reference - X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device) + X_clamped = torch.tensor( + np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device + ) return X, X_clamped # Tests clamp and its alias, clip @dtypes(torch.int64, torch.float32) def test_clamp(self, device, dtype): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, - torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + op_list = ( + torch.clamp, + torch.Tensor.clamp, + torch.Tensor.clamp_, + torch.clip, + torch.Tensor.clip, + torch.Tensor.clip_, + ) # min/max argument product args = product((-10, None), (10, None)) @@ -287,10 +337,9 @@ def test_clamp(self, device, dtype): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, dtype, - min_vals=min_val, - max_vals=max_val, - with_nans=False) + X, Y_expected = self.generate_clamp_baseline( + device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False + ) # Test op X1 = X.clone() # So that the in-place ops do not change X @@ -304,8 +353,14 @@ def test_clamp(self, device, dtype): self.assertEqual(Y_expected, Y_out) def test_clamp_propagates_nans(self, device): - op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_, - torch.clip, torch.Tensor.clip, torch.Tensor.clip_) + op_list = ( + torch.clamp, + torch.Tensor.clamp, + torch.Tensor.clamp_, + torch.clip, + torch.Tensor.clip, + torch.Tensor.clip_, + ) # min/max argument product args = product((-10, None), (10, None)) @@ -315,10 +370,13 @@ def test_clamp_propagates_nans(self, device): if min_val is None and max_val is None: continue - X, Y_expected = self.generate_clamp_baseline(device, torch.float, - min_vals=min_val, - max_vals=max_val, - with_nans=True) + X, Y_expected = self.generate_clamp_baseline( + device, + torch.float, + min_vals=min_val, + max_vals=max_val, + with_nans=True, + ) Y_expected = torch.isnan(Y_expected) # Test op @@ -334,7 +392,7 @@ def test_clamp_propagates_nans(self, device): def test_clamp_raises_arg_errors(self, device): X = torch.randn(100, dtype=torch.float, device=device) - error_msg = 'At least one of \'min\' or \'max\' must not be None' + error_msg = "At least one of 'min' or 'max' must not be None" with self.assertRaisesRegex(RuntimeError, error_msg): X.clamp() with self.assertRaisesRegex(RuntimeError, error_msg): @@ -369,18 +427,22 @@ def all_t(): self.assertEqual(in_t.flip(p_dims), out_t) if len(p_dims) > 0: # Wrap 1st dim - self.assertEqual(in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t) + self.assertEqual( + in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t + ) def gen_data(): # Basic tests data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2) nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data) - dims_result = ((0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), - (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), - (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), - ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), - ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2))) + dims_result = ( + (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), + (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), + (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), + ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), + ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)), + ) for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result): yield in_tensor, dims, out_tensor @@ -393,7 +455,9 @@ def gen_data(): yield in_t, 1, in_t # Transposed - in_t = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) + in_t = ( + make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) + ) dims = (0, 1, 2) out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2) yield in_t, dims, out_t @@ -411,7 +475,9 @@ def gen_data(): if device == "cpu" and dtype != torch.bfloat16: for mf in [torch.contiguous_format, torch.channels_last]: for c in [2, 3, 8, 16]: - in_t = make_from_size((2, c, 32, 32)).contiguous(memory_format=mf) + in_t = make_from_size((2, c, 32, 32)).contiguous( + memory_format=mf + ) np_in_t = in_t.numpy() np_out_t = np_in_t[:, :, :, ::-1].copy() @@ -464,7 +530,9 @@ def gen_data(): size = [2, 3, 4] data = make_from_size(size) possible_dims = range(len(size)) - test_dims = chain(combinations(possible_dims, 1), combinations(possible_dims, 2)) + test_dims = chain( + combinations(possible_dims, 1), combinations(possible_dims, 2) + ) for dims in test_dims: self.assertEqual(size, list(data.flip(dims).size())) @@ -483,7 +551,6 @@ def test_flip_errors(self, device, dtype): self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) self.assertRaises(IndexError, lambda: data.flip(3)) - def _rand_shape(self, dim, min_size, max_size): return tuple(torch.randint(min_size, max_size + 1, (dim,))) @@ -504,8 +571,10 @@ def test_flip_numpy(self, device, dtype): self.compare_with_numpy(torch_fn, np_fn, data) @onlyCUDA # CPU is too slow - @largeTensorTest('17GB') # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB - @largeTensorTest("81GB", "cpu") # even for CUDA test, sufficient system memory is required + @largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB + @largeTensorTest( + "81GB", "cpu" + ) # even for CUDA test, sufficient system memory is required @unittest.skipIf(IS_JETSON, "Too large for Jetson") def test_flip_large_tensor(self, device): t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() @@ -569,7 +638,9 @@ def test_rot90(self, device): # test tensor with more than 2D data = torch.arange(1, 9, device=device).view(2, 2, 2) - self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + self.assertEqual( + torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]) + ) self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) # test for errors @@ -601,7 +672,6 @@ def test_nonzero_no_warning(self, device): @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16)) def test_nonzero(self, device, dtype): - shapes = [ torch.Size((12,)), torch.Size((12, 1)), @@ -616,7 +686,9 @@ def gen_nontrivial_input(shape, dtype, device): return torch.randint(2, shape, device=device, dtype=dtype) else: # windows does not work for bfloat16 randing - return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) + return torch.randint(2, shape, device=device, dtype=torch.float).to( + dtype + ) for shape in shapes: tensor = gen_nontrivial_input(shape, dtype, device) @@ -624,20 +696,31 @@ def gen_nontrivial_input(shape, dtype, device): dst2 = tensor.nonzero(as_tuple=False) dst3 = torch.empty([], dtype=torch.long, device=device) torch.nonzero(tensor, out=dst3) - if self.device_type != 'xla': + if self.device_type != "xla": # xla does not raise runtime error self.assertRaisesRegex( RuntimeError, "scalar type Long", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device)) + lambda: torch.nonzero( + tensor, out=torch.empty([], dtype=torch.float, device=device) + ), ) - if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE: + if ( + self.device_type == "cuda" + or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE + ): self.assertRaisesRegex( RuntimeError, "on the same device", - lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long)) + lambda: torch.nonzero( + tensor, out=torch.empty([], dtype=torch.long) + ), ) - np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() + np_array = ( + tensor.cpu().numpy() + if dtype != torch.bfloat16 + else tensor.float().cpu().numpy() + ) np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) @@ -656,7 +739,9 @@ def test_nonzero_astuple_out(self, device): with self.assertRaises(RuntimeError): torch.nonzero(t, as_tuple=True, out=out) - self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) + self.assertEqual( + torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out) + ) # Verifies that JIT script cannot handle the as_tuple kwarg # See Issue https://github.com/pytorch/pytorch/issues/45499. @@ -684,7 +769,9 @@ def _foo(t): def test_nonzero_discontiguous(self, device): shape = (4, 4) tensor = torch.randint(2, shape, device=device) - tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) + tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_( + tensor + ) dst1 = tensor.nonzero(as_tuple=False) dst2 = tensor_nc.nonzero(as_tuple=False) self.assertEqual(dst1, dst2, atol=0, rtol=0) @@ -695,7 +782,9 @@ def test_nonzero_discontiguous(self, device): self.assertEqual(data_ptr, dst3.data_ptr()) self.assertEqual(dst1, dst3, atol=0, rtol=0) # discontiguous out - dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] + dst4 = torch.empty( + dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device + )[:, ::2] data_ptr = dst4.data_ptr() strides = dst4.stride() torch.nonzero(tensor, out=dst4) @@ -710,7 +799,7 @@ def test_nonzero_non_diff(self, device): @dtypes(torch.int64, torch.float, torch.complex128) def test_sparse_dense_dim(self, device, dtype): - for shape in [(), (2, ), (2, 3)]: + for shape in [(), (2,), (2, 3)]: if dtype.is_complex or dtype.is_floating_point: x = torch.rand(shape, device=device, dtype=dtype) else: @@ -718,7 +807,8 @@ def test_sparse_dense_dim(self, device, dtype): self.assertEqual(x.sparse_dim(), 0) self.assertEqual(x.dense_dim(), len(shape)) + instantiate_device_type_tests(TestShapeOps, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_show_pickle.py b/test/test_show_pickle.py index 929584943007..48b459e12eac 100644 --- a/test/test_show_pickle.py +++ b/test/test_show_pickle.py @@ -1,15 +1,16 @@ # Owner(s): ["oncall: mobile"] -import unittest import io import tempfile +import unittest + import torch import torch.utils.show_pickle -from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase -class TestShowPickle(TestCase): +class TestShowPickle(TestCase): @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): class MyCoolModule(torch.nn.Module): @@ -26,11 +27,13 @@ def forward(self, x): torch.jit.save(m, tmp) tmp.flush() buf = io.StringIO() - torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf) + torch.utils.show_pickle.main( + ["", tmp.name + "@*/data.pkl"], output_stream=buf + ) output = buf.getvalue() self.assertRegex(output, "MyCoolModule") self.assertRegex(output, "weight") -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 7709131e6102..211c7b998608 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -1,44 +1,69 @@ # Owner(s): ["module: tests"] -import torch +import random +from itertools import permutations, product + import numpy as np -import random +import torch from torch import nan -from itertools import permutations, product from torch.testing import make_tensor -from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and, integral_types -from torch.testing._internal.common_utils import \ - (TestCase, run_tests, slowTest, skipIfTorchDynamo) -from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, onlyNativeDeviceTypes, - onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest) +from torch.testing._internal.common_device_type import ( + dtypes, + dtypesIfCPU, + dtypesIfCUDA, + instantiate_device_type_tests, + largeTensorTest, + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, +) +from torch.testing._internal.common_dtype import ( + all_types, + all_types_and, + floating_types_and, + integral_types, +) +from torch.testing._internal.common_utils import ( + run_tests, + skipIfTorchDynamo, + slowTest, + TestCase, +) # TODO: remove this SIZE = 100 -class TestSortAndSelect(TestCase): +class TestSortAndSelect(TestCase): def assertIsOrdered(self, order, x, mxx, ixx, task): SIZE = x.size(1) - if order == 'descending': + if order == "descending": + def check_order(a, b): # `a != a` because we put NaNs # at the end of ascending sorted lists, # and the beginning of descending ones. return ((a != a) | (a >= b)).all().item() - elif order == 'ascending': + + elif order == "ascending": + def check_order(a, b): # see above return ((b != b) | (a <= b)).all().item() + else: - error(f'unknown order "{order}", must be "ascending" or "descending"') # noqa: F821 + error( # noqa: F821 + f'unknown order "{order}", must be "ascending" or "descending"' + ) are_ordered = True for k in range(1, SIZE): - self.assertTrue(check_order(mxx[:, k - 1], mxx[:, k]), - f'torch.sort ({order}) values unordered for {task}') + self.assertTrue( + check_order(mxx[:, k - 1], mxx[:, k]), + f"torch.sort ({order}) values unordered for {task}", + ) seen = set() indicesCorrect = True @@ -50,8 +75,11 @@ def check_order(a, b): for k in range(size0): seen.clear() for j in range(size): - self.assertEqual(x[k][ixx[k][j]], mxx[k][j], - msg=f'torch.sort ({order}) indices wrong for {task}') + self.assertEqual( + x[k][ixx[k][j]], + mxx[k][j], + msg=f"torch.sort ({order}) indices wrong for {task}", + ) seen.add(ixx[k][j]) self.assertEqual(len(seen), size) @@ -79,19 +107,22 @@ def test_sort(self, device): self.assertEqual(x.argsort(), res1ind) # Test sorting of random numbers - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') + self.assertIsOrdered("ascending", x, res2val, res2ind, "random") # Test simple sort self.assertEqual( torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], torch.tensor((10, 20, 30, 40, 50), device=device), - atol=0, rtol=0 + atol=0, + rtol=0, ) # Test that we still have proper sorting with duplicate keys x = torch.floor(torch.rand(4, SIZE, device=device) * 10) torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') + self.assertIsOrdered( + "ascending", x, res2val, res2ind, "random with duplicate keys" + ) # DESCENDING SORT x = torch.rand(4, SIZE, device=device) @@ -107,35 +138,41 @@ def test_sort(self, device): self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) # Test sorting of random numbers - self.assertIsOrdered('descending', x, res2val, res2ind, 'random') + self.assertIsOrdered("descending", x, res2val, res2ind, "random") # Test simple sort task self.assertEqual( - torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[0], + torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[ + 0 + ], torch.tensor((50, 40, 30, 20, 10), device=device), - atol=0, rtol=0 + atol=0, + rtol=0, ) # Test that we still have proper sorting with duplicate keys - self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') + self.assertIsOrdered( + "descending", x, res2val, res2ind, "random with duplicate keys" + ) # Test argument sorting with and without stable x = torch.tensor([1, 10, 2, 2, 3, 7, 7, 8, 9, 9] * 3) - self.assertEqual(torch.argsort(x, stable=True), torch.sort(x, stable=True).indices) - self.assertEqual(torch.argsort(x, stable=False), torch.sort(x, stable=False).indices) + self.assertEqual( + torch.argsort(x, stable=True), torch.sort(x, stable=True).indices + ) + self.assertEqual( + torch.argsort(x, stable=False), torch.sort(x, stable=False).indices + ) self.assertEqual(torch.argsort(x), torch.sort(x).indices) - # Test sorting with NaNs x = torch.rand(4, SIZE, device=device) - x[1][2] = float('NaN') - x[3][0] = float('NaN') + x[1][2] = float("NaN") + x[3][0] = float("NaN") torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, - 'random with NaNs') + self.assertIsOrdered("ascending", x, res2val, res2ind, "random with NaNs") torch.sort(x, out=(res2val, res2ind), descending=True) - self.assertIsOrdered('descending', x, res2val, res2ind, - 'random with NaNs') + self.assertIsOrdered("descending", x, res2val, res2ind, "random with NaNs") def test_sort_stable_none(self): # Called sort with stable=None used to trigger an assertion @@ -169,19 +206,19 @@ def test_stable_sort(self, device, dtype): _, idx = x.sort(stable=True) self.assertEqual( idx[:ncopies], - torch.arange(start=0, end=2 * ncopies, step=2, device=device) + torch.arange(start=0, end=2 * ncopies, step=2, device=device), ) self.assertEqual( idx[ncopies:], - torch.arange(start=1, end=2 * ncopies, step=2, device=device) + torch.arange(start=1, end=2 * ncopies, step=2, device=device), ) @onlyCUDA @dtypes(torch.uint8) - @largeTensorTest('200GB') # Unfortunately 80GB A100 is not large enough + @largeTensorTest("200GB") # Unfortunately 80GB A100 is not large enough def test_sort_large(self, device, dtype): t0 = torch.randperm(8192, device=device).to(dtype) - t = t0.view(1, 8192).expand(2 ** 18 + 1, -1).contiguous() + t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous() v, i = t.sort() del t iv, im = i.var_mean(dim=0) @@ -193,7 +230,6 @@ def test_sort_large(self, device, dtype): self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device)) self.assertEqual(im, t0.sort().indices) - @dtypes(torch.float32) def test_sort_restride(self, device, dtype): # Input: non-contiguous (stride: 5) 3-element array @@ -223,14 +259,24 @@ def _test_sort_discontiguous(self, device, dtype): n = t.size(dim) # assert ordered - self.assertTrue((r1.values.narrow(dim, 1, n - 1) >= r1.values.narrow(dim, 0, n - 1)).all()) + self.assertTrue( + ( + r1.values.narrow(dim, 1, n - 1) + >= r1.values.narrow(dim, 0, n - 1) + ).all() + ) # assert that different segments does not mix, which can easily happen # if the stride is not handled correctly - self.assertTrue((t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)).any(dim=dim).any(dim=-1).all()) + self.assertTrue( + (t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1)) + .any(dim=dim) + .any(dim=-1) + .all() + ) # assert stride is preserved - if self.device_type == 'cuda': + if self.device_type == "cuda": # FIXME: this behavior should be true for all cases, not # just the one specified in if condition self.assertEqual(r1.values.stride(), t.stride()) @@ -262,7 +308,9 @@ def test_sort_1d_output_discontiguous(self, device, dtype): @dtypes(*integral_types()) def test_sort_1d_parallel(self, device, dtype): low = 0 if dtype == torch.uint8 else -128 - tensor = torch.randint(low=low, high=127, size=(100000, ), device=device, dtype=dtype) + tensor = torch.randint( + low=low, high=127, size=(100000,), device=device, dtype=dtype + ) vals, _ = torch.sort(tensor, stable=True) self.assertEqual(True, torch.all(vals[:-1] <= vals[1:])) @@ -283,9 +331,9 @@ def test_topk_1d_output_discontiguous(self, device, dtype): @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_stable_sort_against_numpy(self, device, dtype): if dtype in floating_types_and(torch.float16, torch.bfloat16): - inf = float('inf') - neg_inf = -float('inf') - nan = float('nan') + inf = float("inf") + neg_inf = -float("inf") + nan = float("nan") else: if dtype != torch.bool: # no torch.iinfo support for torch.bool @@ -305,7 +353,7 @@ def generate_samples(): # binary strings yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0) - if self.device_type == 'cuda': + if self.device_type == "cuda": return yield (torch.tensor([0, 1] * 100, dtype=dtype, device=device), 0) @@ -326,13 +374,21 @@ def repeated_index_fill(t, dim, idxs, vals): # for each dimension. n_fill_vals = 3 # cardinality of (inf, neg_inf, nan) for dim in range(len(sizes)): - idxs = (torch.randint(high=size, size=(size // 10,)) for i in range(n_fill_vals)) + idxs = ( + torch.randint(high=size, size=(size // 10,)) + for i in range(n_fill_vals) + ) vals = (inf, neg_inf, nan) - subsets = chain.from_iterable(combinations(list(zip(idxs, vals)), r) - for r in range(1, n_fill_vals + 1)) + subsets = chain.from_iterable( + combinations(list(zip(idxs, vals)), r) + for r in range(1, n_fill_vals + 1) + ) for subset in subsets: idxs_subset, vals_subset = zip(*subset) - yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim) + yield ( + repeated_index_fill(x, dim, idxs_subset, vals_subset), + dim, + ) for sample, dim in generate_samples(): _, idx_torch = sample.sort(dim=dim, stable=True) @@ -340,7 +396,7 @@ def repeated_index_fill(t, dim, idxs, vals): sample_numpy = sample.float().cpu().numpy() else: sample_numpy = sample.cpu().numpy() - idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') + idx_numpy = np.argsort(sample_numpy, axis=dim, kind="stable") self.assertEqual(idx_torch, idx_numpy) @dtypes(*all_types_and(torch.half, torch.bfloat16)) @@ -349,7 +405,9 @@ def test(shape): tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) if tensor.size() != torch.Size([]): if dtype is torch.bfloat16: - expected = torch.from_numpy(np.msort(tensor.float().cpu().numpy())).bfloat16() + expected = torch.from_numpy( + np.msort(tensor.float().cpu().numpy()) + ).bfloat16() else: expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) else: @@ -364,11 +422,15 @@ def test(shape): shapes = ( [], - [0, ], - [20, ], + [ + 0, + ], + [ + 20, + ], [1, 20], [30, 30], - [10, 20, 30] + [10, 20, 30], ) for shape in shapes: test(shape) @@ -414,9 +476,12 @@ def compare(t, k, dim, dir): sortKVal, sortKInd = topKViaSort(t, k, dim, dir) compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) - t = torch.rand(random.randint(1, SIZE), - random.randint(1, SIZE), - random.randint(1, SIZE), device=device) + t = torch.rand( + random.randint(1, SIZE), + random.randint(1, SIZE), + random.randint(1, SIZE), + device=device, + ) for _kTries in range(3): for _dimTries in range(3): @@ -457,91 +522,94 @@ def test_topk_arguments(self, device): self.assertRaises(TypeError, lambda: q.topk(4, True)) def test_unique_dim(self, device): - self.assertFalse(hasattr(torch, 'unique_dim')) + self.assertFalse(hasattr(torch, "unique_dim")) def run_test(device, dtype): - x = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + x = torch.tensor( + [ + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + ], + dtype=dtype, + device=device, + ) x_empty = torch.empty(5, 0, dtype=dtype, device=device) x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) - x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) + x_ill_formed_empty_another = torch.empty( + 5, 0, 5, dtype=dtype, device=device + ) if dtype in floating_types_and(torch.float16, torch.bfloat16): - x_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_unique_dim0 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + x_nan = torch.tensor( + [float("nan"), 0, 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_unique_dim0 = torch.tensor( + [[[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]]], + dtype=dtype, + device=device, + ) expected_inverse_dim0 = torch.tensor([0, 0]) expected_counts_dim0 = torch.tensor([2]) - expected_unique_dim1 = torch.tensor([[[0., 1.], - [1., 1.], - [2., 1.]], - [[0., 1.], - [1., 1.], - [2., 1.]]], - dtype=dtype, - device=device) - expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]], - [[False, True], [True, True]]], - dtype=torch.bool, - device=device) + expected_unique_dim1 = torch.tensor( + [ + [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], + [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], + ], + dtype=dtype, + device=device, + ) + expected_unique_dim1_bool = torch.tensor( + [[[False, True], [True, True]], [[False, True], [True, True]]], + dtype=torch.bool, + device=device, + ) expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0]) expected_counts_dim1 = torch.tensor([2, 1, 1]) expected_counts_dim1_bool = torch.tensor([2, 2]) - expected_unique_dim2 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) + expected_unique_dim2 = torch.tensor( + [ + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]], + ], + dtype=dtype, + device=device, + ) expected_inverse_dim2 = torch.tensor([0, 1]) expected_counts_dim2 = torch.tensor([1, 1]) expected_unique_empty = torch.empty(5, 0, dtype=dtype, device=device) expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) - expected_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) + expected_unique_nan = torch.tensor( + [float("nan"), 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_inverse_nan = torch.tensor( + [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device + ) + expected_counts_nan = torch.tensor( + [1, 2, 1, 1, 1], dtype=torch.long, device=device + ) # dim0 x_unique = torch.unique(x, dim=0) self.assertEqual(expected_unique_dim0, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=0) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=0) + x, return_inverse=False, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_counts_dim0, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=0) + x, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_dim0, x_unique) self.assertEqual(expected_inverse_dim0, x_inverse) self.assertEqual(expected_counts_dim0, x_counts) @@ -553,10 +621,7 @@ def run_test(device, dtype): else: self.assertEqual(expected_unique_dim1, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=1) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -565,10 +630,8 @@ def run_test(device, dtype): self.assertEqual(expected_inverse_dim1, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=1) + x, return_inverse=False, return_counts=True, dim=1 + ) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_counts_dim1_bool, x_counts) @@ -577,10 +640,8 @@ def run_test(device, dtype): self.assertEqual(expected_counts_dim1, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=1) + x, return_inverse=True, return_counts=True, dim=1 + ) if x.dtype == torch.bool: self.assertEqual(expected_unique_dim1_bool, x_unique) self.assertEqual(expected_inverse_dim1_bool, x_inverse) @@ -594,36 +655,27 @@ def run_test(device, dtype): x_unique = torch.unique(x, dim=2) self.assertEqual(expected_unique_dim2, x_unique) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=2) + x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=2) + x, return_inverse=False, return_counts=True, dim=2 + ) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_counts_dim2, x_counts) x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=2) + x, return_inverse=True, return_counts=True, dim=2 + ) self.assertEqual(expected_unique_dim2, x_unique) self.assertEqual(expected_inverse_dim2, x_inverse) self.assertEqual(expected_counts_dim2, x_counts) # test empty tensor x_unique, x_inverse, x_counts = torch.unique( - x_empty, - return_inverse=True, - return_counts=True, - dim=1) + x_empty, return_inverse=True, return_counts=True, dim=1 + ) self.assertEqual(expected_unique_empty, x_unique) self.assertEqual(expected_inverse_empty, x_inverse) self.assertEqual(expected_counts_empty, x_counts) @@ -631,10 +683,8 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): x_unique, x_inverse, x_counts = torch.unique( - x_nan, - return_inverse=True, - return_counts=True, - dim=0) + x_nan, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_unique_nan, x_unique) self.assertEqual(expected_inverse_nan, x_inverse) self.assertEqual(expected_counts_nan, x_counts) @@ -643,10 +693,8 @@ def run_test(device, dtype): # Checking for runtime error, as this is the expected behaviour with self.assertRaises(RuntimeError): torch.unique( - x_ill_formed_empty, - return_inverse=True, - return_counts=True, - dim=1) + x_ill_formed_empty, return_inverse=True, return_counts=True, dim=1 + ) # test along dim2 with self.assertRaises(RuntimeError): @@ -654,46 +702,66 @@ def run_test(device, dtype): x_ill_formed_empty_another, return_inverse=True, return_counts=True, - dim=2) + dim=2, + ) # test consecutive version y = torch.tensor( - [[0, 1], - [0, 1], - [0, 1], - [1, 2], - [1, 2], - [3, 4], - [0, 1], - [0, 1], - [3, 4], - [1, 2]], + [ + [0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2], + ], dtype=dtype, - device=device + device=device, ) # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): - y_nan = torch.tensor([float("nan"), 0, 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) + y_nan = torch.tensor( + [float("nan"), 0, 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) expected_y_unique = torch.tensor( - [[0, 1], - [1, 2], - [3, 4], - [0, 1], - [3, 4], - [1, 2]], + [[0, 1], [1, 2], [3, 4], [0, 1], [3, 4], [1, 2]], dtype=dtype, - device=device + device=device, + ) + expected_y_inverse = torch.tensor( + [0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device + ) + expected_y_counts = torch.tensor( + [3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device + ) + expected_y_inverse_bool = torch.tensor( + [0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device + ) + expected_y_counts_bool = torch.tensor( + [3, 3, 2, 2], dtype=torch.int64, device=device ) - expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device) - expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device) - expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device) - expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device) if dtype in floating_types_and(torch.float16, torch.bfloat16): - expected_y_unique_nan = torch.tensor([float("nan"), 0, float("nan"), float("nan"), 1], dtype=dtype, device=device) - expected_y_inverse_nan = torch.tensor([0, 1, 1, 2, 3, 4], dtype=torch.long, device=device) - expected_y_counts_nan = torch.tensor([1, 2, 1, 1, 1], dtype=torch.long, device=device) - - y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) + expected_y_unique_nan = torch.tensor( + [float("nan"), 0, float("nan"), float("nan"), 1], + dtype=dtype, + device=device, + ) + expected_y_inverse_nan = torch.tensor( + [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device + ) + expected_y_counts_nan = torch.tensor( + [1, 2, 1, 1, 1], dtype=torch.long, device=device + ) + + y_unique, y_inverse, y_counts = torch.unique_consecutive( + y, return_inverse=True, return_counts=True, dim=0 + ) if x.dtype == torch.bool: self.assertEqual(expected_y_inverse_bool, y_inverse) self.assertEqual(expected_y_counts_bool, y_counts) @@ -704,23 +772,27 @@ def run_test(device, dtype): # test tensor with nan if dtype in floating_types_and(torch.float16, torch.bfloat16): y_unique, y_inverse, y_counts = torch.unique_consecutive( - y_nan, - return_inverse=True, - return_counts=True, - dim=0) + y_nan, return_inverse=True, return_counts=True, dim=0 + ) self.assertEqual(expected_y_unique_nan, y_unique) self.assertEqual(expected_y_inverse_nan, y_inverse) self.assertEqual(expected_y_counts_nan, y_counts) # Test dim is sorted same as NumPy with dims >= 3 - x = torch.tensor([[[[1, 0, 1, 0, 1, 1], - [0, 1, 1, 0, 1, 1]], - [[0, 1, 1, 0, 0, 1], - [0, 0, 0, 1, 0, 0]]], - [[[0, 1, 0, 1, 1, 1], - [0, 1, 1, 0, 1, 1]], - [[0, 0, 1, 1, 0, 1], - [1, 1, 0, 0, 0, 0]]]], dtype=dtype, device=device) + x = torch.tensor( + [ + [ + [[1, 0, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]], + [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0]], + ], + [ + [[0, 1, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1]], + [[0, 0, 1, 1, 0, 1], [1, 1, 0, 0, 0, 0]], + ], + ], + dtype=dtype, + device=device, + ) xn = x.cpu().numpy() for d in range(x.dim()): t = torch.unique(x, dim=d) @@ -750,15 +822,20 @@ def test_topk_noncontiguous_gpu(self, device): def _test_topk_dtype(self, device, dtype, integral, size): if integral: - a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, - size=(size,), dtype=dtype, device=device) + a = torch.randint( + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + size=(size,), + dtype=dtype, + device=device, + ) else: a = torch.randn(size=(size,), dtype=dtype, device=device) - sort_topk = a.sort()[0][-(size // 2):].flip(0) + sort_topk = a.sort()[0][-(size // 2) :].flip(0) topk = a.topk(size // 2) - self.assertEqual(sort_topk, topk[0]) # check values - self.assertEqual(sort_topk, a[topk[1]]) # check indices + self.assertEqual(sort_topk, topk[0]) # check values + self.assertEqual(sort_topk, a[topk[1]]) # check indices @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) def test_topk_integral(self, device, dtype): @@ -770,7 +847,6 @@ def test_topk_integral(self, device, dtype): @dtypes(torch.bfloat16, torch.half) def test_topk_lower_precision(self, device, dtype): - small = 10 large = 4096 verylarge = 8192 # multi_block topk on cuda @@ -780,14 +856,20 @@ def test_topk_lower_precision(self, device, dtype): @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) def test_topk_nonfinite(self, device, dtype): - x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype) + x = torch.tensor( + [float("nan"), float("inf"), 1e4, 0, -1e4, -float("inf")], + device=device, + dtype=dtype, + ) val, idx = x.topk(4) - expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype) + expect = torch.tensor( + [float("nan"), float("inf"), 1e4, 0], device=device, dtype=dtype + ) self.assertEqual(val, expect) self.assertEqual(idx, [0, 1, 2, 3]) val, idx = x.topk(4, largest=False) - expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype) + expect = torch.tensor([-float("inf"), -1e4, 0, 1e4], device=device, dtype=dtype) self.assertEqual(val, expect) self.assertEqual(idx, [5, 4, 3, 2]) @@ -796,13 +878,13 @@ def test_topk_4d(self, device): large = 8192 for size in (small, large): x = torch.ones(2, size, 2, 2, device=device) - x[:, 1, :, :] *= 2. + x[:, 1, :, :] *= 2.0 x[:, 10, :, :] *= 1.5 val, ind = torch.topk(x, k=2, dim=1) expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device) expected_ind[:, 1, :, :] = 10 expected_val = torch.ones(2, 2, 2, 2, device=device) - expected_val[:, 0, :, :] *= 2. + expected_val[:, 0, :, :] *= 2.0 expected_val[:, 1, :, :] *= 1.5 self.assertEqual(val, expected_val, atol=0, rtol=0) self.assertEqual(ind, expected_ind, atol=0, rtol=0) @@ -838,7 +920,17 @@ def _test_unique_scalar_empty(self, dtype, device, f): self.assertEqual(inverse, expected_inverse) self.assertEqual(counts, expected_counts) - def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): + def _test_unique_with_expects( + self, + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + additional_shape, + ): def ensure_tuple(x): if isinstance(x, torch.Tensor): return (x,) @@ -847,7 +939,9 @@ def ensure_tuple(x): for return_inverse in [True, False]: for return_counts in [True, False]: # test with expected - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + ret = ensure_tuple( + f(x, return_inverse=return_inverse, return_counts=return_counts) + ) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) self.assertEqual(expected_unique, ret[0]) if return_inverse: @@ -858,7 +952,9 @@ def ensure_tuple(x): # tests per-element unique on a higher rank tensor. y = x.view(additional_shape) - y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) + y_unique, y_inverse, y_counts = f( + y, return_inverse=True, return_counts=True + ) self.assertEqual(expected_unique, y_unique) self.assertEqual(expected_inverse.view(additional_shape), y_inverse) self.assertEqual(expected_counts, y_counts) @@ -872,9 +968,17 @@ def ensure_tuple(x): return x if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) + x = torch.tensor( + [True, False, False, False, True, False, True, False], + dtype=torch.bool, + device=device, + ) + expected_unique = torch.tensor( + [False, True], dtype=torch.bool, device=device + ) + expected_inverse = torch.tensor( + [1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device + ) expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) else: x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) @@ -890,18 +994,29 @@ def ensure_tuple(x): x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) xs = (x, x_sliced) for f, x in product(fs, xs): - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) + self._test_unique_with_expects( + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + (2, 2, 2), + ) self._test_unique_scalar_empty(dtype, device, f) # test unsorted unique fs = ( lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), - lambda x, **kwargs: x.unique(sorted=False, **kwargs) + lambda x, **kwargs: x.unique(sorted=False, **kwargs), ) for f, x in product(fs, xs): self._test_unique_scalar_empty(dtype, device, f) for return_inverse, return_counts in product((True, False), repeat=2): - ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + ret = ensure_tuple( + f(x, return_inverse=return_inverse, return_counts=return_counts) + ) self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) x_list = x.tolist() x_unique_list = ret[0].tolist() @@ -924,18 +1039,40 @@ def ensure_tuple(x): @dtypes(*all_types_and(torch.half, torch.bool)) def test_unique_consecutive(self, device, dtype): if dtype is torch.bool: - x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device) - expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device) - expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device) + x = torch.tensor( + [True, False, False, False, True, True, False, False, False], + dtype=torch.bool, + device=device, + ) + expected_unique = torch.tensor( + [True, False, True, False], dtype=torch.bool, device=device + ) + expected_inverse = torch.tensor( + [0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device + ) + expected_counts = torch.tensor( + [1, 3, 2, 3], dtype=torch.long, device=device + ) else: x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device) expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device) expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device) - for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]: - self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3)) + for f in [ + torch.unique_consecutive, + lambda x, **kwargs: x.unique_consecutive(**kwargs), + ]: + self._test_unique_with_expects( + device, + dtype, + f, + x, + expected_unique, + expected_inverse, + expected_counts, + (3, 3), + ) self._test_unique_scalar_empty(dtype, device, f) @dtypes(torch.double) @@ -991,7 +1128,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(x, x0, atol=0, rtol=0) # simple test case (with repetitions) - y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) + y = torch.tensor((3.0, 5, 4, 1, 1, 5), dtype=dtype, device=device) self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0) self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0) @@ -1007,7 +1144,7 @@ def test_kthvalue(self, device, dtype): self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0) @dtypes(torch.float) - @onlyNativeDeviceTypes # Fails on XLA + @onlyNativeDeviceTypes # Fails on XLA def test_kthvalue_scalar(self, device, dtype): # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818) # Tests that passing a scalar tensor or 1D tensor with 1 element work either way @@ -1029,7 +1166,9 @@ def assert_isin_equal(a, b): # multi-dim tensor, multi-dim tensor a = torch.arange(24, device=device, dtype=dtype).reshape([2, 3, 4]) - b = torch.tensor([[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype) + b = torch.tensor( + [[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype + ) assert_isin_equal(a, b) # zero-dim tensor @@ -1073,16 +1212,56 @@ def define_expected(lst, invert=False): c = torch.isin(a, b, assume_unique=True, invert=invert) self.assertEqual(c, ec) - a = torch.tensor([5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], device=device, dtype=dtype) + a = torch.tensor( + [5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], + device=device, + dtype=dtype, + ) b = torch.tensor([2, 3, 4] * mult, device=device, dtype=dtype) - ec = define_expected([False, True, False, True, True, True, True, True, True, - False, True, False, False, False], invert=invert) + ec = define_expected( + [ + False, + True, + False, + True, + True, + True, + True, + True, + True, + False, + True, + False, + False, + False, + ], + invert=invert, + ) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) - b = torch.tensor([2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype) - ec = define_expected([True, True, True, True, True, True, True, True, True, True, - True, False, True, True], invert=invert) + b = torch.tensor( + [2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype + ) + ec = define_expected( + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + ], + invert=invert, + ) c = torch.isin(a, b, invert=invert) self.assertEqual(c, ec) @@ -1108,12 +1287,14 @@ def define_expected(lst, invert=False): for assume_unique in [False, True]: a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) b = torch.arange(3, 30, device=device, dtype=dtype) - ec = define_expected([[False, False, False], [True, True, True]], invert=invert) + ec = define_expected( + [[False, False, False], [True, True, True]], invert=invert + ) c = torch.isin(a, b, invert=invert, assume_unique=assume_unique) self.assertEqual(c, ec) def test_isin_different_dtypes(self, device): - supported_types = all_types() if device == 'cpu' else all_types_and(torch.half) + supported_types = all_types() if device == "cpu" else all_types_and(torch.half) for mult in [1, 10]: for assume_unique in [False, True]: for dtype1, dtype2 in product(supported_types, supported_types): @@ -1127,18 +1308,18 @@ def test_isin_different_dtypes(self, device): @dtypes(*all_types()) def test_isin_different_devices(self, device, dtype): a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3]) - b = torch.arange(3, 30, device='cpu', dtype=dtype) + b = torch.arange(3, 30, device="cpu", dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(a, b) - c = torch.arange(6, device='cpu', dtype=dtype).reshape([2, 3]) + c = torch.arange(6, device="cpu", dtype=dtype).reshape([2, 3]) d = torch.arange(3, 30, device=device, dtype=dtype) with self.assertRaises(RuntimeError): torch.isin(c, d) @dtypes(*integral_types()) def test_sort_overflow(self, device, dtype): - " Regression test for https://github.com/pytorch/pytorch/issues/111189 " + "Regression test for https://github.com/pytorch/pytorch/issues/111189" prev_num_threads = torch.get_num_threads() try: low = 0 if dtype == torch.uint8 else -1 @@ -1153,5 +1334,5 @@ def test_sort_overflow(self, device, dtype): instantiate_device_type_tests(TestSortAndSelect, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 434793508d47..863f3c37c217 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -330,7 +330,7 @@ def test_fork_wait_exception(self): raise RuntimeError( "Tried execution of add.Tensors with incompatible shape. " "Exception raised by forked runtime execution does " - f"not contain expected substring: \"{expected_error_msg}\"" + f'not contain expected substring: "{expected_error_msg}"' ) from error """ @@ -360,7 +360,7 @@ def test_fork_wait_exception_async(self): raise RuntimeError( "Tried execution of add.Tensors with incompatible shape. " "Exception raised by forked runtime execution does " - f"not contain expected substring: \"{expected_error_msg}\"" + f'not contain expected substring: "{expected_error_msg}"' ) from error def test_multihead_attention_layer(self): diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index c5da8f7fc0da..8b16b2c620fd 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,7 +36,12 @@ "floor", "ceil", ] -BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] +BINARY_OPS = [ + "truediv", "floordiv", + # "truncdiv", # TODO + # NB: pow is float_pow + "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" +] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -81,16 +86,24 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or ( # sympy will expand to x*x*... for integral b; don't do it if it's big - a <= 0 and b == -1 - ) - or (a == b == 0) # no imaginary numbers # 0**0 is undefined + # no imaginary numbers + or a <= 0 + # 0**0 is undefined + or (a == b == 0) ): return False - elif fn == "mod" and b == 0: + elif fn == "pow_by_natural" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big + b > 4 + or b < 0 + or (a == b == 0) + ): return False - elif (fn == "div" or fn == "truediv") and b == 0: + elif fn == "mod" and (a < 0 or b <= 0): + return False + elif (fn in ["div", "truediv", "floordiv"]) and b == 0: return False return True @@ -130,27 +143,26 @@ def test_pow_half(self): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype_a", ("int", "float")) - @parametrize("dtype_b", ("int", "float")) - def test_binary_ref(self, fn, dtype_a, dtype_b): + @parametrize("dtype", ("int", "float")) + def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - dtype_a = to_dtype[dtype_a] - dtype_b = to_dtype[dtype_b] + # Don't test float on int only methods + if dtype == "float" and fn in ["pow_by_natural", "mod"]: + return + dtype = to_dtype[dtype] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype_a(a) - b = dtype_b(b) + a = dtype(a) + b = dtype(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf - if fn != "floordiv": - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -200,7 +212,8 @@ def test_binary_bool_ref_range(self, fn): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - vals = [-sympy.oo, *CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = CONSTANTS for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -216,40 +229,26 @@ def test_unary_ref_range(self, fn): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = LESS_CONSTANTS for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) - def test_rational_bounds(self): - # Repro from https://github.com/pytorch/pytorch/issues/105097 - from sympy import floor, Eq - shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) - new_expr = ( - Eq(30 * floor(4 * ((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + - 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), - 2880 * floor(((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + - 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) - new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} - self.assertTrue(new_expr.subs({shape_0: 95})) - self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) - class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -258,7 +257,13 @@ def test_interp(self, fn): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) + vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -300,29 +305,17 @@ def test_python_interp_fx(self, fn): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) symbols = [x] if arity == 2: symbols = [x, y] - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -330,11 +323,28 @@ def trace_f(px, py): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn == "pow" and (args[0] == 0 and args[1] <= 0): + elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/test/test_testing.py b/test/test_testing.py index ba9558a3ddd1..1e1dce59a32e 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2245,7 +2245,6 @@ def test_circular_dependencies(self) -> None: else: ignored_modules.append("torch.distributed.nn.api.") ignored_modules.append("torch.distributed.optim.") - ignored_modules.append("torch.distributed.pipeline.") ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed diff --git a/test/test_torch.py b/test/test_torch.py index 0d8a672b93cf..f252ddf4a574 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: tests"] import torch @@ -8397,7 +8398,7 @@ def test_resizable(self) -> None: def test_iter(self) -> None: x = torch.randn(5, 5) for i, sub in enumerate(x): - self.assertEqual(sub, x[i]) + self.assertEqual(sub, x[i]) # noqa: PLR1736 x = torch.tensor([]) self.assertEqual(list(x), []) @@ -10623,12 +10624,9 @@ def test_swap_basic(self): if t1.is_floating_point(): t3 = t1.clone().detach().requires_grad_(True) out = t3 * 2 - with self.assertRaisesRegex(RuntimeError, "Expected single reference to a's"): - torch.utils.swap_tensors(t3, t2) - del out - # Now succeeds torch.utils.swap_tensors(t3, t2) - torch.utils.swap_tensors(t1, t2) + with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"): + out.sum().backward() wr = weakref.ref(t1) with self.assertRaisesRegex(RuntimeError, "has weakref"): diff --git a/test/test_transformers.py b/test/test_transformers.py index 73f838143dd5..eea3b3fab8d9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -43,7 +43,8 @@ IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, - PLATFORM_SUPPORTS_CUDNN_ATTENTION + PLATFORM_SUPPORTS_CUDNN_ATTENTION, + tf32_on_and_off ) if TEST_FAIRSEQ: @@ -132,6 +133,10 @@ def get_platform_specific_sdpa(): return ret PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa() +# Indicate the Efficient attention backend can support: +# 1. sequence longher than 512 +# 2. head dimsion larger than 64 +MEM_EFF_CAPABILITY_MATCHES_SM80 = SM80OrLater or TEST_WITH_ROCM def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor: @@ -311,6 +316,7 @@ def test_transformerencoderlayer_src_mask(self, device, nhead): with torch.no_grad(): model(src, src_mask=src_mask) + @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -401,8 +407,9 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0. fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) - torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5) + self.assertEqual(fastpath_output_expanded, slowpath_output) + @tf32_on_and_off(0.001) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -446,7 +453,7 @@ def test_transformerencoder_square_input(self, with_no_grad, training, enable_ne [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] ).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + self.assertEqual(result, ref_output) @parametrize("batch_first", [True, False]) @parametrize("training", [True, False]) @@ -1393,7 +1400,7 @@ def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend): q = torch.randn(size, device=device, dtype=dtype) k = torch.randn(size, device=device, dtype=dtype) v = torch.randn(size, device=device, dtype=dtype) - with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"): + with self.assertWarnsRegex(UserWarning, "All fused kernels requires query, key and value to be 4 dimensional"): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1425,7 +1432,7 @@ def test_invalid_sequence_lengths(self, device, kernel: SDPBackend): make_tensor = partial(torch.rand, device=device, dtype=dtype) size = SdpaShape(2, 2, 0, 8) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."): + with self.assertWarnsRegex(UserWarning, "All fused kernels do not support zero seq_len_q or seq_len_kv."): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1440,7 +1447,7 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): size = SdpaShape(2, 2, 8, 8) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) q.as_strided_(size, [2, 2, 2, 2]) - with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."): + with self.assertWarnsRegex(UserWarning, "All fused kernels require the last dimension of the input to have stride 1."): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1973,7 +1980,7 @@ def ref(x): class TestSDPACudaOnly(NNTestCase): """ Used to test CUDA only functionality of scaled_dot_product_attention Quarks: - There is some trickiness with this function. It's runtime behavior + There is some trickiness with this function. Its runtime behavior is dependent on the CUDA architecture you are testing it on. See `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file. Summary: @@ -2140,9 +2147,34 @@ def convert_flash_attn_S_to_softmax( S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] + @skipIfRocm # No cuDNN Attention + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + def test_cudnn_attention_different_dk_dv(self, device): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) + batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64 + seq_len = 640 + q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) + k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) + v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) + query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) + + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + actual = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = torch.nn.functional.scaled_dot_product_attention( + query.contiguous().to(torch.float32), + key.contiguous().to(torch.float32), + value.contiguous().to(torch.float32), + attn_mask=None, dropout_p=0.0, is_causal=False) + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) - def test_mem_efficient_attetntion_mask_variants(self, device, mask_dim: List[int]): + def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) batch, num_heads, head_dim = 8, 8, 64 @@ -2255,6 +2287,8 @@ def test_singelton_head_dim_stride_ne_1(self, device): @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): + if TEST_WITH_ROCM and type == 'nested': + self.skipTest("ROCM does not support efficient attention on nested tensors, for now") make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 @@ -2347,9 +2381,9 @@ def rand_tensor(shape): math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) - self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) + self.assertEqual(actual_test, math_ref_test, atol=7e-3, rtol=7e-3) - @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Flash Attention was not built for this system") + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) @parametrize("is_causal", [True, False]) def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool): @@ -2465,7 +2499,12 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - if PLATFORM_SUPPORTS_FLASH_ATTENTION: + major, minor = torch.cuda.get_device_capability(device) + is_sm90_or_newer = major >= 9 + + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and is_sm90_or_newer: + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.CUDNN_ATTENTION.value + elif PLATFORM_SUPPORTS_FLASH_ATTENTION: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value else: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value @@ -2482,6 +2521,7 @@ def test_fused_sdp_choice(self, device, type: str): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value + @skipIfRocm # Missing triton.float32 ("triton" prefix is to locate skipped UTs), and deterministic algo @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") @parametrize("warn_only", [True, False]) def test_sdp_choice_with_determinism(self, device, warn_only): @@ -2494,6 +2534,7 @@ def test_sdp_choice_with_determinism(self, device, warn_only): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value + @skipIfRocm # Missing deterministic algo @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("warn_only", [True, False]) @@ -2503,7 +2544,8 @@ def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fus make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) - kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention" + kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else \ + "Flash Attention" if fused_kernel == SDPBackend.FLASH_ATTENTION else "cuDNN Attention" warning_context = ( self.assertWarnsRegex( UserWarning, @@ -2515,7 +2557,12 @@ def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fus with use_deterministic_algorithims(True, warn_only=warn_only): with sdpa_kernel(backends=[fused_kernel]): with warning_context: - torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() + if warn_only or fused_kernel != SDPBackend.CUDNN_ATTENTION: + torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() + else: + # cuDNN attention has no deterministic fallback + self.assertRaises(RuntimeError, lambda: + torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()) @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA") @@ -2572,13 +2619,16 @@ def test_mem_eff_backwards_determinism(self, device): @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512]) - @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512]) - @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64]) + @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 64, 128, 256, 512]) + @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 64, 128, 256, 512]) + @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [8, 16, 32, 64]) @parametrize("is_causal", [False, True]) @parametrize("dropout_p", [0.0, 0.22]) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [torch.float16, torch.float32]) @parametrize("scale", [None, "l1"]) def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -2591,6 +2641,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return + if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: + torch.cuda.empty_cache() # Prevent memory fragmentation seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2650,7 +2702,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) # Fudge Factor when dropout is enabled - dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0 + dropout_fudge_factor = 1.5 if dropout_p == 0.0 else 2.0 query_fudge_factor = dropout_fudge_factor grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) @@ -2660,6 +2712,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + if TEST_WITH_ROCM: + value_fudge_factor = max(2.0, value_fudge_factor) grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) @@ -2674,13 +2728,16 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") @parametrize("batch_size", [1, 8]) - @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512]) - @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512]) - @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64]) + @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 64, 128, 152, 256, 512]) + @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [4, 8, 37, 64, 128, 256, 512]) + @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [8, 16, 32, 64]) @parametrize("is_causal", [False]) @parametrize("dropout_p", [0.0, 0.22]) - @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if - SM80OrLater else [torch.float16, torch.float32]) + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 + else [torch.float16, torch.float32]) @parametrize("scale", [None, "l1"]) def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, @@ -2694,6 +2751,11 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return + if TEST_WITH_ROCM and dtype == torch.float32: + unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.") + return + if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: + torch.cuda.empty_cache() # Prevent memory fragmentation seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2763,8 +2825,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, # Fudge Factor when dropout is enabled dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75 mask_fudge_factor = 1.0 if attn_mask is None else 1.5 + query_fudge_factor = 2.0 - query_fudge_factor = dropout_fudge_factor grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances @@ -2772,6 +2834,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + if TEST_WITH_ROCM: + value_fudge_factor = max(2.0, value_fudge_factor) grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22 @@ -2806,6 +2870,8 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") + if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: + torch.cuda.empty_cache() # Prevent memory fragmentation scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2962,7 +3028,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d device=device, dtype=dtype, requires_grad=True) fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention - if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention) + if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention + if fused_kernel == SDPBackend.FLASH_ATTENTION else torch.ops.aten._scaled_dot_product_cudnn_attention) # Run the math kernel on low precision references query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) @@ -2980,6 +3047,10 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d kwargs["attn_bias"] = None if fused_kernel == SDPBackend.FLASH_ATTENTION: kwargs['return_debug_mask'] = dropout_p > 0.0 + if fused_kernel == SDPBackend.CUDNN_ATTENTION: + kwargs["compute_log_sumexp"] = True + if "return_debug_mask" in kwargs: + kwargs.pop("return_debug_mask") with torch.cuda.stream(s): # Create real output output_tuple = fused_op(query, key, value, **kwargs) @@ -3017,7 +3088,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d # Low Precision Math Reference out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - else: + # cuDNN attention doesn't support returning dropout mask + elif fused_kernel != SDPBackend.CUDNN_ATTENTION: # Create the dropout_mask dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, device) @@ -3035,37 +3107,38 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d with torch.cuda.graph(g1): out.backward(upstream_grad) g1.replay() - out_ref.backward(upstream_grad.to(out_ref.dtype)) - out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) - - # [Note] Fused Tolerances - # Establish the numerical error between the "true" high precision math output - # and the low precision math reference. We use this reference for the atol - # And we use the default rtol for the low precision type. - # We then provide a fudge factor for gradients respectively to account - # for the use of the fused kernel rather than the eager implemntation. - output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - - # Fudge Factor when dropout is enabled - dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 - - query_fudge_factor = dropout_fudge_factor - grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) - - # TODO: Investigate why grad_k needs larger tolerances - key_fudge_factor = 8 * dropout_fudge_factor - grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) - - value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 - grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) - - self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) - self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), - atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) - self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), - atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) - self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), - atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) + if fused_kernel != SDPBackend.CUDNN_ATTENTION or dropout_p == 0.0: + out_ref.backward(upstream_grad.to(out_ref.dtype)) + out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) + + # [Note] Fused Tolerances + # Establish the numerical error between the "true" high precision math output + # and the low precision math reference. We use this reference for the atol + # And we use the default rtol for the low precision type. + # We then provide a fudge factor for gradients respectively to account + # for the use of the fused kernel rather than the eager implemntation. + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) + + # Fudge Factor when dropout is enabled + dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 + + query_fudge_factor = dropout_fudge_factor + grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) + + # TODO: Investigate why grad_k needs larger tolerances + key_fudge_factor = 8 * dropout_fudge_factor + grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) + + value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 + grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) + + self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) + self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), + atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) + self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), + atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) + self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), + atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) @skipIfRocm # Nested Tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @@ -3189,8 +3262,9 @@ def _broadcast(t, batch_broadcasted, num_heads_broadcasted): query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(), attn_mask=None, dropout_p=0.0, is_causal=False) - self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2) + @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_query_dense(self, device): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) @@ -3351,6 +3425,7 @@ def run_test( forw_tolerances: Optional[Tolerances] = None, grad_tolerances: Optional[Tolerances] = None, backend=None, + causal_variant=None, ): if backend is not None: torch._dynamo.reset() @@ -3418,9 +3493,11 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis if causal_variant == CausalVariant.UPPER_LEFT: attn_bias = causal_upper_left(seq_len_q, seq_len_kv) else: + print(seq_len_q, seq_len_kv) attn_bias = causal_lower_right(seq_len_q, seq_len_kv) - self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]): + self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) @skipIfRocm # CausalVariant @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @@ -3451,7 +3528,8 @@ def test_causal_variants_compile(self, device, causal_variant: CausalVariant, sh else: attn_bias = causal_lower_right(seq_len_q, seq_len_kv) - self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) + with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]): + self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) diff --git a/test/test_type_hints.py b/test/test_type_hints.py index a4ae1768cd2a..2fba1ba2f9e4 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: typing"] import doctest diff --git a/test/test_type_info.py b/test/test_type_info.py index 97bb23e89c99..9160c31b4fb8 100644 --- a/test/test_type_info.py +++ b/test/test_type_info.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: typing"] from torch.testing._internal.common_utils import ( diff --git a/test/test_utils.py b/test/test_utils.py index 66d66b8874f1..df41b9b538be 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: unknown"] import os @@ -1177,9 +1178,11 @@ def test_device_mode_ops(self, device, dtype, op): kwargs.pop("device", None) with torch.device("meta"): r = func(sample.input, *sample.args, **kwargs) - self.assertTrue( - tree_all_only(torch.Tensor, lambda x: x.device.type == "meta", r) - ) + + def is_meta_device(x: torch.Tensor) -> bool: + return x.device.type == "meta" + + self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) instantiate_device_type_tests(TestDeviceUtils, globals()) diff --git a/test/test_xpu.py b/test/test_xpu.py index a3838f1d5a05..86a0bc6fa2b9 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,11 +1,13 @@ # Owner(s): ["module: intel"] +import collections import sys import tempfile import unittest import torch import torch.xpu._gpu_trace as gpu_trace +from torch.testing._internal.autocast_test_lists import AutocastTestLists from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyXPU, @@ -309,6 +311,134 @@ def test_serialization_array_with_empty(self): instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") +class TestXpuAutocast(TestCase): + def setUp(self): + super().setUp() + self.autocast_lists = AutocastTestLists(torch.device("xpu")) + + def tearDown(self): + del self.autocast_lists + super().tearDown() + + def _run_autocast_outofplace( + self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None + ): + # helper to cast args + def cast(val, to_type): + if isinstance(val, torch.Tensor): + return val.to(to_type) if val.is_floating_point() else val + elif isinstance(val, collections.abc.Iterable): + return type(val)(cast(v, to_type) for v in val) + else: + return val + + if add_kwargs is None: + add_kwargs = {} + fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 + self.assertFalse(torch.is_autocast_enabled()) + with torch.amp.autocast("xpu", dtype=fast_dtype): + self.assertTrue(torch.is_autocast_enabled()) + + out_type = out_type if out_type is not None else run_as_type + output = output_method = None + + # Try module.* variant, if requested: + if module is not None and hasattr(module, op): + output = getattr(module, op)(*args, **add_kwargs) + if isinstance(output, torch.Tensor): + self.assertTrue( + out_type == output.dtype, + f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", + ) + + # Try Tensor.* variant: + if hasattr(torch.Tensor, op): + output_method = getattr(args[0], op)(*args[1:], **add_kwargs) + if isinstance(output_method, torch.Tensor): + self.assertTrue( + out_type == output_method.dtype, + f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", + ) + + self.assertTrue( + (output is not None) or (output_method is not None), + f"{op} not found as an attribute on either Tensor or the requested module {module}", + ) + + # Accounts for ops that return Tensors, iterables, and other non-Tensors. + # For example, lstm_cell returns a tuple and equal returns bool. + def compare(first, second): + if isinstance(first, torch.Tensor): + return torch.equal(first, second) + elif isinstance(first, collections.abc.Iterable): + return all(compare(f, s) for f, s in zip(first, second)) + else: + return first == second + + # If both torch.* and Tensor.* variants were found, check outputs are identical + if (output is not None) and (output_method is not None): + self.assertTrue(type(output) == type(output_method)) + comparison = compare(output, output_method) + self.assertTrue( + comparison, f"torch.{op} result did not match Tensor.{op} result" + ) + + # Compare numerics to Python-side "autocasting" that (we expect) does the same thing + # as the C++-side autocasting, and should be bitwise accurate. + output_to_compare = output if output is not None else output_method + with torch.amp.autocast("xpu", enabled=False): + self.assertFalse(torch.is_autocast_enabled()) + + if module is not None and hasattr(module, op): + control = getattr(module, op)( + *cast(args, run_as_type), **add_kwargs + ) + else: + control = getattr(args[0].to(run_as_type), op)( + *cast(args[1:], run_as_type), **add_kwargs + ) + self.assertTrue(type(output_to_compare) == type(control)) + comparison = compare(output_to_compare, control) + self.assertTrue(comparison, f"torch.{op} result did not match control") + self.assertTrue(torch.is_autocast_enabled()) + self.assertFalse(torch.is_autocast_enabled()) + + def test_autocast_torch_fp16(self): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = True # skip cudnn op + if not skip_test: + self._run_autocast_outofplace(op, args, torch.float16) + + def test_autocast_torch_bf16(self): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = True # skip cudnn op + if not skip_test: + self._run_autocast_outofplace(op, args, torch.bfloat16) + + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.autocast_lists.torch_need_autocast_promote: + self._run_autocast_outofplace(op, args, torch.float32) + + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: + self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + + def test_xpu_autocast_dtype(self): + dtype = torch.get_autocast_dtype("xpu") + self.assertEqual(dtype, torch.float16) + mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu") + mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu") + with torch.amp.autocast("xpu"): + result = torch.mm(mat0_fp32, mat1_fp32) + self.assertEqual(result.dtype, torch.float16) + + class TestXpuTrace(TestCase): def setUp(self): torch._C._activate_gpu_trace() diff --git a/test/torch_np/numpy_tests/core/test_dtype.py b/test/torch_np/numpy_tests/core/test_dtype.py index ccff28135a1f..00ead3f705af 100644 --- a/test/torch_np/numpy_tests/core/test_dtype.py +++ b/test/torch_np/numpy_tests/core/test_dtype.py @@ -21,6 +21,7 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, xpassIfTorchDynamo, ) @@ -68,6 +69,7 @@ def test_equivalent_dtype_hashing(self): assert_(left == right) assert_(hash(left) == hash(right)) + @xfailIfTorchDynamo # TypeError -> InternalTorchDynamoError def test_invalid_types(self): # Make sure invalid type strings raise an error diff --git a/test/torch_np/numpy_tests/core/test_getlimits.py b/test/torch_np/numpy_tests/core/test_getlimits.py index ab5b08319db6..3be8bc2619ab 100644 --- a/test/torch_np/numpy_tests/core/test_getlimits.py +++ b/test/torch_np/numpy_tests/core/test_getlimits.py @@ -8,14 +8,17 @@ # from numpy.core.getlimits import _discovered_machar, _float_ma -from unittest import skipIf +from unittest import expectedFailure as xfail, skipIf import numpy from pytest import raises as assert_raises from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, run_tests, + subtest, TEST_WITH_TORCHDYNAMO, TestCase, xpassIfTorchDynamo, @@ -109,6 +112,7 @@ def test_basic_missing(self): getattr(finfo(dt), attr) +@instantiate_parametrized_tests class TestIinfo(TestCase): def test_basic(self): dts = list( @@ -129,11 +133,19 @@ def test_basic(self): with assert_raises((TypeError, ValueError)): iinfo("f4") - def test_unsigned_max(self): - types = np.sctypes["uint"] - for T in types: - max_calculated = T(0) - T(1) - assert_equal(iinfo(T).max, max_calculated) + @parametrize( + "T", + [ + np.uint8, + # xfail: unsupported add (uint[16,32,64]) + subtest(np.uint16, decorators=[xfail]), + subtest(np.uint32, decorators=[xfail]), + subtest(np.uint64, decorators=[xfail]), + ], + ) + def test_unsigned_max(self, T): + max_calculated = T(0) - T(1) + assert_equal(iinfo(T).max, max_calculated) class TestRepr(TestCase): diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index bf9aab8ebcee..76af79f62084 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -291,6 +291,7 @@ def test_otherflags(self): assert_equal(self.a.flags["X"], False) assert_equal(self.a.flags["WRITEBACKIFCOPY"], False) + @xfail # invalid dtype def test_string_align(self): a = np.zeros(4, dtype=np.dtype("|S4")) assert_(a.flags.aligned) @@ -298,6 +299,7 @@ def test_string_align(self): a = np.zeros(5, dtype=np.dtype("|S4")) assert_(a.flags.aligned) + @xfail # structured dtypes def test_void_align(self): a = np.zeros(4, dtype=np.dtype([("a", "i4"), ("b", "i4")])) assert_(a.flags.aligned) @@ -1856,7 +1858,7 @@ def test_searchsorted_floats(self, a): y = np.searchsorted(x, x[-1]) assert_equal(y, 2) - @xpassIfTorchDynamo # ( + @xfail # ( # reason="'searchsorted_out_cpu' not implemented for 'ComplexDouble'" # ) def test_searchsorted_complex(self): @@ -5983,6 +5985,11 @@ def test_unnamed_fields(self): self._check("i:f0:", [("f0", "i")]) +# NOTE: xpassIfTorchDynamo below +# 1. TODO: torch._numpy does not handle/model _CopyMode +# 2. order= keyword not supported (probably won't be) +# 3. Under TEST_WITH_TORCHDYNAMO many of these make it through due +# to a graph break leaving the _CopyMode to only be handled by numpy. @skipif(numpy.__version__ < "1.23", reason="CopyMode is new in NumPy 1.22") @xpassIfTorchDynamo @instantiate_parametrized_tests @@ -6011,6 +6018,7 @@ def test_scalars(self): with pytest.raises(ValueError): np.array(pyscalar, dtype=np.int64, copy=np._CopyMode.NEVER) + @xfail # TODO: handle `_CopyMode` properly in torch._numpy def test_compatible_cast(self): # Some types are compatible even though they are different, no # copy is necessary for them. This is mostly true for some integers diff --git a/test/torch_np/numpy_tests/core/test_scalarmath.py b/test/torch_np/numpy_tests/core/test_scalarmath.py index 9c535aefe016..d86595d9d3cc 100644 --- a/test/torch_np/numpy_tests/core/test_scalarmath.py +++ b/test/torch_np/numpy_tests/core/test_scalarmath.py @@ -739,6 +739,9 @@ def test_shift_all_bits(self, type_code, op): # gh-2449 dt = np.dtype(type_code) nbits = dt.itemsize * 8 + if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)): + raise SkipTest("NYI: bitshift uint64") + for val in [5, -5]: for shift in [nbits, nbits + 4]: val_scl = np.array(val).astype(dt)[()] diff --git a/test/torch_np/numpy_tests/lib/test_arraysetops.py b/test/torch_np/numpy_tests/lib/test_arraysetops.py index 34176ee3f3b7..73897bea6981 100644 --- a/test/torch_np/numpy_tests/lib/test_arraysetops.py +++ b/test/torch_np/numpy_tests/lib/test_arraysetops.py @@ -3,7 +3,7 @@ """Test functions for 1D array set operations. """ -from unittest import skipIf +from unittest import expectedFailure as xfail, skipIf import numpy @@ -34,7 +34,7 @@ @skipIf(numpy.__version__ < "1.24", reason="NP_VER: fails on NumPy 1.23.x") -@xpassIfTorchDynamo # (reason="TODO") +@skipIf(True, reason="TODO implement these ops") @instantiate_parametrized_tests class TestSetOps(TestCase): def test_intersect1d(self): @@ -531,6 +531,7 @@ def test_in1d_both_arrays_are_object(self): result = np.in1d(ar1, ar2) assert_array_equal(result, expected) + @xfail def test_in1d_both_arrays_have_structured_dtype(self): # Test arrays of a structured data type containing an integer field # and a field of dtype `object` allowing for arbitrary Python objects diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index d0eda87b0108..aea6c8ee38d9 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -3259,7 +3259,7 @@ def test_keepdims_2(self): subtest( [1, 7], decorators=[ - xpassIfTorchDynamo, + skip(reason="Keepdims wrapper incorrect for multiple q"), ], ), ], @@ -3273,13 +3273,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), subtest( (-3, -1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), ], @@ -3839,13 +3839,13 @@ def test_keepdims_2(self): subtest( (0, 1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), subtest( (-3, -1), decorators=[ - xpassIfTorchDynamo, + skip(reason="Tuple axes"), ], ), ], diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index 954fbf111484..7f8c145a05de 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -353,7 +353,7 @@ def test_signed_overflow_bounds(self): self.do_signed_overflow_bounds(np.short) self.do_signed_overflow_bounds(np.intc) - @xpassIfTorchDynamo # (reason="int->float conversin loses precision") + @xfail # (reason="int->float conversin loses precision") def test_signed_overflow_bounds_2(self): self.do_signed_overflow_bounds(np.int_) self.do_signed_overflow_bounds(np.longlong) diff --git a/test/torch_np/test_dtype.py b/test/torch_np/test_dtype.py index 42866adbe5c2..e288e54286e7 100644 --- a/test/torch_np/test_dtype.py +++ b/test/torch_np/test_dtype.py @@ -18,7 +18,7 @@ dtype_names = [ "bool_", *[f"int{w}" for w in [8, 16, 32, 64]], - "uint8", + *[f"uint{w}" for w in [8, 16, 32, 64]], *[f"float{w}" for w in [16, 32, 64]], *[f"complex{w}" for w in [64, 128]], ] diff --git a/test/typing/pass/cuda_steam.py b/test/typing/pass/cuda_steam.py index 0953effebbc2..bf9a40481b16 100644 --- a/test/typing/pass/cuda_steam.py +++ b/test/typing/pass/cuda_steam.py @@ -1,6 +1,6 @@ import torch -def foo(x: torch.Tensor): +def foo(x: torch.Tensor) -> None: stream = torch.cuda.current_stream() x.record_stream(stream) diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD new file mode 100644 index 000000000000..3cd0c3dbe94b --- /dev/null +++ b/third_party/cpp-httplib.BUILD @@ -0,0 +1,10 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "cpp-httplib", + hdrs = ["httplib.h"], + includes = [ + "/", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend index 150798fe9765..b740542818f3 160000 --- a/third_party/cudnn_frontend +++ b/third_party/cudnn_frontend @@ -1 +1 @@ -Subproject commit 150798fe976556078f443fdb059a1ff0361f58a2 +Subproject commit b740542818f36857acf7f9853f749bbad4118c65 diff --git a/third_party/ios-cmake b/third_party/ios-cmake deleted file mode 160000 index 8abaed637d56..000000000000 --- a/third_party/ios-cmake +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8abaed637d56f1337d6e1d2c4026e25c1eade724 diff --git a/third_party/kineto b/third_party/kineto index be1317644c68..8681ff11e1fa 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit be1317644c68b4bfc4646024a6b221066e430031 +Subproject commit 8681ff11e1fa54da39023076c5c43eddd87b7a8a diff --git a/third_party/xpu.txt b/third_party/xpu.txt index d3e312dadded..07950b62467f 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -aba5d332bb88d422a1256bb2ca5f60243ffc270f +97d692eb8c4b3afab17700a2fd918adcea0cba45 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4922513f295d..81bd19b8e185 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2820,13 +2820,13 @@ output_differentiability: [True, False, False, False, False] query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right) -- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) -- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - output_differentiability: [True, False, False, False, False, False, False, False, False] - query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) +- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset) + output_differentiability: [True, False, False, False] + query, key, value: _scaled_dot_product_cudnn_attention_backward(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, dropout_p, is_causal, scale) # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index b9651ea2da80..6abb13d244e9 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -305,6 +305,7 @@ "linalg_eig", "diagonal_copy", "diagonal_scatter", + "alias_copy", "select_backward", "diagonal_backward", "slice_backward", diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 3e3e06d54115..4d10b3db1aa3 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -13,7 +13,6 @@ from . import which from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file from .env import BUILD_DIR, check_negative_env_flag, IS_64BIT, IS_DARWIN, IS_WINDOWS -from .numpy_ import NUMPY_INCLUDE_DIR, USE_NUMPY def _mkdir_p(d: str) -> None: @@ -285,7 +284,7 @@ def generate( "BUILD_TEST": build_test, # Most library detection should go to CMake script, except this one, which Python can do a much better job # due to NumPy's inherent Pythonic nature. - "USE_NUMPY": USE_NUMPY, + "USE_NUMPY": not check_negative_env_flag("USE_NUMPY"), } ) @@ -309,7 +308,6 @@ def generate( args, Python_EXECUTABLE=sys.executable, TORCH_BUILD_VERSION=version, - NUMPY_INCLUDE_DIR=NUMPY_INCLUDE_DIR, **build_options, ) diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index d87e97a2bb5a..eed5198ca9f2 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -21,6 +21,8 @@ BUILD_DIR = "build" +LIBTORCH_PKG_NAME = "libtorchsplit" + def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] diff --git a/tools/setup_helpers/numpy_.py b/tools/setup_helpers/numpy_.py deleted file mode 100644 index e93fcfd24707..000000000000 --- a/tools/setup_helpers/numpy_.py +++ /dev/null @@ -1,24 +0,0 @@ -"""NumPy helper. - -Note: If you plan to add a library detection script like this one, consider it twice. Most library detection should go -to CMake script. This one is an exception, because Python code can do a much better job due to NumPy's inherent Pythonic -nature. -""" - -from .env import check_negative_env_flag - - -# Set USE_NUMPY to what the user wants, because even if we fail here, cmake -# will check for the presence of NumPy again (`cmake/Dependencies.cmake`). -USE_NUMPY = not check_negative_env_flag("USE_NUMPY") -NUMPY_INCLUDE_DIR = None - -if USE_NUMPY: - try: - import numpy as np - except ImportError: - pass - else: - # To reach here, the user must has not disabled NumPy build and the - # NumPy library is present in the system. - NUMPY_INCLUDE_DIR = np.get_include() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index d212b17e0e8e..10a44af747be 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -68,6 +68,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES ${TORCH_ROOT}/third_party/onnx ${TORCH_ROOT}/third_party/flatbuffers/include ${TORCH_ROOT}/third_party/kineto/libkineto/include + ${TORCH_ROOT}/third_party/cpp-httplib ${TORCH_SRC_DIR}/csrc ${TORCH_SRC_DIR}/csrc/api/include @@ -77,9 +78,10 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${LIBSHM_SRCDIR}) set(TORCH_PYTHON_LINK_LIBRARIES - python::python + Python::Module pybind::pybind11 opentelemetry::api + httplib shm fmt::fmt-header-only ATEN_CPU_FILES_GEN_LIB) @@ -296,11 +298,10 @@ endif() add_library(torch_python SHARED ${TORCH_PYTHON_SRCS}) -add_dependencies(torch_python Caffe2_PROTO) add_dependencies(torch_python onnx_proto) # Avoid numpy for the DEPLOY build if(USE_NUMPY) - target_link_libraries(torch_python PRIVATE numpy::numpy) + target_link_libraries(torch_python PRIVATE Python::NumPy) target_compile_definitions(torch_python PRIVATE USE_NUMPY) endif() diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 24f9f0f9e9fb..9476acb75791 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -1,5 +1,6 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs import builtins from typing import ( diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index d4dbee20466e..135ba3c27757 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1,5 +1,6 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs import builtins from enum import Enum, IntEnum @@ -1195,8 +1196,7 @@ def _conv_determine_backend_memory_format( def _has_storage(x: Tensor) -> _bool: ... def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... -def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... -def _check_tp_alloc_is_default(cls: Type) -> _bool: ... +def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ @@ -1749,6 +1749,7 @@ def _mps_emptyCache() -> None: ... def _mps_setMemoryFraction(fraction: _float) -> None: ... def _mps_currentAllocatedMemory() -> _int: ... def _mps_driverAllocatedMemory() -> _int: ... +def _mps_recommendedMaxMemory() -> _int: ... def _mps_is_available() -> _bool: ... def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ... def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ... @@ -1885,6 +1886,21 @@ def _nccl_reduce_scatter( comms: Optional[Sequence[object]], ) -> None: ... def _rocm_is_backward_pass() -> _bool: ... +def _cuda_tunableop_enable(val: _bool) -> None: ... +def _cuda_tunableop_is_enabled() -> _bool: ... +def _cuda_tunableop_tuning_enable(val: _bool) -> None: ... +def _cuda_tunableop_tuning_is_enabled() -> _bool: ... +def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ... +def _cuda_tunableop_get_max_tuning_duration() -> _int: ... +def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ... +def _cuda_tunableop_get_max_tuning_iterations() -> _int: ... +def _cuda_tunableop_set_filename(filename: str, insert_device_ordinal: Optional[_bool]) -> None: ... +def _cuda_tunableop_get_filename() -> str: ... +def _cuda_tunableop_write_file(filename: Optional[str]) -> _bool: ... +def _cuda_tunableop_read_file(filename: Optional[str]) -> _bool: ... +def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ... +def _cuda_tunableop_get_results() -> Tuple[str, str, str, _float]: ... +def _cuda_tunableop_get_validators() -> Tuple[str, str]: ... class _CudaDeviceProperties: name: str @@ -1921,6 +1937,7 @@ class _SDPBackend(Enum): EFFICIENT_ATTENTION = 2 CUDNN_ATTENTION = 3 +def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... @@ -2317,3 +2334,14 @@ def _save_pickle(obj: Any) -> bytes: ... # Defined in torch/csrc/jit/runtime/static/init.cpp def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ... def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ... + +# Defined in torch/csrc/fx/node.cpp +class _NodeBase: + _erased: _bool + _prev: "_NodeBase" + _next: "_NodeBase" + +class _NodeIter(Iterator): + def __init__(self, root: _NodeBase, reversed: _bool) -> None: ... + def __iter__(self) -> Iterator[_NodeBase]: ... + def __next__(self) -> _NodeBase: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 118d913f6815..05a791725608 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import Any, Callable, List, Optional, Set diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 075fecf45d5a..37794bd7c10b 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -2,4 +2,6 @@ from torch.types import _bool # Defined in torch/csrc/cpu/Module.cpp -def _is_cpu_support_vnni() -> _bool: ... +def _is_cpu_support_avx2() -> _bool: ... +def _is_cpu_support_avx512() -> _bool: ... +def _is_cpu_support_avx512_vnni() -> _bool: ... diff --git a/torch/_C/_distributed_autograd.pyi b/torch/_C/_distributed_autograd.pyi index f4c91304a1b1..dc2a9e9488a9 100644 --- a/torch/_C/_distributed_autograd.pyi +++ b/torch/_C/_distributed_autograd.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Set import torch diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 1a3e4ea63342..cffbf22219c8 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum @@ -94,6 +95,10 @@ class Logger: def _set_uneven_input_join(self) -> None: ... def _set_static_graph(self) -> None: ... +class _WorkerServer: + def __init__(self, socket_path: str) -> None: ... + def shutdown(self) -> None: ... + def get_debug_level(): ... def set_debug_level(): ... def set_debug_level_from_env(): ... @@ -219,7 +224,7 @@ class _ControlCollectives: def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... def scatter_recv(self, key: str, timeout: timedelta) -> str: ... def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... - def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ... + def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ... class _StoreCollectives(_ControlCollectives): def __init__(self, store: Store, rank: int, world_size: int) -> None: ... diff --git a/torch/_C/_distributed_rpc.pyi b/torch/_C/_distributed_rpc.pyi index 7909e0b8e33c..ded7061bbd49 100644 --- a/torch/_C/_distributed_rpc.pyi +++ b/torch/_C/_distributed_rpc.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" from datetime import timedelta from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index f3ad6f722827..14321b2f946f 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types from typing import List, NewType, Optional diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 2de2f10cd328..6b1cf00bce41 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional, Union import torch diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index 111113221a0c..0180586d0bc3 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import Optional, Tuple diff --git a/torch/_C/_lazy.pyi b/torch/_C/_lazy.pyi index ceaaedee2102..f4f57ee56b34 100644 --- a/torch/_C/_lazy.pyi +++ b/torch/_C/_lazy.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List from torch import Tensor diff --git a/torch/_C/_lazy_ts_backend.pyi b/torch/_C/_lazy_ts_backend.pyi index ce833c5ec2e4..b5e69583dbb9 100644 --- a/torch/_C/_lazy_ts_backend.pyi +++ b/torch/_C/_lazy_ts_backend.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # defined in torch/csrc/lazy/python/init.cpp from typing import Any, List, Tuple diff --git a/torch/_C/_nvtx.pyi b/torch/_C/_nvtx.pyi index f7ff779d8ad7..79c9cc2c4b9b 100644 --- a/torch/_C/_nvtx.pyi +++ b/torch/_C/_nvtx.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Defined in torch/csrc/cuda/shared/nvtx.cpp def rangePushA(message: str) -> int: ... def rangePop() -> int: ... diff --git a/torch/_C/return_types.pyi.in b/torch/_C/return_types.pyi.in index 458a076d7bfe..fc1e2974bd4d 100644 --- a/torch/_C/return_types.pyi.in +++ b/torch/_C/return_types.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # ${generated_comment} from typing import ( diff --git a/torch/_C_flatbuffer/__init__.pyi b/torch/_C_flatbuffer/__init__.pyi index 3a2ff059b0ed..38750ed26aa2 100644 --- a/torch/_C_flatbuffer/__init__.pyi +++ b/torch/_C_flatbuffer/__init__.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch._C import LiteScriptModule, ScriptModule def _load_mobile_module_from_file(filename: str): ... diff --git a/torch/__config__.py b/torch/__config__.py index f7e3e209654a..fdb091032759 100644 --- a/torch/__config__.py +++ b/torch/__config__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/__init__.py b/torch/__init__.py index c2bf4a802838..aa68247ed3a8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1,5 +1,4 @@ - -r""" +""" The torch package contains data structures for multi-dimensional tensors and defines mathematical operations over these tensors. Additionally, it provides many utilities for efficient serialization of @@ -9,14 +8,22 @@ on an NVIDIA GPU with compute capability >= 3.0. """ +# mypy: allow-untyped-defs + +import builtins +import ctypes +import glob +import importlib +import importlib.util +import inspect import math import os -import sys import platform +import sys import textwrap -import ctypes -import inspect import threading +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union + # multipy/deploy is setting this import before importing torch, this is the most # reliable way we have to detect if we're running within deploy. @@ -24,139 +31,235 @@ def _running_with_deploy(): return sys.modules.get("torch._meta_registrations", None) is object -from ._utils import _import_dotted_name, classproperty -from ._utils import _functionalize_sync as _sync -from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \ - USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS + +from torch._utils import ( + _functionalize_sync as _sync, + _import_dotted_name, + classproperty, +) +from torch._utils_internal import ( + get_file_path, + prepare_multiprocessing_environment, + USE_GLOBAL_DEPS, + USE_RTLD_GLOBAL_WITH_LIBTORCH, +) # TODO(torch_deploy) figure out how to freeze version.py in fbcode build if _running_with_deploy(): __version__ = "torch-deploy-1.8" else: - from .torch_version import __version__ as __version__ + from torch.torch_version import __version__ as __version__ -from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union, List -import builtins __all__ = [ - 'typename', 'is_tensor', 'is_storage', - 'set_default_tensor_type', 'set_default_device', 'get_default_device', - 'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed', - 'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul', - 'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode', - 'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage', - 'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage', - 'TypedStorage', 'UntypedStorage', - 'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', - 'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor', - 'lobpcg', 'use_deterministic_algorithms', - 'are_deterministic_algorithms_enabled', - 'is_deterministic_algorithms_warn_only_enabled', - 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', - 'set_float32_matmul_precision', 'get_float32_matmul_precision', - 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', - 'SymBool', 'sym_not', 'unravel_index', - 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap', - 'export', 'autocast', 'cond', 'GradScaler', - 'get_device_module', + "BoolStorage", + "BoolTensor", + "ByteStorage", + "ByteTensor", + "CharStorage", + "CharTensor", + "DoubleStorage", + "DoubleTensor", + "FloatStorage", + "FloatTensor", + "GradScaler", + "IntStorage", + "IntTensor", + "LongStorage", + "LongTensor", + "ShortStorage", + "ShortTensor", + "SymBool", + "SymFloat", + "SymInt", + "Tensor", + "TypedStorage", + "UntypedStorage", + "are_deterministic_algorithms_enabled", + "autocast", + "chunk", + "compile", + "cond", + "enable_grad", + "export", + "get_default_device", + "get_deterministic_debug_mode", + "get_device_module", + "get_float32_matmul_precision", + "get_rng_state", + "inference_mode", + "initial_seed", + "is_deterministic_algorithms_warn_only_enabled", + "is_storage", + "is_tensor", + "is_warn_always_enabled", + "load", + "lobpcg", + "manual_seed", + "matmul", + "no_grad", + "rand", + "randn", + "save", + "seed", + "set_default_device", + "set_default_tensor_type", + "set_deterministic_debug_mode", + "set_float32_matmul_precision", + "set_printoptions", + "set_rng_state", + "set_warn_always", + "split", + "stack", + "sym_float", + "sym_int", + "sym_ite", + "sym_max", + "sym_min", + "sym_not", + "typename", + "unravel_index", + "use_deterministic_algorithms", + "vmap", ] +# Please keep this list sorted +assert __all__ == sorted(__all__) + ################################################################################ # Load the extension module ################################################################################ -if sys.platform == 'win32': - import sysconfig - pfiles_path = os.getenv('ProgramFiles', 'C:\\Program Files') - py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') - th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') - usebase_path = os.path.join(sysconfig.get_config_var("userbase"), 'Library', 'bin') - - # When users create a virtualenv that inherits the base environment, - # we will need to add the corresponding library directory into - # DLL search directories. Otherwise, it will rely on `PATH` which - # is dependent on user settings. - if sys.exec_prefix != sys.base_exec_prefix: - base_py_dll_path = os.path.join(sys.base_exec_prefix, 'Library', 'bin') - else: - base_py_dll_path = '' +if sys.platform == "win32": - dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, base_py_dll_path, usebase_path])) + def _load_dll_libraries(): + import sysconfig - if all(not os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths): - nvtoolsext_dll_path = os.path.join( - os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64') - else: - nvtoolsext_dll_path = '' - - from .version import cuda as cuda_version - import glob - if cuda_version and all(not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths): - cuda_version_1 = cuda_version.replace('.', '_') - cuda_path_var = 'CUDA_PATH_V' + cuda_version_1 - default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version) - cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin') - else: - cuda_path = '' + from torch.version import cuda as cuda_version - dll_paths.extend(filter(os.path.exists, [nvtoolsext_dll_path, cuda_path])) + pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files") + py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin") + th_dll_path = os.path.join(os.path.dirname(__file__), "lib") + usebase_path = os.path.join( + sysconfig.get_config_var("userbase"), "Library", "bin" + ) - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') - prev_error_mode = kernel32.SetErrorMode(0x0001) + # When users create a virtualenv that inherits the base environment, + # we will need to add the corresponding library directory into + # DLL search directories. Otherwise, it will rely on `PATH` which + # is dependent on user settings. + if sys.exec_prefix != sys.base_exec_prefix: + base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin") + else: + base_py_dll_path = "" + + dll_paths = [ + p + for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path) + if os.path.exists(p) + ] + + if not builtins.any( + os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths + ): + nvtoolsext_dll_path = os.path.join( + os.getenv( + "NVTOOLSEXT_PATH", + os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"), + ), + "bin", + "x64", + ) + else: + nvtoolsext_dll_path = "" + + if cuda_version and builtins.all( + not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths + ): + cuda_version_1 = cuda_version.replace(".", "_") + cuda_path_var = "CUDA_PATH_V" + cuda_version_1 + default_path = os.path.join( + pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}" + ) + cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin") + else: + cuda_path = "" - kernel32.LoadLibraryW.restype = ctypes.c_void_p - if with_load_library_flags: - kernel32.LoadLibraryExW.restype = ctypes.c_void_p + dll_paths.extend( + p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p) + ) - for dll_path in dll_paths: - os.add_dll_directory(dll_path) + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) - try: - ctypes.CDLL('vcruntime140.dll') - ctypes.CDLL('msvcp140.dll') - ctypes.CDLL('vcruntime140_1.dll') - except OSError: - print('''Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. - It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe''') - - dlls = glob.glob(os.path.join(th_dll_path, '*.dll')) - path_patched = False - for dll in dlls: - is_loaded = False + kernel32.LoadLibraryW.restype = ctypes.c_void_p if with_load_library_flags: - res = kernel32.LoadLibraryExW(dll, None, 0x00001100) - last_error = ctypes.get_last_error() - if res is None and last_error != 126: - err = ctypes.WinError(last_error) - err.strerror += f' Error loading "{dll}" or one of its dependencies.' - raise err - elif res is not None: - is_loaded = True - if not is_loaded: - if not path_patched: - os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']]) - path_patched = True - res = kernel32.LoadLibraryW(dll) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error loading "{dll}" or one of its dependencies.' - raise err - - kernel32.SetErrorMode(prev_error_mode) + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + for dll_path in dll_paths: + os.add_dll_directory(dll_path) + + try: + ctypes.CDLL("vcruntime140.dll") + ctypes.CDLL("msvcp140.dll") + ctypes.CDLL("vcruntime140_1.dll") + except OSError: + print( + textwrap.dedent( + """ + Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. + It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe + """ + ).strip() + ) + + dlls = glob.glob(os.path.join(th_dll_path, "*.dll")) + path_patched = False + for dll in dlls: + is_loaded = False + if with_load_library_flags: + res = kernel32.LoadLibraryExW(dll, None, 0x00001100) + last_error = ctypes.get_last_error() + if res is None and last_error != 126: + err = ctypes.WinError(last_error) + err.strerror += ( + f' Error loading "{dll}" or one of its dependencies.' + ) + raise err + elif res is not None: + is_loaded = True + if not is_loaded: + if not path_patched: + os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) + path_patched = True + res = kernel32.LoadLibraryW(dll) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += ( + f' Error loading "{dll}" or one of its dependencies.' + ) + raise err + + kernel32.SetErrorMode(prev_error_mode) + + _load_dll_libraries() + del _load_dll_libraries def _preload_cuda_deps(lib_folder, lib_name): """Preloads cuda deps if they could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed - assert platform.system() == 'Linux', 'Should only be called on Linux' - import glob + assert platform.system() == "Linux", "Should only be called on Linux" + lib_path = None for path in sys.path: - nvidia_path = os.path.join(path, 'nvidia') + nvidia_path = os.path.join(path, "nvidia") if not os.path.exists(nvidia_path): continue - candidate_lib_paths = glob.glob(os.path.join(nvidia_path, lib_folder, 'lib', lib_name)) + candidate_lib_paths = glob.glob( + os.path.join(nvidia_path, lib_folder, "lib", lib_name) + ) if candidate_lib_paths and not lib_path: lib_path = candidate_lib_paths[0] if lib_path: @@ -168,41 +271,87 @@ def _preload_cuda_deps(lib_folder, lib_name): # See Note [Global dependencies] def _load_global_deps() -> None: - if _running_with_deploy() or platform.system() == 'Windows': + LIBTORCH_PKG_NAME = "libtorchsplit" + + def find_package_path(package_name): + spec = importlib.util.find_spec(package_name) + if spec: + # The package might be a namespace package, so get_data may fail + try: + loader = spec.loader + if loader is not None: + file_path = loader.get_filename() # type: ignore[attr-defined] + return os.path.dirname(file_path) + except AttributeError: + pass + return None + + def load_shared_libraries(library_path): + lib_dir = os.path.join(library_path, "lib") + if not os.path.exists(lib_dir): + return + + # Find all shared library files with the appropriate extension + library_files = [f for f in os.listdir(lib_dir) if f.endswith(lib_ext)] + if not library_files: + return + + for lib_file in library_files: + lib_path = os.path.join(lib_dir, lib_file) + try: + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + except OSError as err: + print(f"Failed to load {lib_path}: {err}") + + if _running_with_deploy() or platform.system() == "Windows": return - lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') + # Determine the file extension based on the platform + lib_ext = ".dylib" if platform.system() == "Darwin" else ".so" + lib_name = f"libtorch_global_deps{lib_ext}" here = os.path.abspath(__file__) - lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) + global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name) + + split_build_lib_name = LIBTORCH_PKG_NAME + library_path = find_package_path(split_build_lib_name) + if library_path: + global_deps_lib_path = os.path.join(library_path, "lib", lib_name) try: - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is cuda_libs: Dict[str, str] = { - 'cublas': 'libcublas.so.*[0-9]', - 'cudnn': 'libcudnn.so.*[0-9]', - 'cuda_nvrtc': 'libnvrtc.so.*[0-9]', - 'cuda_runtime': 'libcudart.so.*[0-9]', - 'cuda_cupti': 'libcupti.so.*[0-9]', - 'cufft': 'libcufft.so.*[0-9]', - 'curand': 'libcurand.so.*[0-9]', - 'cusolver': 'libcusolver.so.*[0-9]', - 'cusparse': 'libcusparse.so.*[0-9]', - 'nccl': 'libnccl.so.*[0-9]', - 'nvtx': 'libnvToolsExt.so.*[0-9]', + "cublas": "libcublas.so.*[0-9]", + "cudnn": "libcudnn.so.*[0-9]", + "cuda_nvrtc": "libnvrtc.so.*[0-9]", + "cuda_runtime": "libcudart.so.*[0-9]", + "cuda_cupti": "libcupti.so.*[0-9]", + "cufft": "libcufft.so.*[0-9]", + "curand": "libcurand.so.*[0-9]", + "cusolver": "libcusolver.so.*[0-9]", + "cusparse": "libcusparse.so.*[0-9]", + "nccl": "libnccl.so.*[0-9]", + "nvtx": "libnvToolsExt.so.*[0-9]", } - is_cuda_lib_err = [lib for lib in cuda_libs.values() if lib.split('.')[0] in err.args[0]] + is_cuda_lib_err = [ + lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] + ] if not is_cuda_lib_err: raise err for lib_folder, lib_name in cuda_libs.items(): _preload_cuda_deps(lib_folder, lib_name) - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) + + if library_path: + # loading libtorch_global_deps first due its special logic + load_shared_libraries(library_path) -if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \ - (_running_with_deploy() or platform.system() != 'Windows'): +if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( + _running_with_deploy() or platform.system() != "Windows" +): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: # @@ -221,7 +370,9 @@ def _load_global_deps() -> None: # old_flags = sys.getdlopenflags() sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) + from torch._C import * # noqa: F403 + sys.setdlopenflags(old_flags) del old_flags @@ -239,10 +390,6 @@ def _load_global_deps() -> None: _load_global_deps() from torch._C import * # noqa: F403 -# Appease the type checker; ordinarily this binding is inserted by the -# torch._C module initialization code in C -if TYPE_CHECKING: - from . import _C as _C class SymInt: """ @@ -267,38 +414,125 @@ def __index__(self): # Magic methods installed by torch.fx.experimental.sym_node + def __round__(self, ndigits=None): + return self + + def __truediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__float_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_truediv__(other) + + def __rtruediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rfloat_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_truediv__(other) + + def __floordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(sym_float(self) / other)) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_floordiv__(other) + + def __rfloordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(other / sym_float(self))) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_floordiv__(other) + + # nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + def __pow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__pow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + # Guards! This guard is necessary because we need to know it to + # determine the output type of this operation + if other >= 0: + return self.__pow_by_natural__(other) + else: + # Mercifully, when the exponent is negative, Python just promotes + # to doubles and does a float pow: + # + # if (Py_SIZE(b) < 0 && c == NULL) { + # /* if exponent is negative and there's no modulus: + # return a float. This works because we know + # that this calls float_pow() which converts its + # arguments to double. */ + # Py_DECREF(a); + # Py_DECREF(b); + # return PyFloat_Type.tp_as_number->nb_power(v, w, x); + # } + return sym_float(self).__pow__(sym_float(other)) + + def __rpow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rpow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + if self >= 0: # self is exponent + return self.__rpow_by_natural__(other) + else: + return sym_float(self).__rpow__(sym_float(other)) + def __eq__(self, other: object) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __lt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __gt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __le__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __ge__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __add__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __mul__(self, other) -> "SymInt": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") + + def __pow_by_natural__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __rpow_by_natural__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __int_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rint_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __int_floordiv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rint_floordiv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") def __sym_max__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_min__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_float__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __neg__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __repr__(self): return str(self.node) @@ -310,6 +544,7 @@ def __hash__(self) -> builtins.int: # We could support constant SymInts as well, but not doing it for now raise TypeError("unhashable type: non-nested SymInt") + class SymFloat: """ Like an float (including magic methods), but redirects all operations on the @@ -322,45 +557,92 @@ def __init__(self, node): # class has a field named node that stores SymNode self.node = node + def __truediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__float_truediv__(sym_float(other)) + + def __rtruediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__rfloat_truediv__(sym_float(other)) + + def __floordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(self / sym_float(other))) + + def __rfloordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(sym_float(other) / self)) + def __bool__(self): return self.node.bool_() + # Symbolic power does NOT work with negative base, this is to avoid + # potential complex outputs + def __pow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(self >= 0) + return self.__float_pow__(other) + + def __rpow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(other >= 0) + return self.__rfloat_pow__(other) + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __lt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __gt__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __le__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __ge__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") + + def __float_pow__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rfloat_pow__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __float_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") + + def __rfloat_truediv__(self, other) -> "SymFloat": + raise TypeError("type stub not overridden") def __trunc__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_max__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_min__(self, other): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_int__(self): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def is_integer(self): """Return True if the float is an integer.""" - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __repr__(self): return self.node.str() + class SymBool: """ Like an bool (including magic methods), but redirects all operations on the @@ -384,10 +666,10 @@ def __int__(self): # Magic methods installed by torch.fx.experimental.sym_node def __and__(self, other) -> "SymBool": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __or__(self, other) -> "SymBool": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") # We very carefully define __sym_not__, and not a number of other # plausible alternatives: @@ -407,13 +689,13 @@ def __or__(self, other) -> "SymBool": # so we reuse the conventional operators there for readability. # def __sym_not__(self) -> "SymBool": - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __sym_ite__(self, then_val, else_val): - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __eq__(self, other) -> builtins.bool: - raise AssertionError("type stub not overridden") + raise TypeError("type stub not overridden") def __repr__(self): return str(self.node) @@ -424,121 +706,140 @@ def __hash__(self): else: raise TypeError("unhashable type: SymBool") + def sym_not(a): - r""" SymInt-aware utility for logical negation. + r"""SymInt-aware utility for logical negation. Args: a (SymBool or bool): Object to negate """ import sympy - from .overrides import has_torch_function_unary, handle_torch_function - if has_torch_function_unary(a): - return handle_torch_function(sym_not, (a,), a) - if hasattr(a, '__sym_not__'): + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_not, (a,), a) + if hasattr(a, "__sym_not__"): return a.__sym_not__() if isinstance(a, sympy.Basic): return ~a # type: ignore[operator] return not a + def sym_float(a): - r""" SymInt-aware utility for float casting. + r"""SymInt-aware utility for float casting. Args: a (SymInt, SymFloat, or object): Object to cast """ - from .overrides import has_torch_function_unary, handle_torch_function - - if has_torch_function_unary(a): - return handle_torch_function(sym_float, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_float, (a,), a) if isinstance(a, SymFloat): return a - elif hasattr(a, '__sym_float__'): + elif hasattr(a, "__sym_float__"): return a.__sym_float__() - return py_float(a) # type: ignore[operator] + return builtins.float(a) # type: ignore[operator] def sym_int(a): - r""" SymInt-aware utility for int casting. + r"""SymInt-aware utility for int casting. Args: a (SymInt, SymFloat, or object): Object to cast """ - from .overrides import has_torch_function_unary, handle_torch_function - - if has_torch_function_unary(a): - return handle_torch_function(sym_int, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(sym_int, (a,), a) if isinstance(a, SymInt): return a elif isinstance(a, SymFloat): return math.trunc(a) - return py_int(a) # type: ignore[operator] + return builtins.int(a) # type: ignore[operator] -def sym_max(a, b): - """ SymInt-aware utility for max().""" - from .overrides import has_torch_function, handle_torch_function - if has_torch_function((a, b)): - return handle_torch_function(sym_max, (a, b), a, b) +def sym_max(a, b): + """ + SymInt-aware utility for max which avoids branching on a < b. + Unlike builtins.max(), this only works for int/float, and it always + promotes to float if any argument is float (unlike builtins.max, which + will faithfully preserve the type of the input argument). + """ + if overrides.has_torch_function((a, b)): + return overrides.handle_torch_function(sym_max, (a, b), a, b) if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # NB: If you actually care about preserving output type exactly - # if you do something like max(0, 0.0), it is NOT sound to treat - # min/max as commutative + # Due to promotion semantics, this is operator is commutative: + # max(1, 1.0) === max(1.0, 1) === 1.0 return b.__sym_max__(a) - return builtins.max(a, b) # type: ignore[operator] + # TODO: Probably can make bool work too, just lazy + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.max(a, b)) + else: + return builtins.max(a, b) -def sym_min(a, b): - """ SymInt-aware utility for max().""" - from .overrides import has_torch_function, handle_torch_function - if has_torch_function((a, b)): - return handle_torch_function(sym_min, (a, b), a, b) +def sym_min(a, b): + """SymInt-aware utility for min().""" + if overrides.has_torch_function((a, b)): + return overrides.handle_torch_function(sym_min, (a, b), a, b) if isinstance(a, (SymInt, SymFloat)): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - return builtins.min(a, b) # type: ignore[operator] + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.min(a, b)) + else: + return builtins.min(a, b) -# Drop in replacement for math.sqrt, math.sin, math.cos etc -current_module = sys.modules[__name__] +# Drop in replacement for math.sqrt, math.sin, math.cos etc def _get_sym_math_fn(name): def fn(a): - from .overrides import has_torch_function_unary, handle_torch_function - - if has_torch_function_unary(a): - return handle_torch_function(fn, (a,), a) + if overrides.has_torch_function_unary(a): + return overrides.handle_torch_function(fn, (a,), a) if hasattr(a, f"__sym_{name}__"): return getattr(a, f"__sym_{name}__")() return getattr(math, name)(a) return fn -for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"): - sym_name = f"_sym_{name}" - fn = _get_sym_math_fn(name) - fn.__qualname__ = fn.__name__ = sym_name - setattr(current_module, sym_name, fn) + +__fn, __name, __sym_name = None, "", "" +for __name in ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", +): + __sym_name = f"_sym_{__name}" + __fn = _get_sym_math_fn(__name) + __fn.__qualname__ = __fn.__name__ = __sym_name + globals()[__sym_name] = __fn + +del __fn, __name, __sym_name, _get_sym_math_fn # Adding temporary shortcut -sym_sqrt = current_module._sym_sqrt +sym_sqrt = globals()["_sym_sqrt"] __all__.append("sym_sqrt") -del fn, name, sym_name, current_module # type: ignore[possibly-undefined] - def sym_ite(b, t, f): - from .overrides import has_torch_function, handle_torch_function - - if has_torch_function((b, t, f)): - return handle_torch_function(sym_ite, (b, t, f), b, t, f) + if overrides.has_torch_function((b, t, f)): + return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f) assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) if isinstance(b, SymBool): return b.__sym_ite__(t, f) return t if b else f + # Check to see if we can load C extensions, and if not provide some guidance # on what the problem might be. try: @@ -549,44 +850,61 @@ def sym_ite(b, t, f): # The __file__ check only works for Python 3.7 and above. if _C_for_compiled_check.__file__ is None: - raise ImportError(textwrap.dedent(''' - Failed to load PyTorch C extensions: - It appears that PyTorch has loaded the `torch/_C` folder - of the PyTorch repository rather than the C extensions which - are expected in the `torch._C` namespace. This can occur when - using the `install` workflow. e.g. - $ python setup.py install && python -c "import torch" - - This error can generally be solved using the `develop` workflow - $ python setup.py develop && python -c "import torch" # This should succeed - or by running Python from a different directory. - ''').strip()) from None + raise ImportError( + textwrap.dedent( + """ + Failed to load PyTorch C extensions: + It appears that PyTorch has loaded the `torch/_C` folder + of the PyTorch repository rather than the C extensions which + are expected in the `torch._C` namespace. This can occur when + using the `install` workflow. e.g. + $ python setup.py install && python -c "import torch" + + This error can generally be solved using the `develop` workflow + $ python setup.py develop && python -c "import torch" # This should succeed + or by running Python from a different directory. + """ + ).strip() + ) from None raise # If __file__ is not None the cause is unknown, so just re-raise. -for name in dir(_C): - if name[0] != '_' and not name.endswith('Base'): - __all__.append(name) - obj = getattr(_C, name) - if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] - if (obj.__module__ != 'torch'): +# The torch._C submodule is already loaded via `from torch._C import *` above +# Make an explicit reference to the _C submodule to appease linters +from torch import _C as _C + +__name, __obj = "", None +for __name in dir(_C): + if __name[0] != "_" and not __name.endswith("Base"): + __all__.append(__name) + __obj = getattr(_C, __name) + if callable(__obj) or inspect.isclass(__obj): + if __obj.__module__ != __name__: # "torch" # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: - obj.__module__ = 'torch' - elif name == 'TensorBase': + if __name not in { + "DisableTorchFunctionSubclass", + "DisableTorchFunction", + "Generator", + }: + __obj.__module__ = __name__ # "torch" + elif __name == "TensorBase": # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. - delattr(sys.modules[__name__], name) + delattr(sys.modules[__name__], __name) + +del __name, __obj if not TYPE_CHECKING: # issue 38137 and python issue 43367. Submodules of a C extension are # non-standard, and attributes of those submodules cannot be pickled since # pickle expect to be able to import them as "from _C.sub import attr" # which fails with "_C is not a package - for attr in dir(_C): - candidate = getattr(_C, attr) - if type(candidate) is type(_C): + __name, __candidate = "", None + for __name in dir(_C): + __candidate = getattr(_C, __name) + if type(__candidate) is type(_C): # submodule - if f'torch._C.{attr}' not in sys.modules: - sys.modules[f'torch._C.{attr}'] = candidate + sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate) + + del __name, __candidate ################################################################################ @@ -598,15 +916,19 @@ def typename(o): if isinstance(o, torch.Tensor): return o.type() - module = '' - class_name = '' - if hasattr(o, '__module__') and o.__module__ != 'builtins' \ - and o.__module__ != '__builtin__' and o.__module__ is not None: - module = o.__module__ + '.' - - if hasattr(o, '__qualname__'): + module = "" + class_name = "" + if ( + hasattr(o, "__module__") + and o.__module__ != "builtins" + and o.__module__ != "__builtin__" + and o.__module__ is not None + ): + module = o.__module__ + "." + + if hasattr(o, "__qualname__"): class_name = o.__qualname__ - elif hasattr(o, '__name__'): + elif hasattr(o, "__name__"): class_name = o.__name__ else: class_name = o.__class__.__name__ @@ -717,6 +1039,7 @@ def set_default_device(device): device_context = None else: from torch.utils._device import DeviceContext + device_context = DeviceContext(device) device_context.__enter__() _GLOBAL_DEVICE_CONTEXT.device_context = device_context @@ -774,7 +1097,6 @@ def set_default_dtype(d): Args: d (:class:`torch.dtype`): the floating point dtype to make the default. - Either torch.float32 or torch.float64. Example: >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") @@ -806,8 +1128,13 @@ def set_default_dtype(d): """ _C._set_default_dtype(d) -def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.bool = False) -> None: - r""" Sets whether PyTorch operations must use "deterministic" + +def use_deterministic_algorithms( + mode: builtins.bool, + *, + warn_only: builtins.bool = False, +) -> None: + r"""Sets whether PyTorch operations must use "deterministic" algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When enabled, operations will use deterministic algorithms when available, @@ -943,12 +1270,14 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo """ _C._set_deterministic_algorithms(mode, warn_only=warn_only) + def are_deterministic_algorithms_enabled() -> builtins.bool: r"""Returns True if the global deterministic flag is turned on. Refer to :func:`torch.use_deterministic_algorithms` documentation for more details. """ return _C._get_deterministic_algorithms() + def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: r"""Returns True if the global deterministic flag is set to warn only. Refer to :func:`torch.use_deterministic_algorithms` documentation for more @@ -956,6 +1285,7 @@ def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: """ return _C._get_deterministic_algorithms_warn_only() + def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None: r"""Sets the debug mode for deterministic operations. @@ -973,19 +1303,20 @@ def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None: # NOTE: builtins.int is used here because int in this scope resolves # to torch.int if not isinstance(debug_mode, (builtins.int, str)): - raise TypeError(f'debug_mode must be str or int, but got {type(debug_mode)}') + raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}") if isinstance(debug_mode, str): - if debug_mode == 'default': + if debug_mode == "default": debug_mode = 0 - elif debug_mode == 'warn': + elif debug_mode == "warn": debug_mode = 1 - elif debug_mode == 'error': + elif debug_mode == "error": debug_mode = 2 else: raise RuntimeError( - 'invalid value of debug_mode, expected one of `default`, ' - f'`warn`, `error`, but got {debug_mode}') + "invalid value of debug_mode, expected one of `default`, " + f"`warn`, `error`, but got {debug_mode}" + ) if debug_mode == 0: _C._set_deterministic_algorithms(False) @@ -995,8 +1326,9 @@ def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None: _C._set_deterministic_algorithms(True) else: raise RuntimeError( - 'invalid value of debug_mode, expected 0, 1, or 2, ' - f'but got {debug_mode}') + "invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}" + ) + def get_deterministic_debug_mode() -> builtins.int: r"""Returns the current value of the debug mode for deterministic @@ -1012,12 +1344,14 @@ def get_deterministic_debug_mode() -> builtins.int: else: return 0 + def get_float32_matmul_precision() -> builtins.str: r"""Returns the current value of float32 matrix multiplication precision. Refer to :func:`torch.set_float32_matmul_precision` documentation for more details. """ return _C._get_float32_matmul_precision() + def set_float32_matmul_precision(precision: str) -> None: r"""Sets the internal precision of float32 matrix multiplications. @@ -1083,6 +1417,7 @@ def set_float32_matmul_precision(precision: str) -> None: """ _C._set_float32_matmul_precision(precision) + def set_warn_always(b: builtins.bool) -> None: r"""When this flag is False (default) then some PyTorch warnings may only appear once per process. This helps avoid excessive warning information. @@ -1095,12 +1430,14 @@ def set_warn_always(b: builtins.bool) -> None: """ _C._set_warnAlways(b) + def is_warn_always_enabled() -> builtins.bool: r"""Returns True if the global warn_always flag is turned on. Refer to :func:`torch.set_warn_always` documentation for more details. """ return _C._get_warnAlways() + ################################################################################ # Define error checking functions ################################################################################ @@ -1108,11 +1445,17 @@ def is_warn_always_enabled() -> builtins.bool: # These error checking functions must be kept consistent with their C++ # equivalents. Their C++ equivalents are mentioned where applicable. -def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]): # noqa: F811 + +def _check_with( + error_type, + cond: Union[builtins.bool, SymBool], + message: Callable[[], str], +): # noqa: F811 if not isinstance(cond, (builtins.bool, torch.SymBool)): - raise TypeError(f'cond must be a bool, but got {type(cond)}') + raise TypeError(f"cond must be a bool, but got {type(cond)}") from torch.fx.experimental.symbolic_shapes import expect_true + if expect_true(cond): return @@ -1121,18 +1464,20 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab if message is None: message_evaluated = ( - 'Expected cond to be True, but got False. (Could this error ' - 'message be improved? If so, please report an enhancement request ' - 'to PyTorch.)') + "Expected cond to be True, but got False. (Could this error " + "message be improved? If so, please report an enhancement request " + "to PyTorch.)" + ) else: if not callable(message): - raise TypeError('message must be a callable') + raise TypeError("message must be a callable") message_evaluated = str(message()) raise error_type(message_evaluated) + def _check(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1150,6 +1495,7 @@ def _check(cond, message=None): # noqa: F811 """ _check_with(RuntimeError, cond, message) + def _check_is_size(i, message=None): """Checks that a given integer is a valid size (i.e., is non-negative). You should use this over _check(i >= 0) because we can use the semantic @@ -1163,8 +1509,10 @@ def _check_is_size(i, message=None): # This is responsible for the expect_true _check(i >= 0, message) from torch.fx.experimental.symbolic_shapes import _advise_is_size + _advise_is_size(i) + def _check_index(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1182,6 +1530,7 @@ def _check_index(cond, message=None): # noqa: F811 """ _check_with(IndexError, cond, message) + def _check_value(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1199,6 +1548,7 @@ def _check_value(cond, message=None): # noqa: F811 """ _check_with(ValueError, cond, message) + def _check_type(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1216,6 +1566,7 @@ def _check_type(cond, message=None): # noqa: F811 """ _check_with(TypeError, cond, message) + def _check_not_implemented(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition is False. @@ -1233,16 +1584,17 @@ def _check_not_implemented(cond, message=None): # noqa: F811 """ _check_with(NotImplementedError, cond, message) + def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 if not torch.is_tensor(cond): - raise TypeError(f'cond must be a tensor, but got {type(cond)}') + raise TypeError(f"cond must be a tensor, but got {type(cond)}") if not cond.dtype == torch.bool: - raise TypeError( - f'cond tensor must have dtype torch.bool, but got {cond.dtype}') + raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}") _check_with(error_type, cond._is_all_true().item(), message) + # C++ equivalent: `TORCH_CHECK_TENSOR_ALL` def _check_tensor_all(cond, message=None): # noqa: F811 r"""Throws error containing an optional message if the specified condition @@ -1262,26 +1614,39 @@ def _check_tensor_all(cond, message=None): # noqa: F811 """ _check_tensor_all_with(RuntimeError, cond, message) + ################################################################################ # Define numeric constants ################################################################################ # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and # NumPy consistency (https://numpy.org/devdocs/reference/constants.html) -from math import e , nan , inf , pi +from math import e, inf, nan, pi + newaxis: None = None -__all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis']) + +__all__.extend(["e", "pi", "nan", "inf", "newaxis"]) ################################################################################ # Define Storage and Tensor classes ################################################################################ -from ._tensor import Tensor -from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal +from torch._tensor import Tensor # usort: skip + +# needs to be after torch.Tensor is defined to avoid circular dependencies +from torch import storage as storage # usort: skip +from torch.storage import ( + _LegacyStorage, + _StorageBase, + _warn_typed_storage_removal, + TypedStorage, + UntypedStorage, +) # NOTE: New Storage classes should never be added. When adding a new # dtype, use torch.storage.TypedStorage directly. + class ByteStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1292,6 +1657,7 @@ def dtype(self): def _dtype(self): return torch.uint8 + class DoubleStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1302,6 +1668,7 @@ def dtype(self): def _dtype(self): return torch.double + class FloatStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1312,6 +1679,7 @@ def dtype(self): def _dtype(self): return torch.float + class HalfStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1322,6 +1690,7 @@ def dtype(self): def _dtype(self): return torch.half + class LongStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1332,6 +1701,7 @@ def dtype(self): def _dtype(self): return torch.long + class IntStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1342,6 +1712,7 @@ def dtype(self): def _dtype(self): return torch.int + class ShortStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1352,6 +1723,7 @@ def dtype(self): def _dtype(self): return torch.short + class CharStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1362,6 +1734,7 @@ def dtype(self): def _dtype(self): return torch.int8 + class BoolStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1372,6 +1745,7 @@ def dtype(self): def _dtype(self): return torch.bool + class BFloat16Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1382,6 +1756,7 @@ def dtype(self): def _dtype(self): return torch.bfloat16 + class ComplexDoubleStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1392,6 +1767,7 @@ def dtype(self): def _dtype(self): return torch.cdouble + class ComplexFloatStorage(_LegacyStorage): @classproperty def dtype(self): @@ -1402,6 +1778,7 @@ def dtype(self): def _dtype(self): return torch.cfloat + class QUInt8Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1412,6 +1789,7 @@ def dtype(self): def _dtype(self): return torch.quint8 + class QInt8Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1422,6 +1800,7 @@ def dtype(self): def _dtype(self): return torch.qint8 + class QInt32Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1432,6 +1811,7 @@ def dtype(self): def _dtype(self): return torch.qint32 + class QUInt4x2Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1442,6 +1822,7 @@ def dtype(self): def _dtype(self): return torch.quint4x2 + class QUInt2x4Storage(_LegacyStorage): @classproperty def dtype(self): @@ -1452,45 +1833,58 @@ def dtype(self): def _dtype(self): return torch.quint2x4 + _storage_classes = { - UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage, - ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage, - QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage, - ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage, - TypedStorage + UntypedStorage, + DoubleStorage, + FloatStorage, + LongStorage, + IntStorage, + ShortStorage, + CharStorage, + ByteStorage, + HalfStorage, + BoolStorage, + QUInt8Storage, + QInt8Storage, + QInt32Storage, + BFloat16Storage, + ComplexFloatStorage, + ComplexDoubleStorage, + QUInt4x2Storage, + QUInt2x4Storage, + TypedStorage, } # The _tensor_classes set is initialized by the call to initialize_python_bindings. _tensor_classes: Set[Type] = set() # If you edit these imports, please update torch/__init__.py.in as well -from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed -from .serialization import save, load -from ._tensor_str import set_printoptions +from torch import amp as amp, random as random, serialization as serialization +from torch._tensor_str import set_printoptions +from torch.amp import autocast, GradScaler +from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state +from torch.serialization import load, save ################################################################################ # Initialize extension ################################################################################ -def manager_path(): - if _running_with_deploy() or platform.system() == 'Windows': + +# Shared memory manager needs to know the exact location of manager executable +def _manager_path(): + if _running_with_deploy() or platform.system() == "Windows": return b"" - path = get_file_path('torch', 'bin', 'torch_shm_manager') - prepare_multiprocessing_environment(get_file_path('torch')) + path = get_file_path("torch", "bin", "torch_shm_manager") + prepare_multiprocessing_environment(get_file_path("torch")) if not os.path.exists(path): raise RuntimeError("Unable to find torch_shm_manager at " + path) - return path.encode('utf-8') + return path.encode("utf-8") -from torch.amp import autocast, GradScaler -# Initializing the extension shadows the built-in python float / int classes; -# store them for later use by SymInt / SymFloat. -py_float = float -py_int = int +_C._initExtension(_manager_path()) -# Shared memory manager needs to know the exact location of manager executable -_C._initExtension(manager_path()) -del manager_path +del _manager_path # Appease the type checker: it can't deal with direct setting of globals(). # Note that we will see "too many" functions when reexporting this way; there @@ -1501,30 +1895,31 @@ def manager_path(): # signatures already imported. For now these clashes are ignored; see # PR #43339 for details. from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403 + # Fixup segment_reduce visibility _segment_reduce = segment_reduce del segment_reduce # noqa: F821 # Ops not to be exposed in `torch` namespace, # mostly helper ops. -PRIVATE_OPS = ( - 'unique_dim', -) +PRIVATE_OPS = ("unique_dim",) -for name in dir(_C._VariableFunctions): - if name.startswith('__') or name in PRIVATE_OPS: +__name, __obj = "", None +for __name in dir(_C._VariableFunctions): + if __name.startswith("__") or __name in PRIVATE_OPS: continue - obj = getattr(_C._VariableFunctions, name) - obj.__module__ = 'torch' + __obj = getattr(_C._VariableFunctions, __name) + __obj.__module__ = __name__ # "torch" # Hide some APIs that should not be public - if name == "segment_reduce": + if __name == "segment_reduce": # TODO: Once the undocumented FC window is passed, remove the line bellow - globals()[name] = obj - name = "_" + name - globals()[name] = obj - if not name.startswith("_"): - __all__.append(name) + globals()[__name] = __obj + __name = "_" + __name + globals()[__name] = __obj + if not __name.startswith("_"): + __all__.append(__name) +del __name, __obj ################################################################################ # Add torch.dtype instances to the public API @@ -1532,24 +1927,24 @@ def manager_path(): import torch -for attribute in dir(torch): - if isinstance(getattr(torch, attribute), torch.dtype): - __all__.append(attribute) +__all__.extend( + name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype) +) ################################################################################ # Import TorchDynamo's lazy APIs to avoid circular dependenices ################################################################################ -# needs to be before from .functional import * to avoid circular dependencies -from ._compile import _disable_dynamo +# needs to be before from torch.functional import * to avoid circular dependencies +from torch._compile import _disable_dynamo # usort: skip ################################################################################ # Import interface functions defined in Python ################################################################################ # needs to be after the above ATen bindings so we can overwrite from Python side -from .functional import * # noqa: F403 - +from torch import functional as functional # usort: skip +from torch.functional import * # usort: skip # noqa: F403 ################################################################################ # Remove unnecessary members @@ -1562,16 +1957,19 @@ def manager_path(): # Define _assert ################################################################################ + # needs to be before the submodule imports to avoid circular dependencies def _assert(condition, message): - r"""A wrapper around Python's assert which is symbolically traceable. - """ - from .overrides import has_torch_function, handle_torch_function - - if type(condition) is not torch.Tensor and has_torch_function((condition,)): - return handle_torch_function(_assert, (condition,), condition, message) + r"""A wrapper around Python's assert which is symbolically traceable.""" + if type(condition) is not torch.Tensor and overrides.has_torch_function( + (condition,) + ): + return overrides.handle_torch_function( + _assert, (condition,), condition, message + ) assert condition, message + ################################################################################ # Import most common subpackages ################################################################################ @@ -1579,55 +1977,62 @@ def _assert(condition, message): # Use the redundant form so that type checkers know that these are a part of # the public API. The "regular" import lines are there solely for the runtime # side effect of adding to the imported module's members for other users. -from torch import cuda as cuda -from torch import cpu as cpu -from torch import mps as mps -from torch import xpu as xpu -from torch import mtia as mtia -from torch import autograd as autograd -from torch.autograd import ( - no_grad as no_grad, + +# needs to be before import torch.nn as nn to avoid circular dependencies +from torch.autograd import ( # usort: skip enable_grad as enable_grad, - set_grad_enabled as set_grad_enabled, inference_mode as inference_mode, + no_grad as no_grad, + set_grad_enabled as set_grad_enabled, ) -from torch import fft as fft -from torch import futures as futures -from torch import _awaits as _awaits -from torch import nested as nested -from torch import nn as nn -from torch.signal import windows as windows -from torch import optim as optim -import torch.optim._multi_tensor -from torch import multiprocessing as multiprocessing -from torch import sparse as sparse -from torch import special as special + import torch.utils.backcompat -from torch import jit as jit -from torch import linalg as linalg -from torch import hub as hub -from torch import random as random -from torch import distributions as distributions -from torch import testing as testing -from torch import backends as backends import torch.utils.data -from torch import __config__ as __config__ -from torch import __future__ as __future__ -from torch import profiler as profiler +from torch import ( + __config__ as __config__, + __future__ as __future__, + _awaits as _awaits, + autograd as autograd, + backends as backends, + cpu as cpu, + cuda as cuda, + distributions as distributions, + fft as fft, + futures as futures, + hub as hub, + jit as jit, + linalg as linalg, + mps as mps, + mtia as mtia, + multiprocessing as multiprocessing, + nested as nested, + nn as nn, + optim as optim, + overrides as overrides, + profiler as profiler, + sparse as sparse, + special as special, + testing as testing, + types as types, + xpu as xpu, +) +from torch.signal import windows as windows # Quantized, sparse, AO, etc. should be last to get imported, as nothing # is expected to depend on them. -from torch import ao as ao +from torch import ao as ao # usort: skip + # nn.quant* depends on ao -- so should be after those. +import torch.nn.intrinsic +import torch.nn.qat import torch.nn.quantizable import torch.nn.quantized -import torch.nn.qat -import torch.nn.intrinsic -_C._init_names(list(torch._storage_classes)) +_C._init_names(list(_storage_classes)) # attach docstrings to torch and tensor functions -from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs +from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs + del _torch_docs, _tensor_docs, _storage_docs, _size_docs @@ -1636,53 +2041,56 @@ def compiled_with_cxx11_abi() -> builtins.bool: return _C._GLIBCXX_USE_CXX11_ABI -# Import the ops "namespace" -from torch._ops import ops -from torch._classes import classes import torch._library -# quantization depends on torch.fx +# Import the ops "namespace" +from torch._classes import classes as classes +from torch._ops import ops as ops # usort: skip + +# quantization depends on torch.fx and torch.ops # Import quantization -from torch import quantization as quantization +from torch import quantization as quantization # usort: skip # Import the quasi random sampler -from torch import quasirandom as quasirandom +from torch import quasirandom as quasirandom # usort: skip # If you are seeing this, it means that this call site was not checked if # the memory format could be preserved, and it was switched to old default # behaviour of contiguous -legacy_contiguous_format = contiguous_format +legacy_contiguous_format = contiguous_format # defined by _C._initExtension() # Register fork handler to initialize OpenMP in child processes (see gh-28389) from torch.multiprocessing._atfork import register_after_fork + register_after_fork(torch.get_num_threads) del register_after_fork # Import tools that require fully imported torch (for applying # torch.jit.script as a decorator, for instance): -from ._lobpcg import lobpcg as lobpcg +from torch._lobpcg import lobpcg as lobpcg # These were previously defined in native_functions.yaml and appeared on the # `torch` namespace, but we moved them to c10 dispatch to facilitate custom # class usage. We add these lines here to preserve backward compatibility. -quantized_lstm = torch.ops.aten.quantized_lstm -quantized_gru = torch.ops.aten.quantized_gru - -from torch.utils.dlpack import from_dlpack, to_dlpack +quantized_lstm = ops.aten.quantized_lstm +quantized_gru = ops.aten.quantized_gru # Import experimental masked operations support. See # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more # information. -from . import masked +from torch import masked as masked # Import removed ops with error message about removal -from ._linalg_utils import ( # type: ignore[misc] - matrix_rank, +from torch._linalg_utils import ( # type: ignore[misc] + _symeig as symeig, eig, - solve, lstsq, + matrix_rank, + solve, ) -from ._linalg_utils import _symeig as symeig # type: ignore[misc] + +from torch.utils.dlpack import from_dlpack, to_dlpack + class _TorchCompileInductorWrapper: compiler_name = "inductor" @@ -1695,6 +2103,7 @@ def __init__(self, mode, options, dynamic): # Stash the compiler_fn to be used for backend match guard. from torch._inductor.compile_fx import compile_fx + self.compiler_fn = compile_fx if self.config.get("triton.cudagraphs", False): os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" @@ -1705,15 +2114,18 @@ def __init__(self, mode, options, dynamic): os.environ["TEARDOWN_CUPTI"] = "0" def __eq__(self, other): - return (isinstance(other, _TorchCompileInductorWrapper) and - self.config == other.config and - self.dynamic == other.dynamic) + return ( + isinstance(other, _TorchCompileInductorWrapper) + and self.config == other.config + and self.dynamic == other.dynamic + ) def apply_mode(self, mode: Optional[str]): if mode is None or mode == "default": pass - elif mode in ("reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"): + elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}: from torch._inductor import list_mode_options + self.apply_options(list_mode_options(mode, self.dynamic)) else: raise RuntimeError( @@ -1725,6 +2137,7 @@ def apply_options(self, options: Optional[Dict[str, Any]]): return from torch._inductor import config + current_config: Dict[str, Any] = config.shallow_copy_dict() for key, val in options.items(): @@ -1748,15 +2161,19 @@ def __call__(self, model_, inputs_): def get_compiler_config(self): from torch._inductor.compile_fx import get_patched_config_dict + return get_patched_config_dict(config_patches=self.config) def reset(self): from torch._inductor import config + if "triton.cudagraphs" in self.config or config.triton.cudagraphs: if self.config.get("triton.cudagraphs", True): from torch._inductor.cudagraph_trees import reset_cudagraph_trees + reset_cudagraph_trees() + class _TorchCompileWrapper: def __init__(self, backend, mode, options, dynamic): from torch._dynamo.backends.registry import lookup_backend @@ -1777,10 +2194,12 @@ def __init__(self, backend, mode, options, dynamic): self.kwargs["options"] = options def __eq__(self, other): - return (isinstance(other, _TorchCompileWrapper) and - self.compiler_fn == other.compiler_fn and - self.kwargs == other.kwargs and - self.dynamic == other.dynamic) + return ( + isinstance(other, _TorchCompileWrapper) + and self.compiler_fn == other.compiler_fn + and self.kwargs == other.kwargs + and self.dynamic == other.dynamic + ) def __call__(self, model_, inputs_): return self.compiler_fn(model_, inputs_, **self.kwargs) @@ -1790,13 +2209,16 @@ def reset(self): self.compiler_fn.reset() -def compile(model: Optional[Callable] = None, *, - fullgraph: builtins.bool = False, - dynamic: Optional[builtins.bool] = None, - backend: Union[str, Callable] = "inductor", - mode: Union[str, None] = None, - options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None, - disable: builtins.bool = False) -> Callable: +def compile( + model: Optional[Callable] = None, + *, + fullgraph: builtins.bool = False, + dynamic: Optional[builtins.bool] = None, + backend: Union[str, Callable] = "inductor", + mode: Union[str, None] = None, + options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None, + disable: builtins.bool = False, +) -> Callable: """ Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` @@ -1887,20 +2309,26 @@ def foo(x): # Decorator mode if model is None: + def fn(model: Callable): if model is None: raise RuntimeError("Model can't be None") - return compile(model, - fullgraph=fullgraph, - dynamic=dynamic, - backend=backend, - mode=mode, - options=options, - disable=disable) + return compile( + model, + fullgraph=fullgraph, + dynamic=dynamic, + backend=backend, + mode=mode, + options=options, + disable=disable, + ) + return fn if mode is not None and options is not None: - raise RuntimeError("Either mode or options can be specified, but both can't be specified at the same time.") + raise RuntimeError( + "Either mode or options can be specified, but both can't be specified at the same time." + ) if mode is None and options is None: mode = "default" if backend == "inductor": @@ -1908,12 +2336,18 @@ def fn(model: Callable): else: backend = _TorchCompileWrapper(backend, mode, options, dynamic) - return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model) + return torch._dynamo.optimize( + backend=backend, + nopython=fullgraph, + dynamic=dynamic, + disable=disable, + )(model) from torch import export as export -from torch._higher_order_ops import cond +from torch._higher_order_ops import cond as cond + def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` @@ -1926,20 +2360,23 @@ def _register_device_module(device_type, module): device_type = torch.device(device_type).type m = sys.modules[__name__] if hasattr(m, device_type): - raise RuntimeError(f"The runtime module of '{device_type}' has already " - f"been registered with '{getattr(m, device_type)}'") + raise RuntimeError( + f"The runtime module of '{device_type}' has already " + f"been registered with '{getattr(m, device_type)}'" + ) setattr(m, device_type, module) - torch_module_name = '.'.join([__name__, device_type]) + torch_module_name = ".".join([__name__, device_type]) sys.modules[torch_module_name] = module + # expose return_types -from . import return_types -from . import library +from torch import library as library, return_types as return_types + if not TYPE_CHECKING: - from . import _meta_registrations + from torch import _meta_registrations # Enable CUDA Sanitizer -if 'TORCH_CUDA_SANITIZER' in os.environ: +if "TORCH_CUDA_SANITIZER" in os.environ: import torch.cuda._sanitizer as csan csan.enable_cuda_sanitizer() @@ -1948,7 +2385,7 @@ def _register_device_module(device_type, module): import torch.fx.experimental.sym_node from torch import func as func -from torch.func import vmap +from torch.func import vmap as vmap # Register MPS specific decomps @@ -1983,9 +2420,7 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): # Import the following modules during type checking to enable code intelligence features, # such as auto-completion in tools like pylance, even when these modules are not explicitly # imported in user code. - from torch import _dynamo as _dynamo - from torch import _inductor as _inductor - from torch import onnx as onnx + from torch import _dynamo as _dynamo, _inductor as _inductor, onnx as onnx else: _lazy_modules = { @@ -2001,16 +2436,20 @@ def __getattr__(name): replacement = _deprecated_attrs.get(name) if replacement is not None: import warnings - warnings.warn(f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", stacklevel=2) + + warnings.warn( + f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", + stacklevel=2, + ) return replacement() # Lazy modules if name in _lazy_modules: - import importlib return importlib.import_module(f".{name}", __name__) raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + def get_device_module(device: Optional[Union[torch.device, str]] = None): """ Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). @@ -2024,7 +2463,9 @@ def get_device_module(device: Optional[Union[torch.device, str]] = None): # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. device_module_name = torch._C._get_accelerator().type else: - raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None") + raise RuntimeError( + f"Invalid value of device '{device}', expect torch.device, str, or None" + ) device_module = getattr(torch, device_module_name, None) if device_module is None: raise RuntimeError( @@ -2033,7 +2474,11 @@ def get_device_module(device: Optional[Union[torch.device, str]] = None): return device_module -def _constrain_as_size(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None): +def _constrain_as_size( + symbol, + min: Optional[builtins.int] = None, + max: Optional[builtins.int] = None, +): """ This indicates that a given int is size-like, and can be used in any context where a size is expected. You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() @@ -2055,5 +2500,6 @@ def _constrain_as_size(symbol, min: Optional[builtins.int] = None, max: Optional torch.sym_constrain_range_for_size(symbol, min=min, max=max) -from . import _logging +from torch import _logging + _logging._init_logs() diff --git a/torch/_classes.py b/torch/_classes.py index 870073fea6ea..58b347453524 100644 --- a/torch/_classes.py +++ b/torch/_classes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types import torch._C diff --git a/torch/_compile.py b/torch/_compile.py index 354d64e9ff9f..0f0f51a3509a 100644 --- a/torch/_compile.py +++ b/torch/_compile.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ APIs related to torch.compile which lazily import torch._dynamo to avoid circular dependencies. @@ -19,9 +20,15 @@ def _disable_dynamo(fn=None, recursive=True): @functools.wraps(fn) def inner(*args, **kwargs): - import torch._dynamo + # cache this on the first invocation to avoid adding too much overhead. + disable_fn = getattr(fn, "__dynamo_disable", None) + if disable_fn is None: + import torch._dynamo - return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + disable_fn = torch._dynamo.disable(fn, recursive) + fn.__dynamo_disable = disable_fn + + return disable_fn(*args, **kwargs) return inner else: diff --git a/torch/_custom_op/autograd.py b/torch/_custom_op/autograd.py index 116a4612a45e..35727197d03c 100644 --- a/torch/_custom_op/autograd.py +++ b/torch/_custom_op/autograd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.utils._pytree as pytree from collections import namedtuple diff --git a/torch/_custom_op/functional.py b/torch/_custom_op/functional.py index 26ef5b307bd5..57ff351e2e2d 100644 --- a/torch/_custom_op/functional.py +++ b/torch/_custom_op/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref import torch diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index df83c51bcfd9..2f3efce60a81 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import functools import inspect @@ -836,7 +837,7 @@ def _find_custom_op(qualname, also_check_torch_library=False): return global_registry[qualname] if not also_check_torch_library: raise RuntimeError( - f"Could not find custom op \"{qualname}\". Did you register it via " + f'Could not find custom op "{qualname}". Did you register it via ' f"the torch._custom_ops API?") overload = get_op(qualname) result = custom_op_from_existing(overload) diff --git a/torch/_custom_ops.py b/torch/_custom_ops.py index c09a8ae68543..b8231a186c0a 100644 --- a/torch/_custom_ops.py +++ b/torch/_custom_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from torch._custom_op.impl import ( diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index b277bb7eceb0..7674e5f466a8 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from collections import defaultdict from functools import wraps @@ -260,6 +261,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.addcmul_, aten.addr, aten.affine_grid_generator, + aten.alias_copy, aten.all, aten.aminmax, aten.arange.default, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 5bec539db06c..7c9d342ea0f0 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import numbers import operator @@ -2147,7 +2148,7 @@ def cudnn_batch_norm( def _broadcast_batch_norm_backward(x, broadcast_mask): for axis, mask in enumerate(broadcast_mask): - if mask == 1 and not (axis < x.ndim and x.shape[axis] == broadcast_mask[axis]): + if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask): x = x.unsqueeze(axis) return x diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index d430386ff360..ce47ac43d372 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Callable, Dict, List, Optional, Tuple diff --git a/torch/_decomp/decompositions_for_rng.py b/torch/_decomp/decompositions_for_rng.py index 1aa762351171..74eb9b9240ae 100644 --- a/torch/_decomp/decompositions_for_rng.py +++ b/torch/_decomp/decompositions_for_rng.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from collections import defaultdict from typing import Callable, Dict diff --git a/torch/_deploy.py b/torch/_deploy.py index 35e8d4976940..3f8adc420672 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import torch diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index d80839dc7e47..1d36623ba861 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import unittest.mock from contextlib import contextmanager diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 6e22cafcc6dd..30e44b000fd5 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._C import DispatchKey from torch._higher_order_ops.utils import autograd_not_implemented diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index cf1204de1a5f..69e70198c7f5 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -14,18 +14,22 @@ log = logging.getLogger(__name__) -def aot_autograd(**kwargs): - def compiler_fn(gm: torch.fx.GraphModule, example_inputs): +class AotAutograd: + def __init__(self, **kwargs): + self.__name__ = "compiler_fn" + self.kwargs = kwargs + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): return flatten_graph_inputs( gm, example_inputs, - compiler_fn, + self, ) # Hack to get around circular import problems with aot_eager_decomp_partition - if callable(kwargs.get("decompositions")): - kwargs["decompositions"] = kwargs["decompositions"]() + if callable(self.kwargs.get("decompositions")): + self.kwargs["decompositions"] = self.kwargs["decompositions"]() # NB: dont delete counter increment counters["aot_autograd"]["total"] += 1 @@ -42,10 +46,10 @@ def _wrapped_bw_compiler(*args, **kwargs): # stop TorchDynamo from trying to compile our generated backwards pass return disable(disable(bw_compiler)(*args, **kwargs)) - bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] - kwargs["bw_compiler"] = _wrapped_bw_compiler - kwargs["inference_compiler"] = ( - kwargs.get("inference_compiler") or kwargs["fw_compiler"] + bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] + self.kwargs["bw_compiler"] = _wrapped_bw_compiler + self.kwargs["inference_compiler"] = ( + self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] ) from functorch.compile import nop @@ -54,7 +58,7 @@ def _wrapped_bw_compiler(*args, **kwargs): # debug asserts slow down compile time noticeably, # So only default them on when the aot_eager backend is used. - if kwargs.get("fw_compiler", None) == nop: + if self.kwargs.get("fw_compiler", None) == nop: patch_config = patch("functorch.compile.config.debug_assert", True) else: patch_config = contextlib.nullcontext() @@ -62,14 +66,16 @@ def _wrapped_bw_compiler(*args, **kwargs): try: # NB: NOT cloned! with enable_aot_logging(), patch_config: - cg = aot_module_simplified(gm, example_inputs, **kwargs) + cg = aot_module_simplified(gm, example_inputs, **self.kwargs) counters["aot_autograd"]["ok"] += 1 return disable(cg) except Exception: counters["aot_autograd"]["not_ok"] += 1 raise - return compiler_fn + +def aot_autograd(**kwargs): + return AotAutograd(**kwargs) def mem_efficient_fusion_kwargs(use_decomps): diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index a0a86536c16d..6c024b114fe2 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -4,6 +4,7 @@ import importlib import logging import os +import sys import tempfile from types import MappingProxyType from typing import Optional @@ -23,7 +24,7 @@ def tvm( example_inputs, *, options: Optional[MappingProxyType] = MappingProxyType( - {"scheduler": None, "trials": 20000} + {"scheduler": None, "trials": 20000, "opt_level": 3} ), ): import tvm # type: ignore[import] @@ -50,6 +51,7 @@ def tvm( scheduler = os.environ.get("TVM_SCHEDULER", None) trials = options.get("trials", 20000) + opt_level = options.get("opt_level", 3) if scheduler == "auto_scheduler": from tvm import auto_scheduler @@ -82,7 +84,7 @@ def tvm( with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( - opt_level=3, config={"relay.backend.use_auto_scheduler": True} + opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True} ): lib = relay.build(mod, target=target, params=params) elif scheduler == "meta_schedule": @@ -97,24 +99,27 @@ def tvm( ) # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + assert trials > 0 database = ms.relay_integration.tune_relay( mod=mod, target=target, work_dir=work_dir, - max_trials_global=20000, + max_trials_global=trials, num_trials_per_iter=64, params=params, strategy="evolutionary", + opt_level=opt_level, ) lib = ms.relay_integration.compile_relay( database=database, mod=mod, target=target, params=params, + opt_level=opt_level, ) elif scheduler == "default" or not scheduler: # no autotuning - with tvm.transform.PassContext(opt_level=10): + with tvm.transform.PassContext(opt_level=opt_level): lib = relay.build(mod, target=target, params=params) else: raise NotImplementedError( @@ -179,6 +184,10 @@ def has_tvm(): @functools.lru_cache(None) def llvm_target(): - if "avx512" in open("/proc/cpuinfo").read(): - return "llvm -mcpu=skylake-avx512" - return "llvm -mcpu=core-avx2" + if sys.platform == "linux": + cpuinfo = open("/proc/cpuinfo").read() + if "avx512" in cpuinfo: + return "llvm -mcpu=skylake-avx512" + elif "avx2" in cpuinfo: + return "llvm -mcpu=core-avx2" + return "llvm" diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 340378e7266b..541c3e0cc882 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import bisect import dataclasses import dis diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index dec673b0e910..63dbdf048f6c 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import dis @@ -1117,6 +1118,23 @@ def should_compute_arg(): instructions[i].arg = idx +def clear_instruction_args(instructions): + # Clear the instruction arg for instructions that have argvals. + # Useful for using dis'd bytecode within generated bytecode. + for inst in instructions: + if ( + inst.argval is not _NotProvided + and ( + inst.opcode in HAS_LOCAL + or inst.opcode in HAS_NAME + or inst.opcode in HAS_FREE + or inst.opcode in HAS_CONST + ) + and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR") + ): + inst.arg = None + + def get_code_keys() -> List[str]: # Python 3.11 changes to code keys are not fully documented. # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 @@ -1247,3 +1265,100 @@ def unique_id(name) -> str: def is_generator(code: types.CodeType) -> bool: co_generator = 0x20 return (code.co_flags & co_generator) > 0 + + +def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): + """Generates bytecode from a template function `fn` for use in + dynamo bytecode generation. + + For example, we can generate Python-version-independent bytecode + for looping through a dictionary and copying the values to a new dictionary. + + def template(d1, d2): + for k, v in d1.items(): + d2[k] = v + + + or a try block: + + def template(): + try: + dummy1 + except: + dummy2 + raise + dummy3 + + Args: + fn: a function template to generate bytecode from + varname_map: a mapping of `fn`'s varnames to new names. This + map will be applied to the generated bytecode's varnames. + For example, local variables in `fn` can be replaced with + new names that are generated by `OutputGraph.new_var`. + noreturn: remove all RETURN_* bytecodes and replace them with a jump + to the end of the bytecode. + noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive). + """ + insts = cleaned_instructions(fn.__code__) + clear_instruction_args(insts) + + if noprefix: + for i, inst in enumerate(insts): + if inst.opname == "RESUME": + insts = insts[i + 1 :] + break + + for inst in insts: + # If we don't reset starts_line, then the generated + # bytecode's line number will be based on fn's. + inst.starts_line = None + if varname_map and inst.argval in varname_map: + inst.argval = varname_map[inst.argval] + + if noreturn: + if sys.version_info >= (3, 12): + # replace RETURN_CONST with LOAD_CONST RETURN_VALUE + new_insts = [] + for inst in insts: + if inst.opname == "RETURN_CONST": + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + new_insts.append(inst) + # no need to propagate target/exn table + new_insts.append(create_instruction("RETURN_VALUE")) + else: + new_insts.append(inst) + insts = new_insts + + returns = [] + for inst in insts: + if inst.opname == "RETURN_VALUE": + returns.append(inst) + + if len(returns) == 1 and returns[0] is insts[-1]: + # only 1 return at the end - just pop it + insts.pop(-1) + elif len(returns) > 0: + # create jump target - if the last inst is a return, + # we can replace it with a NOP and make that the jump target. + if insts[-1] is returns[-1]: + insts[-1].opname = "NOP" + insts[-1].opcode = dis.opmap["NOP"] + insts[-1].arg = None + insts[-1].argval = _NotProvided + returns.pop(-1) + else: + insts.append(create_instruction("NOP")) + + # replace returns with jumps + for inst in returns: + # don't replace inst with new instruction + # due to targetting/exn table/etc. + jump_inst = create_jump_absolute(insts[-1]) + inst.opname = jump_inst.opname + inst.opcode = jump_inst.opcode + inst.arg = jump_inst.arg + inst.argval = jump_inst.argval + inst.target = jump_inst.target + + return insts diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index 340f227a9956..ea5e2ae0ce10 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import types import weakref diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index a65e2844f215..35f447a80349 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs class CompilationCallbackHandler: def __init__(self): self.start_callbacks = [] diff --git a/torch/_dynamo/code_context.py b/torch/_dynamo/code_context.py index 0fe19016ca13..59c912bd30f7 100644 --- a/torch/_dynamo/code_context.py +++ b/torch/_dynamo/code_context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types from .utils import ExactWeakKeyDictionary diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 6dbd7f36b0b5..ac0d06d9f428 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import re diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 7a87a2c7d575..f13b53e7ed5f 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,6 +1,7 @@ +# mypy: allow-untyped-defs import contextlib import functools -from typing import List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING import torch from torch._dynamo.external_utils import call_backward, call_hook @@ -41,6 +42,10 @@ def cpp_verbose_log_fn(msg: str) -> None: verbose_log.debug(msg) +def snapshot_cudagraph_enabled(): + return torch._inductor.config.triton.cudagraphs + + def maybe_clone(x): if x is not None: return clone_preserve_strides(x) @@ -203,6 +208,52 @@ def post_acc_grad_hook(self, input, hook_id): self.bind_tensors_to_proxies(input, proxies) return input + # Note: [Compiled autograd and cudagraphs] + # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. + # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph + # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the + # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. + def move_graph_nodes_to_cuda(self, graph) -> List[int]: + to_move: Dict[int, torch.fx.Node] = {} + has_cuda_inputs = False + nodes = list(graph.nodes) + assert nodes[0].target == "inputs" + inputs = nodes[0] + inputs_users = list(inputs.users.keys()) + # the ordering of the nodes should always [inputs, sizes, hooks, getitem, getitem1, ...] + # where getitemi accesses inputs[i] + first_getitem_idx = 3 + assert nodes[first_getitem_idx] == inputs_users[0] + last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 + assert nodes[last_getitem_idx] == inputs_users[-1] + for i, node in enumerate(inputs_users): + if not has_cuda_inputs and node.meta["val"].device.type == "cuda": + has_cuda_inputs = True + continue + + is_cpu = node.meta["val"].device.type == "cpu" + is_scalar = len(node.meta["val"].size()) == 0 + if is_cpu and is_scalar: + node_users = list(node.users.keys()) + if all( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + for user in node_users + ): + # all users are prims/aten, can move safely + to_move[i] = node + + # only move cpu scalars to cuda if there were cuda activations in this graph, + # this is to handle the case where cudagraphs is enabled on a cpu-only graph + if has_cuda_inputs: + for node in to_move.values(): + node.meta["val"] = node.meta["val"].cuda() + + # return runtime indices we need to move to cuda + return list(to_move.keys()) + + return [] + def end_capture(self, outputs): self.stack.close() self.fx_tracer.create_node( @@ -212,6 +263,10 @@ def end_capture(self, outputs): {}, ) self.reorder_accumulate_grad_nodes() + runtime_inputs_to_move: List[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" ) @@ -220,13 +275,23 @@ def end_capture(self, outputs): "%s", lazy_format_graph_code("Compiled autograd graph", graph) ) verbose_log.debug( - "%s", lazy_format_graph_code("Compiled autograd graph", graph) + "%s", + lazy_format_graph_code( + "Compiled autograd graph", graph, include_device=True + ), ) trace_structured( "compiled_autograd_graph", payload_fn=lambda: graph.print_readable(print_output=False), ) - return self.compiler_fn(graph) + + def runtime_wrapper(compiled_fn, inputs, sizes, hooks): + for i in runtime_inputs_to_move: + inputs[i] = inputs[i].cuda() + + return compiled_fn(inputs, sizes, hooks) + + return runtime_wrapper, self.compiler_fn(graph) def reorder_accumulate_grad_nodes(self): """ diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 80880588b54e..ffb9fbc47cca 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This file establishes the public comptime interface to Dynamo. # This allows Dynamo users to execute arbitrary Python code while # Dynamo is symbolically evaluating their original programs. diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 62138127befd..bf3d35c334aa 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import getpass import inspect import os @@ -116,8 +117,9 @@ def is_fbcode(): # This feature doesn't really work. We offer this flag for experimental # purposes / if you want to help us build out support. # -# torchdynamo has very limited support for tensor subclasses that implement -# __torch_function__. Our current support is limited to tensor subclasses +# torchdynamo has limited support for tensor subclasses that implement +# __torch_function__ see [Note: __torch_function__] in torch_function.py. +# Our current support is limited to tensor subclasses # that DO NOT store metadata on the tensor (in general, dynamo does not # support Python code that stores extra attributes on tensors at present). # If your tensor subclass purely changes function call behavior via @@ -225,6 +227,15 @@ def is_fbcode(): os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1" ) +# hybrid backed unbacked symints +prefer_deferred_runtime_asserts_over_guards = False + +# For complex dynamic shapes guards that we're unable to specify with dynamo/export's +# range constraints + dims + derived dims language, we raise constraint violation +# errors or specialize by default. If set to True, this flag avoids crashing/specialization, +# and allows complex guards as runtime assertions in the graph. +_allow_complex_guards_as_runtime_asserts = False + # By default, dynamo will treat all ints as backed SymInts, which means (1) it # will wait to see the int change over multiple runs before generalizing and # (2) it will still always 0/1 specialize an int. When true, this knob @@ -445,6 +456,10 @@ def default_debug_dir_root(): # WARNING: this is an experimental flag and is subject to change. _experimental_support_context_fn_in_torch_utils_checkpoint = False +# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile(). +# Note: AOT Autograd will still trace joint graphs. +compiled_autograd = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 6dcb84fab8fc..a1d7e7e6e130 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import cProfile import dis @@ -178,9 +179,7 @@ def _fn(*args, **kwargs): finally: cleanup.close() torch._C._set_grad_enabled(prior_grad_mode) - torch.torch.autograd.grad_mode._enter_inference_mode( - prior_inference_mode - ) + torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) torch.use_deterministic_algorithms( prior_deterministic, warn_only=prior_warn_only ) @@ -361,17 +360,34 @@ def profile_wrapper(*args, **kwargs): return profile_wrapper -def convert_frame_assert( - compiler_fn: CompilerFn, - one_graph: bool = True, - export: bool = False, - export_constraints=None, -): - """Fully convert a frame into an FX graph""" - reset_graph_break_dup_checker() +class ConvertFrameAssert: + def __init__( + self, + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, + ): + reset_graph_break_dup_checker() + self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + self._one_graph = one_graph + self._export = export + self._export_constraints = export_constraints + + @property + def _clone_with_backend(self): + return lambda backend: convert_frame_assert( + backend, self._one_graph, self._export, self._export_constraints + ) - def _convert_frame_assert( - frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 + def __call__( + self, + frame: types.FrameType, + cache_entry, + hooks: Hooks, + frame_state, + *, + skip: int = 0, ): increment_frame() @@ -458,10 +474,10 @@ def _convert_frame_assert( frame.f_globals, frame.f_locals, frame.f_builtins, - compiler_fn, - one_graph, - export, - export_constraints, + self._torchdynamo_orig_callable, + self._one_graph, + self._export, + self._export_constraints, hooks, cache_entry, cache_size, @@ -471,13 +487,15 @@ def _convert_frame_assert( skip=skip + 1, ) - _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] - def _clone_with_backend(backend): - return convert_frame_assert(backend, one_graph, export, export_constraints) - - _convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined] - return _convert_frame_assert +def convert_frame_assert( + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, +): + """Fully convert a frame into an FX graph""" + return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) from collections import OrderedDict @@ -876,6 +894,7 @@ def format_guard_failures(): dynamo_time_before_restart = time.time() - start_time metrics = CompilationMetrics( + str(compile_id), frame_key, code.co_name, code.co_filename, @@ -906,16 +925,27 @@ def format_guard_failures(): torch._dynamo.callback_handler.run_end_callbacks() -def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): - """Try to convert a frame into an FX graph, if error leave frame unmodified""" - inner_convert = convert_frame_assert(compiler_fn, one_graph=False) +class ConvertFrame: + def __init__(self, compiler_fn: CompilerFn, hooks: Hooks): + self._torchdynamo_orig_callable = compiler_fn + self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._hooks = hooks + + @property + def _clone_with_backend(self): + return lambda backend: convert_frame(backend, self._hooks) - def _convert_frame( - frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 + def __call__( + self, + frame: types.FrameType, + cache_entry, + hooks: Hooks, + frame_state, + skip: int = 0, ): counters["frames"]["total"] += 1 try: - result = inner_convert( + result = self._inner_convert( frame, cache_entry, hooks, frame_state, skip=skip + 1 ) counters["frames"]["ok"] += 1 @@ -979,9 +1009,10 @@ def _convert_frame( log.warning(error_msg, exc_info=True) return None - _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] - _convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined] - return _convert_frame + +def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): + """Try to convert a frame into an FX graph, if error leave frame unmodified""" + return ConvertFrame(compiler_fn, hooks) # TODO mlazos: add support for same args, or record them @@ -1022,9 +1053,13 @@ def first_real_inst_idx(code): raise RuntimeError("RESUME instruction not found in code") -def catch_errors_wrapper(callback, hooks: Hooks): - @functools.wraps(callback) - def catch_errors(frame, cache_entry, frame_state): +class CatchErrorsWrapper: + def __init__(self, callback, hooks): + functools.wraps(callback)(self) + self._torchdynamo_orig_callable = callback + self.hooks = hooks + + def __call__(self, frame, cache_entry, frame_state): assert frame_state is not None is_skipfile = trace_rules.check(frame.f_code) @@ -1062,19 +1097,26 @@ def catch_errors(frame, cache_entry, frame_state): ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_module.bucket_bytes_cap, - backend_compile_fn=callback._torchdynamo_orig_callable, + backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, ) assert hasattr( - callback, "_clone_with_backend" + self._torchdynamo_orig_callable, "_clone_with_backend" ), "DDPOptimizer only supports callback fns that know how to clone themselves." - hijacked_callback = callback._clone_with_backend( - ddp_optimizer.compile_fn, + hijacked_callback = ( + self._torchdynamo_orig_callable._clone_with_backend( + ddp_optimizer.compile_fn, + ) + ) + return hijacked_callback( + frame, cache_entry, self.hooks, frame_state ) - return hijacked_callback(frame, cache_entry, hooks, frame_state) with compile_lock, _disable_current_modes(): # skip=1: skip this frame - return callback(frame, cache_entry, hooks, frame_state, skip=1) + return self._torchdynamo_orig_callable( + frame, cache_entry, self.hooks, frame_state, skip=1 + ) - catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] - return catch_errors + +def catch_errors_wrapper(callback, hooks: Hooks): + return CatchErrorsWrapper(callback, hooks) diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index 42981fcf1015..d30e4a37f003 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -1,3 +1,7 @@ +# mypy: allow-untyped-defs +import threading +from contextlib import contextmanager + import torch doc = """ @@ -36,3 +40,20 @@ def new_parameter_placeholder(size, dtype, device, requires_grad): # Allocating a zero tensor would causes assert failures in autograd. result.untyped_storage().resize_(0) return result + + +_TLS = threading.local() + + +@contextmanager +def do_not_convert_to_tracable_parameter(): + old_flag = getattr(_TLS, "convert_tracable_parameter", True) + _TLS.convert_tracable_parameter = False + try: + yield False + finally: + _TLS.convert_tracable_parameter = old_flag + + +def can_convert_to_tracable_parameter(): + return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/torch/_dynamo/current_scope_id.py b/torch/_dynamo/current_scope_id.py index 1289bdcdffe4..ad079875b58a 100644 --- a/torch/_dynamo/current_scope_id.py +++ b/torch/_dynamo/current_scope_id.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import threading diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 4b4b37a34da9..e262f8cbdb71 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" import copy diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 2c4417d9af50..ec25d06281fc 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,3 +1,5 @@ +# mypy: allow-untyped-defs +# ruff: noqa: TCH004 from dataclasses import dataclass from typing import TYPE_CHECKING @@ -7,6 +9,7 @@ from .comptime import comptime from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage +from .external_utils import is_compiling if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 @@ -73,22 +76,12 @@ def assume_constant_result(fn): def allow_in_graph(fn): """ - Customize which functions TorchDynamo will include in the generated - graph. Similar to `torch.fx.wrap()`. - :: - - torch._dynamo.allow_in_graph(my_custom_function) + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. - @torch._dynamo.optimize(...) - def fn(a): - x = torch.add(x, 1) - x = my_custom_function(x) - x = torch.add(x, 1) - return x - - fn(...) + See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation - Will capture a single graph containing `my_custom_function()`. + WARNING: this API can be a footgun, please read the documentation carefully. """ if isinstance(fn, (list, tuple)): return [allow_in_graph(x) for x in fn] @@ -272,7 +265,7 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. """ - if torch.compiler.is_compiling(): + if is_compiling(): if index is None: for s in t.size(): comptime.force_static(s) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index d93a26546683..aa8848014b34 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fe06995771e0..2fc451cf3d17 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" """ @@ -168,6 +169,9 @@ def _initialize(self): self._forward = self.forward self.forward = self._call_lazy_check + def __reduce__(self): + return (self.__class__, (self._orig_mod, self.dynamo_ctx)) + def __getstate__(self): state = dict(self.__dict__) state.pop("forward", None) @@ -273,9 +277,11 @@ def __init__( super().__init__() assert callable(callback) or callback is False or callback is None self.callback: DynamoCallback = callback + self._backend_ctx_ctor = backend_ctx_ctor self.prior: Union[Unset, DynamoCallback] = unset self.first_ctx = first_ctx self.export = export + self._dynamic = dynamic self.compiler_config = compiler_config self.cleanup_fns: List[Callable[[], Any]] = [] self.enter_exit_hooks = [] @@ -379,7 +385,13 @@ def get_compiler_config(): # call to a builtin without a frame for us to capture fn = external_utils.wrap_inline(fn) - callback = self.callback + def do_nothing(*arg, **kwargs): + pass + + if hasattr(self, "callback"): + callback = self.callback + else: + callback = do_nothing is_jit_tracing = torch._C._is_tracing is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing @@ -493,6 +505,9 @@ def __init__( export=False, dynamic=None, compiler_config=None, + rebuild_ctx: Optional[ + Callable[[], Union[OptimizeContext, _NullDecorator]] + ] = None, ): def on_enter(): install_generation_tagging_init() @@ -508,6 +523,28 @@ def on_enter(): compiler_config=compiler_config, ) + if config.compiled_autograd: + + def call_compiled_autograd(): + assert rebuild_ctx is not None + compiler_fn = rebuild_ctx() + ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_compiled_autograd) + + def __reduce__(self): + return ( + self.__class__, + (self.callback, self._backend_ctx_ctor, self.first_ctx), + { + "export": self.export, + "dynamic": self._dynamic, + "compiler_config": self.compiler_config, + }, + ) + class RunOnlyContext(_TorchDynamoContext): def __init__(self): @@ -517,6 +554,9 @@ def on_enter(): super().__init__(callback=False, on_enter=on_enter) + def __reduce__(self): + return (self.__class__, ()) + class DisableContext(_TorchDynamoContext): def __init__(self): @@ -569,6 +609,9 @@ def _fn(*args, **kwargs): return _fn + def __reduce__(self): + return (self.__class__, ()) + def _optimize_catch_errors( compile_fn, @@ -577,6 +620,7 @@ def _optimize_catch_errors( export=False, dynamic=None, compiler_config=None, + rebuild_ctx=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), @@ -585,6 +629,7 @@ def _optimize_catch_errors( export=export, dynamic=dynamic, compiler_config=compiler_config, + rebuild_ctx=rebuild_ctx, ) @@ -635,7 +680,15 @@ def is_inductor_supported(): return False -def optimize( +def optimize(*args, **kwargs): + def rebuild_ctx(): + return optimize(*args, **kwargs) + + return _optimize(rebuild_ctx, *args, **kwargs) + + +def _optimize( + rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], backend="inductor", *, nopython=False, @@ -643,7 +696,7 @@ def optimize( guard_fail_fn=None, disable=False, dynamic=None, -): +) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call backend() to optimize extracted graphs. @@ -691,6 +744,7 @@ def toy_example(a, b): backend, dynamic=dynamic, hooks=hooks, + rebuild_ctx=rebuild_ctx, ) # The backend function is stashed in the callable returned by # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can @@ -703,6 +757,7 @@ def toy_example(a, b): compiler_config=backend.get_compiler_config() if hasattr(backend, "get_compiler_config") else None, + rebuild_ctx=rebuild_ctx, ) @@ -774,6 +829,7 @@ def guard_export_print(guards): "If you don't migrate, we may break your explain call in the future if your user defined kwargs " "conflict with future kwargs added to explain(f).", FutureWarning, + stacklevel=2, ) return inner(*extra_args, **extra_kwargs) else: @@ -1129,6 +1185,8 @@ def export( assume_static_by_default: bool = False, same_signature: bool = True, disable_constraint_solver: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, + _allow_complex_guards_as_runtime_asserts: bool = False, _log_export_usage: bool = True, **extra_kwargs, ) -> Callable[..., ExportResult]: @@ -1304,6 +1362,8 @@ def result_capturing_wrapper(*graph_inputs): automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, ): opt_f = optimize_assert( dynamo_normalization_capturing_compiler, @@ -1449,6 +1509,7 @@ def graph_with_interpreter(*args): "If you don't migrate, we may break your export call in the future if your user defined kwargs " "conflict with future kwargs added to export(f).", FutureWarning, + stacklevel=2, ) return inner(*extra_args, **extra_kwargs) else: @@ -1462,6 +1523,7 @@ def optimize_assert( export=False, export_constraints=None, dynamic=None, + rebuild_ctx=None, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -1479,6 +1541,7 @@ def optimize_assert( backend_ctx_ctor, export=export, dynamic=dynamic, + rebuild_ctx=rebuild_ctx, ) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 2ca4c311540e..f3cc073b8a30 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import textwrap from enum import auto, Enum @@ -183,6 +184,10 @@ class IncorrectUsage(Exception): pass +class ObservedException(TorchDynamoException): + pass + + # These exceptions are ok to fallback to eager/graph_break. exceptions_allowed_to_be_fallback = ( torch._subclasses.fake_tensor.DataDependentOutputException, diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 669f86c9ec59..caea92bc6be0 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -1,8 +1,8 @@ +# mypy: allow-untyped-defs # This module contains functions that *will be allowed* by dynamo import functools from typing import List -from typing_extensions import deprecated import torch import torch.utils._pytree as pytree @@ -13,10 +13,6 @@ np = None # type: ignore[assignment] -@deprecated( - "`is_compiling` is deprecated. Use `torch.compiler.is_compiling()` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a8fc77b92c11..fc3f12847a75 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import ast @@ -512,6 +513,11 @@ def __init__( # limit the number of cache entries with same ID_MATCH'd object. self.id_matched_objs: Dict[str, ReferenceType[object]] = {} + # Save the guard managers to avoid repeatedly traversing sources. + self._cached_guard_managers: Dict[ + str, torch._C._dynamo.guards.GuardManager + ] = {} + def guard_on_dict_keys_and_ignore_order(self, example_value, guard): dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): @@ -758,6 +764,10 @@ def get_guard_manager_from_source(self, source): example_value = None source_name = source.name() + + if source_name != "" and source_name in self._cached_guard_managers: + return self._cached_guard_managers[source_name] + if source_name != "": example_value = self.get(source_name) @@ -781,7 +791,7 @@ def get_guard_manager_from_source(self, source): # RootGuardManager accepts a dict but still its not a # DictGuardManager because we will eventually move to # fastlocals. - return root_guard_manager.dict_getitem_manager( + out = root_guard_manager.dict_getitem_manager( key=source.local_name, source=source_name, example_value=example_value, @@ -791,14 +801,14 @@ def get_guard_manager_from_source(self, source): # Global manager accepts a dict but it is not a DictGuardManager # because globals dict is big and we typically guard on a very # selected items on globals. - return self.get_global_guard_manager().dict_getitem_manager( + out = self.get_global_guard_manager().dict_getitem_manager( key=source.global_name, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) elif istype(source, GlobalWeakRefSource): - return self.get_global_guard_manager().global_weakref_manager( + out = self.get_global_guard_manager().global_weakref_manager( global_name=source.global_name, source=source_name, example_value=example_value, @@ -812,7 +822,7 @@ def get_guard_manager_from_source(self, source): return root_guard_manager elif istype(source, TypeSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.type_manager( + out = base_guard_manager.type_manager( source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, @@ -822,10 +832,10 @@ def get_guard_manager_from_source(self, source): (OptimizerSource, NNModuleSource, NotNNModuleSource, FSDPNNModuleSource), ): assert base_guard_manager # to make mypy happy - return base_guard_manager + out = base_guard_manager elif istype(source, GradSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.grad_manager( + out = base_guard_manager.grad_manager( source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, @@ -834,7 +844,7 @@ def get_guard_manager_from_source(self, source): assert base_guard_manager # to make mypy happy if isinstance(base_example_value, torch.nn.Module): - return self.getattr_on_nn_module( + out = self.getattr_on_nn_module( source, base_guard_manager, base_example_value, @@ -843,13 +853,13 @@ def get_guard_manager_from_source(self, source): source_name, guard_manager_enum, ) - - return base_guard_manager.getattr_manager( - attr=source.member, - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) + else: + out = base_guard_manager.getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GetItemSource): assert base_guard_manager # to make mypy happy if isinstance(base_example_value, (dict, collections.OrderedDict)): @@ -858,7 +868,7 @@ def get_guard_manager_from_source(self, source): # dicts) so that GetItemSource is only for non dict objects. if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) - return getitem_on_dict_manager( + out = getitem_on_dict_manager( source, base_guard_manager, base_example_value, @@ -871,40 +881,40 @@ def get_guard_manager_from_source(self, source): "Expecting clean index here. Likely Dynamo forgot to mark" " a dict as guard_on_key_order" ) - return base_guard_manager.dict_getitem_manager( + out = base_guard_manager.dict_getitem_manager( key=source.index, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) elif isinstance(base_example_value, list) and not source.index_is_slice: - return base_guard_manager.list_getitem_manager( + out = base_guard_manager.list_getitem_manager( key=source.index, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) elif isinstance(base_example_value, tuple) and not source.index_is_slice: - return base_guard_manager.tuple_getitem_manager( + out = base_guard_manager.tuple_getitem_manager( key=source.index, source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) - - index = source.index - if source.index_is_slice: - index = source.unpack_slice() - return base_guard_manager.getitem_manager( - key=index, - source=source_name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) + else: + index = source.index + if source.index_is_slice: + index = source.unpack_slice() + out = base_guard_manager.getitem_manager( + key=index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, ODictGetItemSource): if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) - return getitem_on_dict_manager( + out = getitem_on_dict_manager( source, base_guard_manager, base_example_value, @@ -913,7 +923,7 @@ def get_guard_manager_from_source(self, source): ) else: assert base_guard_manager # to make mypy happy - return base_guard_manager.dict_getitem_manager( + out = base_guard_manager.dict_getitem_manager( key=source.index, source=source_name, example_value=example_value, @@ -923,7 +933,7 @@ def get_guard_manager_from_source(self, source): assert base_guard_manager # to make mypy happy assert callable(base_example_value) if not source.is_kw: - return base_guard_manager.func_defaults_manager( + out = base_guard_manager.func_defaults_manager( source=base_source_name, example_value=base_example_value.__defaults__, guard_manager_enum=GuardManagerType.GUARD_MANAGER, @@ -947,7 +957,7 @@ def get_guard_manager_from_source(self, source): ) assert not isinstance(dict_mgr, DictGuardManager) - return dict_mgr.dict_getitem_manager( + out = dict_mgr.dict_getitem_manager( key=source.idx_key, source=source_name, example_value=example_value, @@ -955,7 +965,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, NumpyTensorSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.lambda_manager( + out = base_guard_manager.lambda_manager( python_lambda=from_numpy, source=source_name, example_value=example_value, @@ -963,7 +973,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, FlattenScriptObjectSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.lambda_manager( + out = base_guard_manager.lambda_manager( python_lambda=lambda x: x.__obj_flatten__(), source=source_name, example_value=example_value, @@ -971,7 +981,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, ScriptObjectQualifiedNameSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.lambda_manager( + out = base_guard_manager.lambda_manager( python_lambda=lambda x: x._type().qualified_name(), source=source_name, example_value=example_value, @@ -979,7 +989,7 @@ def get_guard_manager_from_source(self, source): ) elif istype(source, TupleIteratorGetItemSource): assert base_guard_manager # to make mypy happy - return base_guard_manager.tuple_iterator_getitem_manager( + out = base_guard_manager.tuple_iterator_getitem_manager( index=source.index, source=source_name, example_value=example_value, @@ -990,7 +1000,7 @@ def get_guard_manager_from_source(self, source): raise AssertionError( "ConstDictKeySource can only work on DictGuardManager" ) - return base_guard_manager.get_key_manager( + out = base_guard_manager.get_key_manager( index=source.index, source=source_name, example_value=example_value, @@ -1001,6 +1011,9 @@ def get_guard_manager_from_source(self, source): f"missing guard manager builder {source} - {source.name()}" ) + self._cached_guard_managers[source.name()] = out + return out + def get_guard_manager(self, guard: Guard): return self.get_guard_manager_from_source(guard.originating_source) diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py index 1e9a820785be..316b3ec817cb 100644 --- a/torch/_dynamo/logging.py +++ b/torch/_dynamo/logging.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index c4a588888e11..9077ecd3d57f 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" import functools @@ -7,7 +8,10 @@ from torch.nn import Module from . import config -from .utils import ExactWeakKeyDictionary, is_lazy_module +from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks + + +unpatched_nn_module_init = torch.nn.Module.__init__ class MutationTracker: @@ -109,6 +113,9 @@ def is_dynamic_nn_module(obj, is_export): and not is_export ): return True + + if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks(): + return True dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( obj ) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index e2bf4e2b3ed6..946bc52d7182 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import copy @@ -292,6 +293,8 @@ def __init__( tracked_fakes=self.tracked_fakes, allow_scalar_outputs=config.capture_scalar_outputs, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, + _allow_complex_guards_as_runtime_asserts=config._allow_complex_guards_as_runtime_asserts, co_fields=self.co_fields, ) @@ -749,7 +752,13 @@ def register_attr_or_module( **options, ): if is_dynamic_nn_module(target, self.root_tx.export): - return variables.UnspecializedNNModuleVariable(target, **options) + result = variables.UnspecializedNNModuleVariable(target, **options) + if not SideEffects.cls_supports_mutation_side_effects(type(target)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.root_tx.output.side_effects.track_object_existing( + target, result + ) options = dict(options) assert "source" in options @@ -1285,7 +1294,10 @@ def compile_and_call_fx_graph(self, tx, rv, root): "dynamo_flat_name_to_original_fqn" ] = self.dynamo_flat_name_to_original_fqn.copy() - graph_code_log.debug("%s", lazy_format_graph_code(name, gm)) + graph_code_log.debug( + "%s", + lazy_format_graph_code(name, gm, include_stride=True, include_device=True), + ) torch._logging.trace_structured( "dynamo_output_graph", lambda: {"sizes": self.get_graph_sizes_structured()}, @@ -1674,7 +1686,7 @@ def example_value_from_input_node(self, node: torch.fx.Node): "(and fall back to eager-mode PyTorch) on all ops " "that have do not have the 'pt2_compliant_tag'. " "Please see the following doc for how to mark this op as PT2 compliant " - "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ" + "https://pytorch.org/docs/main/notes/custom_operators.html" ) diff --git a/torch/_dynamo/profiler.py b/torch/_dynamo/profiler.py index b52551c67137..b7e9553ce219 100644 --- a/torch/_dynamo/profiler.py +++ b/torch/_dynamo/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import os from typing import Any, List diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index 7a312e5d58a9..0049dfe7d3ef 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from dataclasses import field from types import CodeType, ModuleType diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 0dbf3cd5c0e4..98149c72c02c 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import copy import functools diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 76b9128e6995..254f293951ee 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import copy import functools @@ -56,19 +57,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): ) -def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): - """ - A minifier decorator that wraps the TorchDynamo produced Fx graph modules. - As opposed to wrap_compiler_debug, this wrapper intercepts at the - TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some - level, e.g., it is useful for minifying issues related to Aot Autograd - tracing. If an error is found, we minify and save the minified repro in - repro.tar.gz. - """ - - @functools.wraps(unconfigured_compiler_fn) - def debug_wrapper(gm, example_inputs, **kwargs): - compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) +class WrapBackendDebug: + def __init__(self, unconfigured_compiler_fn, compiler_name: str): + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] + self._compiler_name = compiler_name + if hasattr(unconfigured_compiler_fn, "__name__"): + self.__name__ = unconfigured_compiler_fn.__name__ + if hasattr(unconfigured_compiler_fn, "compiler_name"): + self.__name__ = unconfigured_compiler_fn.compiler_name + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__(self, gm, example_inputs, **kwargs): + compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": @@ -82,7 +84,7 @@ def add_paths(exc): ) if config.repro_level == 3: - dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) + dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) # Check for either accuracy (level 4) or other type of failures. if config.repro_level == 4: @@ -95,7 +97,7 @@ def add_paths(exc): dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, - compiler_name, + self._compiler_name, ) exc = AccuracyError("Bad accuracy detected.") add_paths(exc) @@ -110,7 +112,7 @@ def add_paths(exc): ) if config.repro_level == 1: dump_state_fn = functools.partial( - dump_backend_state, compiler_name=compiler_name + dump_backend_state, compiler_name=self._compiler_name ) dump_state_fn( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs @@ -119,7 +121,7 @@ def add_paths(exc): dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, - compiler_name, + self._compiler_name, ) add_paths(exc) raise @@ -128,12 +130,17 @@ def add_paths(exc): return compiled_gm - debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] - if hasattr(unconfigured_compiler_fn, "compiler_name"): - debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name - if hasattr(unconfigured_compiler_fn, "get_compiler_config"): - debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] - return debug_wrapper + +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 387adc06272a..3dae1b3b9b10 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import sys diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 647fae379c54..4072f7641f84 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import warnings from typing import Any, Dict, List, Optional, Union @@ -6,6 +7,7 @@ from . import utils, variables from .bytecode_transformation import ( + bytecode_from_template, create_call_function, create_call_method, create_instruction, @@ -58,6 +60,11 @@ def __init__(self, source: Optional[Source], cls_source: Optional[Source]): self.cls_source = cls_source +def _manual_update_dict(dict_from, dict_to): + for k, v in dict_from.items(): + dict_to[k] = v + + class SideEffects: """ Track side effects (list mutation, setattr, etc) that need to be @@ -346,13 +353,7 @@ def codegen_save_tempvars(self, cg: PyCodegen): elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") - if "__call_nn_module_init" in self.store_attr_mutations.get( - var.mutable_local, {} - ): - assert isinstance(var, variables.UnspecializedNNModuleVariable) - cg.load_import_from(utils.__name__, "nn_module_new") - else: - cg.load_import_from(utils.__name__, "object_new") + cg.load_import_from(utils.__name__, "object_new") cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, True)) cg.add_cache(var) @@ -459,6 +460,39 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) suffixes.append([create_instruction("STORE_SUBSCR")]) + elif isinstance(var, variables.CustomizedDictVariable): + # need to update the dict manually since update method may be invalid + varname_map = {} + for name in _manual_update_dict.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output( + [create_instruction("STORE_FAST", argval=varname_map["dict_to"])] + ) + + cg(var, allow_cache=False) + cg.extend_output( + [create_instruction("STORE_FAST", argval=varname_map["dict_from"])] + ) + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output([create_load_method("clear")]) + + # unfortunately can't just use DICT_MERGE due to possible custom behaviors + dict_update_insts = bytecode_from_template( + _manual_update_dict, varname_map=varname_map + ) + + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + *dict_update_insts, + create_instruction("POP_TOP"), + ] + ) + elif isinstance(var, variables.ConstDictVariable): cg.tx.output.update_co_names("clear") cg.tx.output.update_co_names("update") @@ -479,9 +513,25 @@ def codegen_update_mutated(self, cg: PyCodegen): ] ) elif self.is_attribute_mutation(var): - for name, value in self.store_attr_mutations.get( - var.mutable_local, {} - ).items(): + # Applying mutations involves two steps: 1) Push all + # reconstructed objects onto the stack. 2) Call STORE_ATTR to + # apply the mutations. + # + # Dynamo must ensure that mutations are applied in the same + # order as in the original program. Therefore, two reverse + # operations occur below. + # + # The first reverse operation concerns `suffixes`. We apply + # suffixes in reverse order due to the way Python handles the + # stack. In Step 1, we push all reconstructed objects onto the + # stack, but the item at the top of the stack refers to the last + # attribute in the mutation order. If not fixed, this will apply + # the mutations of attributes in the reverse order. To account + # for this reversal, we iterate through the mutable attributes + # in reverse order. + for name, value in reversed( + self.store_attr_mutations.get(var.mutable_local, {}).items() + ): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) @@ -489,8 +539,6 @@ def codegen_update_mutated(self, cg: PyCodegen): suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) - elif name == "__call_nn_module_init": - pass # handled in codegen_save_tempvars elif isinstance(value, variables.DeletedVariable): if isinstance( var.mutable_local, AttributeMutationExisting diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index ded62ba97d8a..69423712c53c 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import enum diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index bacb8dff9e36..7e129a05a090 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import collections.abc import contextlib @@ -101,7 +102,7 @@ PythonModuleVariable, UnknownVariable, ) -from .variables.nn_module import NNModuleVariable +from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.user_defined import ( RemovableHandleVariable, @@ -199,6 +200,8 @@ def _step_logger(): @dataclasses.dataclass class BlockStackEntry: + # Current instruction that pushes something to block_stack + inst: Instruction target: Instruction stack_index: Optional[int] = None with_context: Optional[ContextWrappingVariable] = None @@ -412,11 +415,22 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if push: self.push(value) self.jump(inst) + elif isinstance(value, UnspecializedNNModuleVariable): + mod = value.value + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): - x = value.var_getattr(self, "__bool__") - # if __bool__ is missing, trying __len__ to infer a truth value. - if isinstance(x, GetAttrVariable): + try: + x = value.var_getattr(self, "__bool__") + except exc.ObservedException: + # if __bool__ is missing, trying __len__ to infer a truth value. x = value.var_getattr(self, "__len__") + else: + if isinstance(x, GetAttrVariable): + # if __bool__ is missing, trying __len__ to infer a truth value. + x = value.var_getattr(self, "__len__") # __bool__ or __len__ is function if isinstance(x, UserMethodVariable): @@ -651,6 +665,7 @@ class InstructionTranslatorBase( inconsistent_side_effects: bool current_speculation: Optional[SpeculationEntry] dispatch_table: List[Any] + exn_vt_stack: List[VariableTracker] exec_recorder: Optional[ExecutionRecorder] strict_checks_fn: Optional[Callable[[VariableTracker], bool]] @@ -800,6 +815,9 @@ def step(self): try: self.dispatch_table[inst.opcode](self, inst) return not self.output.should_exit + except exc.ObservedException: + self.exception_handler() + return True except ReturnValueOp: return False except Unsupported: @@ -989,9 +1007,6 @@ def LOAD_GLOBAL(self, inst): assert name in self.f_builtins self.exec_recorder.builtins[name] = self.f_builtins[name] - if inst.argval == "AssertionError": - unimplemented("assert with non-string message") - if name in self.symbolic_globals: variable = self.output.side_effects[self.symbolic_globals[name]] self.push(self.output.side_effects.load_global(variable, name)) @@ -1128,21 +1143,24 @@ def IMPORT_FROM(self, inst): self.DUP_TOP(inst) self._load_attr(inst) - def load_builtin(self, inst): - if inst.argval not in self.f_builtins: - raise NameError(f"name '{inst.argval}' is not defined") - val = self.f_builtins[inst.argval] + def load_builtin_from_argval(self, argval): + if argval not in self.f_builtins: + raise NameError(f"name '{argval}' is not defined") + val = self.f_builtins[argval] if callable(val): builtins_source = GlobalSource( self.output.name_of_builtins_dict_key_in_fglobals ) - var_source = GetItemSource(builtins_source, inst.argval) + var_source = GetItemSource(builtins_source, argval) self.push(VariableBuilder(self, var_source)(val)) else: assert is_builtin_constant(val) self.push(ConstantVariable.create(value=val)) + def load_builtin(self, inst): + self.load_builtin_from_argval(inst.argval) + def jump(self, inst): self.instruction_pointer = self.indexof[inst.target] @@ -1156,11 +1174,11 @@ def jump(self, inst): def SETUP_LOOP(self, inst): # only exists in python<=3.7 - self.block_stack.append(BlockStackEntry(inst.target)) + self.block_stack.append(BlockStackEntry(inst, inst.target)) def SETUP_EXCEPT(self, inst): # only exists in python<=3.7 - self.block_stack.append(BlockStackEntry(inst.target)) + self.block_stack.append(BlockStackEntry(inst, inst.target)) def POP_BLOCK(self, inst): self.block_stack.pop() @@ -1169,7 +1187,7 @@ def SETUP_WITH(self, inst): self.setup_or_before_with(inst) def SETUP_FINALLY(self, inst): - self.block_stack.append(BlockStackEntry(inst.target)) + self.block_stack.append(BlockStackEntry(inst, inst.target)) def BEGIN_FINALLY(self, inst): self.push(None) @@ -1234,16 +1252,213 @@ def RAISE_VARARGS(self, inst): unimplemented("re-raise") elif inst.arg == 1: val = self.pop() + + # TODO(anijain2305) - Merge StopIterationVariable to use the same exception infra. if ( isinstance(val, BuiltinVariable) and val.fn is StopIteration ) or isinstance(val, variables.StopIterationVariable): raise exc.UserStopIteration + + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + if isinstance(val, variables.BuiltinVariable): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) + + # Save the exception in a global data structure + self.exn_vt_stack.append(val) + + # 2) when user raises exception instance + if isinstance(val, variables.ExceptionVariable): + raise exc.ObservedException(f"raised exception {val}") unimplemented(f"raise {exc}") else: unimplemented("raise ... from ...") + def exception_handler(self): + if sys.version_info >= (3, 11): + exn_tab_entry = self.current_instruction.exn_tab_entry + if exn_tab_entry: + # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + + # 1) pop values from the stack until it matches the stack depth + # for the handler + while len(self.stack) > exn_tab_entry.depth: + self.pop() + + # 2) if 'lasti' is true, then push the offset that the exception was raised at + if exn_tab_entry.lasti: + # This is untested. Any test that tests this end-to-end + # requires supporting more bytecodes. Therefore graph + # breaking for now. + unimplemented("lasti=True while exception handling") + self.push( + variables.ConstantVariable(self.current_instruction.offset) + ) + + # 3) push the exception to the stack + assert len(self.exn_vt_stack) + self.push(self.exn_vt_stack[-1]) + + # 4) jump to the handler + self.jump(exn_tab_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translater. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise exc.ObservedException + else: + if len(self.block_stack): + # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455 + + assert len(self.exn_vt_stack) + exception_var = self.exn_vt_stack[-1] + + block_stack_entry = self.block_stack.pop() + + while block_stack_entry.inst.opname == "EXCEPT_HANDLER": + # TODO(anijain2305) - This is not tested .. unable to create a testcase + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + self.popn(3) + if len(self.block_stack) == 0: + unimplemented( + "exception is raised when block stack " "is empty" + ) + block_stack_entry = self.block_stack.pop() + + if block_stack_entry.inst.opname != "SETUP_FINALLY": + unimplemented( + "exception is raised when top of the block stack " + "is not exception handler (e.g. try .. with .. except). " + f"Current TOS is {block_stack_entry.inst}" + ) + + # Push a dummy block stack entry of EXCEPT_HANDLER + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0) + self.block_stack.append(BlockStackEntry(except_handler_inst, None)) + + # Push old exception + if len(self.exn_vt_stack) >= 2: + old_exception = self.exn_vt_stack[-2] + + # Push the old exception on to stack - tb, value, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(old_exception) + self.push(variables.BuiltinVariable(old_exception.exc_type)) + else: + # Push empty exception tb, value, type + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + + # Push new exception - tb, val, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(exception_var) + self.push(variables.BuiltinVariable(exception_var.exc_type)) + + # Jump to target + self.jump(block_stack_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translater. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + raise Unsupported("Observed exception") + raise exc.ObservedException + + def PUSH_EXC_INFO(self, inst): + val = self.pop() + assert len(self.exn_vt_stack) + self.push(self.exn_vt_stack[-1]) + self.push(val) + + def POP_EXCEPT(self, inst): + if sys.version_info >= (3, 11): + val = self.pop() + assert isinstance(val, variables.ExceptionVariable) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + else: + assert len(self.block_stack) > 0 + if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER": + raise AssertionError( + "Bug in Dynamo tracing of exception handling." + "Top of the block stack is not EXCEPT_HANDLER." + ) + self.block_stack.pop() + + self.popn(3) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + + def check_if_exc_matches(self): + assert len(self.stack) >= 2 + expected_exc_types = self.pop() + exc_instance = self.stack[-1] + + # Users can check exception in 2 ways + # 1) except NotImplementedError --> BuilinVariable + # 2) except (NotImplemetedError, AttributeError) -> TupleVariable + + if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)): + unimplemented( + f"except has an unsupported types of objects {expected_exc_types}" + ) + + if sys.version_info >= (3, 11): + if not isinstance(exc_instance, variables.ExceptionVariable): + unimplemented( + f"except expects to recieve an object of exception type but received {exc_instance}" + ) + + if isinstance(expected_exc_types, TupleVariable): + expected_types = expected_exc_types.items + else: + expected_types = [ + expected_exc_types, + ] + + for expected_type in expected_types: + if not isinstance(expected_type, BuiltinVariable): + unimplemented( + f"except has an unsupported types of object {expected_type}" + ) + if isinstance(exc_instance, variables.ExceptionVariable) and issubclass( + exc_instance.exc_type, expected_type.fn + ): + return True + elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( + exc_instance.fn, expected_type.fn + ): + return True + + return False + + def CHECK_EXC_MATCH(self, inst): + self.push(variables.ConstantVariable(self.check_if_exc_matches())) + + def JUMP_IF_NOT_EXC_MATCH(self, inst): + if not self.check_if_exc_matches(): + self.jump(inst) + def COMPARE_OP(self, inst): - self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) + if inst.argval == "exception match": + self.CHECK_EXC_MATCH(inst) + else: + self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) def GET_ITER(self, inst): self.call_function(BuiltinVariable(iter), [self.pop()], {}) @@ -1767,7 +1982,7 @@ def MATCH_KEYS(self, inst): self.push(ConstantVariable.create(False)) def LOAD_ASSERTION_ERROR(self, inst): - unimplemented("assert with non-string message") + self.load_builtin_from_argval("AssertionError") UNARY_POSITIVE = stack_op(operator.pos) UNARY_NEGATIVE = stack_op(operator.neg) @@ -1905,9 +2120,11 @@ def setup_or_before_with(self, inst): if target: if isinstance(self, InstructionTranslator): - self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx)) + self.block_stack.append( + BlockStackEntry(inst, target, len(self.stack), ctx) + ) else: - self.block_stack.append(BlockStackEntry(target)) + self.block_stack.append(BlockStackEntry(inst, target)) self.push(exit) self.push(ctx.enter(self)) @@ -2062,6 +2279,7 @@ def __init__( self.kw_names = None self.accept_prefix_inst = True self.prefix_insts = [] + self.exn_vt_stack = [] # Properties of the input/output code self.instructions: List[Instruction] = instructions @@ -2573,6 +2791,14 @@ def get_trace_call_log_str(): try: with strict_ctx: tracer.run() + except exc.ObservedException as e: + msg = f"Observed exception DURING INLING {code} : {e}" + # TODO(anijain2305) - This works but we should probably have a + # global/central data structure for the exception stack. + parent.exn_vt_stack.extend(tracer.exn_vt_stack) + log.debug(msg) + # bubble up the exception to the parent frame. + raise except exc.SkipFrame as e: msg = f"SKIPPED INLINING {code}: {e}" log.debug(msg) @@ -2753,8 +2979,6 @@ def LOAD_GLOBAL(self, inst): self.PUSH_NULL(inst) name = inst.argval - if inst.argval == "AssertionError": - unimplemented("assert with non-string message") _, fglobals_vt, global_source = self.get_globals_source_and_value(name) if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name): diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py index 4c4246474c1d..290f03ad0c6e 100644 --- a/torch/_dynamo/tensor_version_op.py +++ b/torch/_dynamo/tensor_version_op.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._prims import _make_prim, RETURN_TYPE from torch._subclasses import FakeTensorMode diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 297ea6e2bc2a..0489b6acc963 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import importlib import logging diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index d12e5a92315a..4736c75785cc 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import io import logging diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 9e9abe84228b..527e0138fc25 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dis import functools @@ -343,6 +344,12 @@ def skipIfNotPy311(fn): return unittest.skip(fn) +def skipIfNotPy312(fn): + if sys.version_info >= (3, 12): + return fn + return unittest.skip(fn) + + def xfailIfPy312(fn): if sys.version_info >= (3, 12): return unittest.expectedFailure(fn) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 6be6e4965ce1..c6e2a848adec 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import _collections_abc import _weakrefset import abc @@ -405,7 +406,9 @@ "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", "torch._C._construct_storage_from_data_pointer", "torch._C._conv_determine_backend_memory_format", - "torch._C._cpu._is_cpu_support_vnni", + "torch._C._cpu._is_cpu_support_avx2", + "torch._C._cpu._is_cpu_support_avx512", + "torch._C._cpu._is_cpu_support_avx512_vnni", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", "torch._C._crash_if_csrc_ubsan", @@ -938,6 +941,7 @@ "torch._C._mps_currentAllocatedMemory", "torch._C._mps_deviceSynchronize", "torch._C._mps_driverAllocatedMemory", + "torch._C._mps_recommendedMaxMemory", "torch._C._mps_elapsedTimeOfEvents", "torch._C._mps_emptyCache", "torch._C._mps_get_default_generator", @@ -1995,7 +1999,6 @@ "torch.not_equal", "torch.nuclear_norm", "torch.numel", - "torch.obj", "torch.ones_like", "torch.ones", "torch.orgqr", @@ -2182,6 +2185,7 @@ "torch.xlogy", "torch.zero_", "torch.zeros", + "torch.zeros_like", "torch._fused_sgd_", "torch.slice_inverse", "torch._assert_scalar", @@ -2292,7 +2296,6 @@ "torch._linalg_utils._symeig", "torch._linalg_utils.basis", "torch._linalg_utils.bform", - "torch._linalg_utils.conjugate", "torch._linalg_utils.eig", "torch._linalg_utils.get_floating_dtype", "torch._linalg_utils.is_sparse", @@ -2302,8 +2305,6 @@ "torch._linalg_utils.qform", "torch._linalg_utils.solve", "torch._linalg_utils.symeig", - "torch._linalg_utils.transjugate", - "torch._linalg_utils.transpose", "torch._load_global_deps", "torch._lowrank._svd_lowrank", "torch._lowrank.get_approximate_basis", @@ -2419,7 +2420,9 @@ "torch.chain_matmul", "torch.compile", "torch.compiled_with_cxx11_abi", - "torch.cpu._is_cpu_support_vnni", + "torch.cpu._is_cpu_support_avx2", + "torch.cpu._is_cpu_support_avx512", + "torch.cpu._is_cpu_support_avx512_vnni", "torch.cpu.current_device", "torch.cpu.current_stream", "torch.cpu.device_count", @@ -3094,6 +3097,18 @@ def is_numpy(obj) -> bool: return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids +def is_numpy_dtype(obj) -> bool: + if np is None: + return False + return isinstance(obj, np.dtype) + + +def is_numpy_type_info(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.finfo, np.iinfo)) + + BUILTIN_SKIPLIST = ( abc, collections, @@ -3510,7 +3525,11 @@ def lookup_inner( # The rules defined in `torch_name_rule_map` mainly includes two parts: # - Manually defined rules for any functions. # - The list of torch in graph functions. - if not hashable(obj): + try: + can_hash = hashable(obj) + except Exception: + can_hash = False + if not can_hash: if reasons is not None: reasons.add("obj is not hashable") return None diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 58768957af87..fe2f096ec488 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import atexit import collections import contextlib @@ -113,7 +114,9 @@ compilation_time_metrics: Dict[str, List[float]] = {} # profiling compilation time by frame phase -frame_phase_timing: Dict[str, Dict[str, float]] = {} +frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( + lambda: collections.defaultdict(float) +) timer_counter = itertools.count() @@ -185,6 +188,10 @@ def print_time_report(): print(out) +def _add_time_spent(key, phase_name, time_spent): + frame_phase_timing[key][phase_name] += time_spent + + # dynamo_timed API works as a function decorator # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. @@ -201,31 +208,82 @@ def print_time_report(): # phase_names record an extra record into a separate compilation timing structure, # one keyed on frame+name rather than function. # The frame is incremented outside of this function, in def increment_frame() above. +# `fwd_only` is used to identify if this phase or function is only called +# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. +# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. -def dynamo_timed(original_function=None, phase_name=None): +def dynamo_timed(original_function=None, phase_name=None, fwd_only=True): def dynamo_timed_inner(func): - if config.cprofile: - return func - @wraps(func) def time_wrapper(*args, **kwargs): key = func.__qualname__ if key not in compilation_time_metrics: compilation_time_metrics[key] = [] - with torch.profiler.record_function(f"{key} (dynamo_timed)"): - t0 = time.time() - r = func(*args, **kwargs) - time_spent = time.time() - t0 - compilation_time_metrics[key].append(time_spent) - if phase_name: - frame_key = str(curr_frame) - if frame_key not in frame_phase_timing: - frame_phase_timing[frame_key] = {} - if phase_name not in frame_phase_timing[frame_key]: - frame_phase_timing[frame_key][phase_name] = time_spent - else: - frame_phase_timing[frame_key][phase_name] += time_spent + + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + time_spent = float("-inf") + try: + with torch.profiler.record_function(f"{key} (dynamo_timed)"): + t0 = time.time() + r = func(*args, **kwargs) + time_spent = time.time() - t0 + compilation_time_metrics[key].append(time_spent) + except Exception as e: + fail_type = str(type(e)) + fail_reason = str(e) + raise + finally: + # Only record backward compilation metrics if phase_name is not None! + if phase_name: + frame_key = str(curr_frame) + # fwd only compilation stages: entire_frame_compile, backend_compile. + # use frame_key as time aggregation key. + if fwd_only and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + else: + # fwd + bwd compilation stages: inductor_compile, code_gen. + # use frame_key as time aggregation key for fwd graphs; + # use compile_id as time aggregation key for bwd graphs. + if torch._guards.TracingContext.try_get() is not None: + aot_graph_name = str( + torch._guards.TracingContext.get().aot_graph_name + ) + if ( + "forward" in aot_graph_name + or "inference" in aot_graph_name + ) and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + elif "backward" in aot_graph_name: + compile_id = str( + torch._guards.CompileContext.current_compile_id() + ) + if fail_type is None: + _add_time_spent(compile_id, phase_name, time_spent) + + # log backward compilation metrics at the end of `inductor_compile` of bwd graph, + # one record for one bwd graph. + if phase_name == "inductor_compile": + if fail_type is None: + inductor_compile_time = frame_phase_timing[ + compile_id + ].get("inductor_compile", None) + code_gen_time = frame_phase_timing[ + compile_id + ].get("code_gen", None) + else: + inductor_compile_time = None + code_gen_time = None + metrics = BwdCompilationMetrics( + compile_id, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + ) + record_compilation_metrics(metrics) + return r return time_wrapper @@ -598,6 +656,7 @@ def proxy_args_kwargs(args, kwargs): @dataclasses.dataclass class CompilationMetrics: + compile_id: str frame_key: str co_name: str co_filename: str @@ -628,26 +687,44 @@ class CompilationMetrics: has_guarded_code: bool +@dataclasses.dataclass +class BwdCompilationMetrics: + compile_id: str + inductor_compile_time_s: Optional[float] + code_gen_time_s: Optional[float] + fail_type: Optional[str] + fail_reason: Optional[str] + + DEFAULT_COMPILATION_METRICS_LIMIT = 64 -_compilation_metrics: Deque[CompilationMetrics] = collections.deque( - maxlen=DEFAULT_COMPILATION_METRICS_LIMIT -) +_compilation_metrics: Deque[ + Union[CompilationMetrics, BwdCompilationMetrics] +] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT) -def record_compilation_metrics(compilation_metrics: CompilationMetrics): +def record_compilation_metrics( + compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics] +): global _compilation_metrics _compilation_metrics.append(compilation_metrics) - torch._logging.trace_structured( - "compilation_metrics", - lambda: { - k: list(v) if isinstance(v, set) else v - for k, v in dataclasses.asdict(compilation_metrics).items() - }, - ) - if config.log_compilation_metrics: - log_compilation_event(compilation_metrics) + if isinstance(compilation_metrics, CompilationMetrics): + name = "compilation_metrics" + else: + name = "bwd_compilation_metrics" + # Currently only record fwd compilation metrics, will add bwd compilation metrics + # after the internal Scuba logging changes finish. + if isinstance(compilation_metrics, CompilationMetrics): + torch._logging.trace_structured( + name, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, + ) + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) def set_compilation_metrics_limit(new_size: int) -> None: @@ -663,7 +740,7 @@ def clear_compilation_metrics() -> None: _compilation_metrics.clear() -def get_compilation_metrics() -> List[CompilationMetrics]: +def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]: return list(_compilation_metrics) @@ -1829,7 +1906,7 @@ def make_error_message(e): assert nnmodule is not None return nnmodule(*args, **kwargs) elif op == "get_attr": - return tracer.get_submodule(node.target) + return tracer.output_graph.get_submodule(node.target) elif op == "placeholder": assert "example_value" in node.meta return node.meta["example_value"] @@ -1864,6 +1941,9 @@ def get_real_value(node, tracer): lambda n: get_real_value(n, tracer), ) + if op == "placeholder" and "grapharg" in node.meta: + return node.meta["grapharg"].example + if op == "call_module": nn_module = tracer.output_graph.nn_modules[node.target] if not is_lazy_module(nn_module): @@ -1939,12 +2019,12 @@ def object_has_getattribute(value: Any): return False -def get_custom_getattr(value: Any): +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: getattr_fn = None - if getattr_fn is torch.nn.Module.__getattr__: + if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: # ignore this case of getattr getattr_fn = None return getattr_fn @@ -2027,6 +2107,14 @@ def format_bytecode(prefix, name, filename, line_no, code): all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names +def nn_module_has_global_hooks(): + # This is limited to backward hooks for now because NNModuleVariable + # supports fwd hooks underneath. + return len(torch.nn.modules.module._global_backward_hooks) or len( + torch.nn.modules.module._global_backward_pre_hooks + ) + + def nn_module_get_all_hooks( mod, check_forward_hooks=False, @@ -2529,12 +2617,17 @@ def is_torch_function_object(value): def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bool: - from torch._dynamo.variables import UserDefinedObjectVariable + from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable - return isinstance(vt, TensorWithTFOverrideVariable) or ( - isinstance(vt, UserDefinedObjectVariable) - and hasattr(vt.value, "__torch_function__") + if isinstance(vt, TensorWithTFOverrideVariable): + return True + + if isinstance(vt, LazyVariableTracker): + LazyVariableTracker.realize(vt) + + return isinstance(vt, UserDefinedObjectVariable) and hasattr( + vt.value, "__torch_function__" ) diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 06f634efb348..9ffdd64fbc96 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -23,7 +23,6 @@ from .dicts import ( ConstDictVariable, CustomizedDictVariable, - DataClassVariable, DefaultDictVariable, SetVariable, ) @@ -63,6 +62,7 @@ AutogradFunctionVariable, ClosureVariable, DeletedVariable, + ExceptionVariable, GetAttrVariable, InspectSignatureVariable, LambdaVariable, @@ -112,7 +112,6 @@ "CountIteratorVariable", "CustomizedDictVariable", "CycleIteratorVariable", - "DataClassVariable", "DefaultDictVariable", "DeletedVariable", "DeterministicAlgorithmsVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 10a79ed8ff31..478fd3eb4010 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -70,7 +70,12 @@ Source, TupleIteratorGetItemSource, ) -from ..trace_rules import is_callable_allowed, is_numpy +from ..trace_rules import ( + is_callable_allowed, + is_numpy, + is_numpy_dtype, + is_numpy_type_info, +) from ..utils import ( build_checkpoint_variable, clone_input, @@ -106,7 +111,7 @@ ) from .dicts import ( ConstDictVariable, - DataClassVariable, + CustomizedDictVariable, DefaultDictVariable, HFPretrainedConfigVariable, PythonSysModulesVariable, @@ -151,6 +156,8 @@ LambdaVariable, LoggingLoggerVariable, MethodWrapperVariable, + NumpyDTypeVariable, + NumpyTypeInfoVariable, NumpyVariable, PythonModuleVariable, RegexPatternVariable, @@ -486,6 +493,11 @@ class Autotuner: elif value is sys.modules: self.install_guards(GuardBuilder.FUNCTION_MATCH) return PythonSysModulesVariable(source=self.source) + elif CustomizedDictVariable.is_matching_cls_hf(type(value)): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = CustomizedDictVariable.wrap(self, value) + result.source = self.source + return self.tx.output.side_effects.track_object_existing(value, result) elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): if not value and self.get_source().is_nn_module(): # It is faster to guard on 'false' property than to guard @@ -625,6 +637,17 @@ def build_key_value(i, k, v): else GuardBuilder.TYPE_MATCH ) return NumpyVariable(value, source=self.source) + elif is_numpy_dtype(value): + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyDTypeVariable(value, source=self.source) + elif is_numpy_type_info(value): + if isinstance(value, np.iinfo): + self.install_guards(GuardBuilder.TYPE_MATCH) + dt_source = AttrSource(self.source, "dtype") + install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyTypeInfoVariable(value, source=self.source) # NB: These can't be put in type_dispatch, they have to run later elif CollectiveFunctionRewriteVariable.can_rewrite(value): self.install_guards(GuardBuilder.FUNCTION_MATCH) @@ -693,9 +716,6 @@ def build_key_value(i, k, v): ) elif np and isinstance(value, np.number): return self.wrap_unspecialized_primitive(value) - elif DataClassVariable.is_matching_object(value): - self.install_guards(GuardBuilder.TYPE_MATCH) - return DataClassVariable.wrap(self, value) elif HFPretrainedConfigVariable.is_matching_object(value): self.install_guards(GuardBuilder.TYPE_MATCH) return HFPretrainedConfigVariable(value) @@ -1110,6 +1130,19 @@ def wrap_module(self, value: torch.nn.Module): if mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it self.install_guards(GuardBuilder.TYPE_MATCH) + if ( + torch._dynamo.config.inline_inbuilt_nn_modules + and torch._inductor.config.freezing + and not torch.is_grad_enabled() + ): + from ..decorators import mark_static_address + + for p in value.parameters(): + mark_static_address(p) + + for b in value.buffers(): + mark_static_address(b) + result = UnspecializedNNModuleVariable(value, source=self.source) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ @@ -1119,7 +1152,7 @@ def wrap_module(self, value: torch.nn.Module): value.__class__, torch.nn.parallel.distributed.DistributedDataParallel ): self.install_guards(GuardBuilder.TYPE_MATCH) - return UnspecializedNNModuleVariable(value) + return UnspecializedNNModuleVariable(value, source=self.get_source()) elif getattr(value, "_is_fsdp_managed_module", False): # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] # in fully_sharded_data_parallel.py for more information @@ -1164,6 +1197,10 @@ def wrap_literal(self, value): value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() + # Assume that integers that came from NN modules want to be + # specialized (as we don't expect users to be changing the + # NN modules on the fly) + or self.source.guard_source().is_nn_module() or is_from_defaults(self.source) or is_cell_contents(self.source) ): @@ -1666,7 +1703,7 @@ def wrap_unspecialized_primitive(self, value): def _dataclasses_fields_lambda(obj): if isinstance(obj, UserDefinedObjectVariable): value = obj.value - elif isinstance(obj, DataClassVariable): + elif isinstance(obj, CustomizedDictVariable): value = obj.user_cls else: unimplemented(f"Dataclass fields handling fails for type {obj}") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 605f56b3047d..71744e95277f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -712,6 +712,20 @@ def _make_handler(fn, arg_types: List[type], has_kwargs: bool): tx, [v.realize() for v in args], kwargs ) + if inspect.isclass(fn) and issubclass(fn, Exception): + + def create_exception_class_object(tx, args, kwargs): + if fn is AssertionError and not all( + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, str) + for x in args + ): + unimplemented("assert with non-string message") + + return variables.ExceptionVariable(fn, args, **kwargs) + + return create_exception_class_object + if obj.can_insert_in_graph() and not ( fn is operator.getitem and not issubclass(arg_types[0], variables.TensorVariable) @@ -1129,6 +1143,14 @@ def call_pos(self, tx, arg: "VariableTracker"): ) return pos_method.call_function(tx, [], {}) + def call_index(self, tx, arg: "VariableTracker"): + if isinstance(arg, variables.TensorVariable): + unimplemented("unsupported index(tensor)") + + arg = guard_if_dyn(arg) + constant_value = operator.index(arg) + return variables.ConstantVariable.create(constant_value) + def call_round(self, tx, arg, *args, **kwargs): # Call arg.__round__() round_method = BuiltinVariable(getattr).call_function( @@ -1433,6 +1455,8 @@ def call_next(self, tx, arg: VariableTracker): def call_hasattr(self, tx, obj, attr): if attr.is_python_constant(): name = attr.as_python_constant() + if isinstance(obj, variables.BuiltinVariable): + return variables.ConstantVariable(hasattr(obj.fn, name)) return obj.call_hasattr(tx, name) def call_map(self, tx, fn, seq): @@ -1609,7 +1633,6 @@ def call_setattr( if isinstance( obj, ( - variables.DataClassVariable, variables.CustomizedDictVariable, variables.PlacementVariable, variables.UserDefinedObjectVariable, @@ -1678,6 +1701,9 @@ def _lower_version_count_by_1(x): return out tx.output.side_effects.store_attr(obj, name, val) + if name == "_grad": + tx.output.side_effects.store_attr(obj, "grad", val) + return val elif isinstance(obj, variables.UserDefinedObjectVariable): unimplemented( @@ -1813,6 +1839,12 @@ def call_id(self, tx, *args): nn_mod_variable = args[0] mod = tx.output.get_submodule(nn_mod_variable.module_key) return variables.ConstantVariable.create(id(mod)) + elif len(args) == 1 and isinstance( + args[0], variables.UserDefinedObjectVariable + ): + install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) + constant_result = id(args[0].value) + return variables.ConstantVariable.create(constant_result) else: unimplemented(f"call_id with args {args}") diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 0724a80621f7..50ea3f96379c 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -174,7 +174,11 @@ def python_type(self): def __contains__(self, vt): assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker - return is_hashable(vt) and Hashable(vt) in self.items + return ( + is_hashable(vt) + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) def reconstruct(self, codegen): # instructions to load collections.OrderedDict if necessary @@ -545,6 +549,8 @@ def python_type(self): def _is_matching_transformers_cls(cls) -> bool: mod = sys.modules.get("transformers.file_utils") + if mod is None: + mod = sys.modules.get("transformers.utils.generic") return mod is not None and issubclass(cls, mod.ModelOutput) @@ -555,12 +561,20 @@ def _is_matching_diffusers_cls(cls) -> bool: def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" + if tx.output.side_effects.is_attribute_mutation(self): + try: + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except KeyError: + pass if name in self.items or hasattr(self.user_cls, name): return ConstantVariable(True) elif istype(self.mutable_local, MutableLocal) and self.source is None: # Something created locally can't have any extra fields on it return ConstantVariable(False) - elif self.mutable_local is None and self.source: + elif self.source: # Maybe add a guard try: example = tx.output.root_tx.get_example_value(self.source) @@ -577,152 +591,27 @@ def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": class DataClassVariable(ConstDictVariable): """ - This is a bit of a hack to deal with - transformers.file_utils.ModelOutput() from huggingface. + This class doesn't appear to be used anywhere. + It used to be used to deal with transformers.file_utils.ModelOutput + from huggingface. - ModelOutput causes trouble because it a a mix of a dataclass and a - OrderedDict and it calls super() methods implemented in C. + Keeping since we wish to support dataclasses in general in the future """ - # ModelOutput() excludes None, though generic datclasses don't - include_none = False + pass - @staticmethod - @functools.lru_cache(None) - def _patch_once(): - try: - from transformers.file_utils import ModelOutput - - for obj in ModelOutput.__dict__.values(): - if callable(obj): - skip_code(obj.__code__) - except ImportError: - pass - - try: - from diffusers.utils import BaseOutput - - for obj in BaseOutput.__dict__.values(): - if callable(obj): - skip_code(obj.__code__) - except ImportError: - pass +class CustomizedDictVariable(ConstDictVariable): @staticmethod - def is_matching_cls(cls): + def is_matching_cls_hf(cls): return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) - @classmethod - def is_matching_object(cls, obj): - return cls.is_matching_cls(type(obj)) - - @classmethod - def create(cls, user_cls, args, kwargs, options): - DataClassVariable._patch_once() - - skip_code(user_cls.__init__.__code__) - keys = [f.name for f in dataclasses.fields(user_cls)] - bound = inspect.signature(user_cls).bind(*args, **kwargs) - bound.apply_defaults() - assert set(bound.arguments.keys()) == set(keys) - items = {} - for key in keys: - val = bound.arguments[key] - key = ConstantVariable.create(key) - if isinstance(val, VariableTracker): - items[key] = val - else: - if cls.include_none: - assert variables.ConstantVariable.is_literal(val) - items[key] = variables.ConstantVariable.create(val) - else: - assert val is None, f"unexpected {val}" - - if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable): - unimplemented("DataClassVariable iterator constructor") - # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__ - - return cls(items, user_cls, **options) - - @classmethod - def wrap(cls, builder, obj): - user_cls = type(obj) - keys = [f.name for f in dataclasses.fields(user_cls)] - - excluded = [] - items = {} - for key in keys: - # __init__ function of a dataclass might not have yet defined the key - if hasattr(obj, key): - val = getattr(obj, key) - var = builder.__class__( - tx=builder.tx, source=AttrSource(builder.source, key) - )(val) - if val is not None or cls.include_none: - key = ConstantVariable.create(key) - items[key] = var - else: - excluded.append(var) - return cls(items, user_cls) - - def __init__(self, items, user_cls, **options): - super().__init__(items, user_cls, **options) - assert self.is_matching_cls(user_cls) - - def as_proxy(self): - raise NotImplementedError - - def reconstruct(self, codegen): - codegen.extend_output([codegen._create_load_const(self.user_cls)]) - # All the keys are just wrapped strings - d = self.keys_as_python_constant() - codegen.foreach(d.values()) - keys = tuple(d.keys()) - codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) - - def call_method( - self, - tx, - name, - args: "List[VariableTracker]", - kwargs: "Dict[str, VariableTracker]", - ) -> "VariableTracker": - if name == "__getitem__": - assert not kwargs and len(args) == 1 - val = args[0] - if val.python_type() == str: - return self.getitem_const(val) - else: - return self.call_method(tx, "to_tuple", [], {}).call_method( - tx, "__getitem__", args, kwargs - ) - elif name == "to_tuple": - assert not (args or kwargs) - return variables.TupleVariable(list(self.items.values())) - elif name == "__setattr__": - name = "__setitem__" - return super().call_method(tx, name, args, kwargs) - - def var_getattr(self, tx, name: str) -> "VariableTracker": - name_vt = ConstantVariable.create(name) - if name_vt in self: - return self.call_method(tx, "__getitem__", [name_vt], {}) - elif not self.include_none: - defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} - if name in defaults: - assert variables.ConstantVariable.is_literal(defaults[name]) - return variables.ConstantVariable.create(defaults[name]) - super().var_getattr(tx, name) - - call_hasattr = _call_hasattr_customobj - - -class CustomizedDictVariable(ConstDictVariable): @staticmethod def is_matching_cls(cls): # True if using default OrderedDict.__init__ and did not implement __post_init__ if ( issubclass(cls, collections.OrderedDict) + and cls is not collections.OrderedDict and cls.__init__ is collections.OrderedDict.__init__ and not hasattr(cls, "__post_init__") ): @@ -730,7 +619,7 @@ def is_matching_cls(cls): # hack for HF usecase: # assume dataclass annotation for ModelOutput subclass # assume self.create is AA to ModelOutput.__post_init__ - return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) + return CustomizedDictVariable.is_matching_cls_hf(cls) @classmethod def is_matching_object(cls, obj): @@ -764,9 +653,7 @@ def make_var(x): ) bound_args = {} - if _is_matching_transformers_cls(user_cls) or _is_matching_diffusers_cls( - user_cls - ): + if cls.is_matching_cls_hf(user_cls): # Skip none for k, v in bound.arguments.items(): if isinstance(v, ConstantVariable) and v.value is None or v is None: @@ -792,7 +679,27 @@ def make_var(x): # called from builder.py @classmethod def wrap(cls, builder, obj): - raise NotImplementedError + user_cls = type(obj) + + if not cls.is_matching_cls_hf(user_cls): + unimplemented("custom non-hf dict subclass wrap unimplemented") + + items = builder.__class__(tx=builder.tx, source=builder.source)( + collections.OrderedDict(obj) + ).items + + keys = [f.name for f in dataclasses.fields(user_cls)] + for key in keys: + # __init__ function of a dataclass might not have yet defined the key + if hasattr(obj, key): + val = getattr(obj, key) + var = builder.__class__( + tx=builder.tx, source=AttrSource(builder.source, key) + )(val) + if val is not None: + key = ConstantVariable.create(key) + items[key] = var + return cls(items, user_cls) def __init__(self, items, user_cls, **options): super().__init__(items, user_cls, **options) @@ -804,9 +711,7 @@ def as_proxy(self): # 'RETURN_VALUE triggered compile' # called from torch/_dynamo/codegen.py def reconstruct(self, codegen): - is_hf_model_output = _is_matching_transformers_cls( - self.user_cls - ) or _is_matching_diffusers_cls(self.user_cls) + is_hf_model_output = self.is_matching_cls_hf(self.user_cls) # If the user class is a ModelOutput, then wrap the instance creation in # torch._dynamo.disable(). Even though we mark the __post_init__ as skip @@ -848,21 +753,34 @@ def call_method( ): # for python dict method without overridden return super().call_method(tx, name, args, kwargs) - elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"): + elif name in ( + "__getitem__", + "to_tuple", + "__setitem__", + "__setattr__", + "__post_init__", + ): # for user overridden method return tx.inline_user_function_return( variables.UserFunctionVariable(fn, source=source), [self] + list(args), kwargs, ) + elif fn is getattr(collections.OrderedDict, name, None): + return super().call_method(tx, name, args, kwargs) - unimplemented("custom dict: call_method unimplemented name=%s", name) + unimplemented(f"custom dict: call_method unimplemented name={name}") def var_getattr(self, tx, name: str) -> "VariableTracker": name_vt = ConstantVariable.create(name) if name_vt in self: return self.call_method(tx, "__getitem__", [name_vt], {}) - super().var_getattr(tx, name) + if dataclasses.is_dataclass(self.user_cls): + defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} + if name in defaults: + assert variables.ConstantVariable.is_literal(defaults[name]) + return variables.ConstantVariable.create(defaults[name]) + return super().var_getattr(tx, name) call_hasattr = _call_hasattr_customobj diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 3fab4413cb0f..88bc94165349 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -6,6 +6,7 @@ import inspect import itertools import types +import warnings from typing import Dict, List, Optional, TYPE_CHECKING, Union import torch @@ -338,6 +339,9 @@ def call_function( return self.obj.call_method( tx, self.fn.__name__, args, kwargs, constant=self.is_constant ) + if self.is_constant: + fn = getattr(self.obj.value, self.fn.__name__) + return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) return super().call_function(tx, args, kwargs) def inspect_parameter_names(self): @@ -634,9 +638,30 @@ def wraps(fn): else: try: path = inspect.getfile(self.value) + msg = f"'skip function {self.value.__qualname__} in file {path}'" except TypeError: - path = f"Builtin {self.value.__name__}" - msg = f"'skip function {self.value.__qualname__} in file {path}'" + known_python_builtin_modules = {"_abc", "_warnings"} + if self.value.__module__ in known_python_builtin_modules: + msg = ( + f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. " + f"Please file an issue on GitHub " + f"so the PyTorch team can add support for it. " + ) + else: + msg = ( + f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. " + f"This function is either a Python builtin (e.g. _warnings.warn) " + f"or a third-party C/C++ Python extension (perhaps created with pybind). " + f"If it is a Python builtin, please file an issue on GitHub " + f"so the PyTorch team can add support for it and see the next case for a workaround. " + f"If it is a third-party C/C++ Python extension, please " + f"either wrap it into a PyTorch-understood custom operator " + f"(see https://pytorch.org/docs/main/notes/custom_operators.html " + f"for more details) or, if it is traceable, use " + f"torch.compiler.allow_in_graph." + ) + # also warn on it because most users won't see the graph break message + warnings.warn(msg) msg += f"', {self.reason}'" if self.reason else "" unimplemented(msg) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 00932f984f38..59f8c26ce62d 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -12,7 +12,7 @@ import torch.fx import torch.nn import torch.onnx.operators -from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value +from torch._dynamo.utils import get_fake_value from torch._dynamo.variables import ConstantVariable from torch._dynamo.variables.base import VariableTracker from torch._dynamo.variables.builtin import BuiltinVariable @@ -1149,17 +1149,15 @@ def call_function( p_args = tuple(arg.as_proxy() for arg in args[1:]) real_sub_args = pytree.tree_map_only( - torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args + torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args ) - example_res = lowered_module.original_module.module()(*real_sub_args) + example_value = lowered_module.original_module.module()(*real_sub_args) # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: # executorch modules promise not to alias inputs and outputs. # Thus, output FakeTensors will correctly not alias input FakeTensors. - _assert_tensors_nonaliasing(real_sub_args, example_res) - - example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) + _assert_tensors_nonaliasing(real_sub_args, example_value) p_args = (lowered_node,) + p_args diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 9dc5bc52ae76..179bb9a52bf9 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -14,8 +14,10 @@ import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction +from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_unspec_or_constant_args, @@ -121,7 +123,6 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) - if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -133,12 +134,10 @@ def call_method( and isinstance(objvar.mutable_local, AttributeMutationNew) and not (args or kwargs) ): - tx.output.side_effects.store_attr( - objvar, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return variables.ConstantVariable.create(None) + with do_not_convert_to_tracable_parameter(): + return variables.UserFunctionVariable( + unpatched_nn_module_init, source=source + ).call_function(tx, [self.objvar] + args, kwargs) else: unimplemented("super() nn.Module.__init__") elif isinstance(inner_fn, types.FunctionType): @@ -171,14 +170,45 @@ def call_method( return super(variables.CustomizedDictVariable, self.objvar).call_method( tx, "__setitem__", args, kwargs ) + elif inner_fn is collections.OrderedDict.__getitem__ and isinstance( + self.objvar, variables.CustomizedDictVariable + ): + return super(variables.CustomizedDictVariable, self.objvar).call_method( + tx, "__getitem__", args, kwargs + ) elif is_standard_setattr(inner_fn) and isinstance( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError: + unimplemented(f"non-const delattr attr: {attr}") + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented(f"delattr({self.objvar}, {attr}, ...)") + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) unimplemented(f"non-function or method super: {inner_fn}") +class ExceptionVariable(VariableTracker): + def __init__(self, exc_type, args, **kwargs): + super().__init__(**kwargs) + self.exc_type = exc_type + self.args = args + + def reconstruct(self, codegen): + codegen.load_import_from("builtins", self.exc_type.__name__) + codegen.foreach(self.args) + codegen.call_function(len(self.args), True) + + class UnknownVariable(VariableTracker): """ It could be anything! @@ -679,11 +709,13 @@ def call_method( and self.name == "__dict__" and not kwargs and args[0].is_python_constant() - and isinstance(self.obj, variables.UserDefinedObjectVariable) + and isinstance( + self.obj, + (variables.UserDefinedObjectVariable, variables.NNModuleVariable), + ) ): obj = self.obj key = args[0].as_python_constant() - obj._check_for_getattribute() if obj.has_key_in_generic_dict(tx, key): # redirect to var_getattr on the original obj return obj.var_getattr(tx, key) @@ -701,11 +733,13 @@ def call_method( and len(args) == 1 and args[0].is_python_constant() and not kwargs - and isinstance(self.obj, variables.UserDefinedObjectVariable) + and isinstance( + self.obj, + (variables.UserDefinedObjectVariable, variables.NNModuleVariable), + ) ): obj = self.obj key = args[0].as_python_constant() - obj._check_for_getattribute() if obj.has_key_in_generic_dict(tx, key): return variables.ConstantVariable(True) else: @@ -847,25 +881,34 @@ def can_constant_fold_through(cls, fn): assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] return fn in cls.constant_fold_functions + @classmethod + def get_constant_collection_for_func(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return np_constant_collections_map.get(fn, None) + def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": if not config.trace_numpy: unimplemented(f"numpy.{self.value}()") - import numpy as np - from ..utils import numpy_to_tensor_wrapper from .tensor import NumpyNdarrayVariable - # lookup method name in tnp. Things like np.dtype(float) are not supported yet. - if self.value.__name__ == "dtype": + func = get_np_to_tnp_map().get(self.value) + if func is None: unimplemented( - f"numpy dtype function is not supported yet. Got type {type(self.value)}." + f"Can't find numpy function {self.value} in torch._numpy. " + " Please file an issue to request support for this function." ) - elif self.value in (np.iinfo, np.finfo): + + # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) + if ( + collection_variable_typ := self.get_constant_collection_for_func(func) + ) is not None: try: - return NumpyTypeInfoVariable( + return collection_variable_typ( self.value( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, @@ -875,14 +918,7 @@ def call_function( unimplemented( f"{self.value.__name__} with non-const args: {args} {kwargs}" ) - else: # We are dealing with a callable. - func = get_np_to_tnp_map().get(self.value) - if func is None: - unimplemented( - f"Can't find numpy function {self.value} in torch._numpy. " - " Please file an issue to request support for this function." - ) - + else: if ( func.__module__ == "torch._numpy.random" and config.use_numpy_random_stream @@ -1091,9 +1127,14 @@ class ConstantLikeVariable(VariableTracker): _error_prefix = "ConstantLikeVariable" try: - from numpy import floating as np_floating + from numpy import ( + dtype as np_dtype, + floating as np_floating, + generic as np_generic, + ) except ImportError: np_floating = type("invalid_type", (), {}) + np_dtype = type("invalid_type", (), {}) def __init__(self, value, **kwargs): super().__init__(**kwargs) @@ -1132,6 +1173,11 @@ def var_getattr(self, tx, name: str) -> VariableTracker: result = getattr(self.value, name) if isinstance(result, self.np_floating): result = float(result) + if isinstance(result, self.np_dtype): + return NumpyDTypeVariable(result) + if isinstance(result, type) and issubclass(result, self.np_generic): + # things like x.dtype.type + return NumpyVariable(result) if variables.ConstantVariable.is_literal(result): return variables.ConstantVariable.create(result) return GetAttrVariable(self, name) @@ -1156,3 +1202,22 @@ def __init__(self, **kwargs): class NumpyTypeInfoVariable(ConstantLikeVariable): _error_prefix = "np.iinfo/np.finfo" + + +class NumpyDTypeVariable(ConstantLikeVariable): + _error_prefix = "np.dtype[...]" + + def as_proxy(self): + """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: + + np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. + This also handles unsupported things nicely (i.e. structured arrays and object arrays). + """ + return self.value.type.__name__ + + +np_constant_collections_map = { + tnp.finfo: NumpyTypeInfoVariable, + tnp.iinfo: NumpyTypeInfoVariable, + tnp.dtype: NumpyDTypeVariable, +} diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index f71767a7b7cb..d3f7052a9445 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -115,6 +115,7 @@ class NNModuleVariable(VariableTracker): "module_type", "module_key", "module", + "nn_module_stack_source", *VariableTracker._nonvar_fields, } @@ -126,6 +127,13 @@ def __init__( self.module_key = module_key self.module = module assert self.source + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self): + return self.nn_module_stack_source or self.source + + def set_nn_module_stack_source(self, source): + self.nn_module_stack_source = source def python_type(self): return self.module_type @@ -189,12 +197,25 @@ def convert_to_unspecialized(self, tx): GenerationTracker.mark_class_dynamic(type(mod)) raise UnspecializeRestartAnalysis + def has_key_in_generic_dict(self, tx, key): + base = tx.output.get_submodule(self.module_key) + + if object_has_getattribute(base): + unimplemented("NNModuleVariable with custom __getattribute__") + + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + base_dict = object.__getattribute__(base, "__dict__") + return key in base_dict + def _custom_getattr_fallback(self, base, tx, name, options): """Check for a __getattr__ and handle it specially if it is implemented""" if object_has_getattribute(base): unimplemented("torch.nn.Module with a custom __getattribute__ defined") - getattr_fn = get_custom_getattr(base) + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) if getattr_fn is None: return None @@ -223,6 +244,9 @@ def var_getattr(self, tx, name): if not self.source: unimplemented("GETATTR with no source") + if name == "__dict__": + return variables.GetAttrVariable(self, name, source=source) + if name in base_dict: subobj = base_dict[name] elif ( @@ -256,7 +280,17 @@ def var_getattr(self, tx, name): return variables.UserDefinedClassVariable(base.__class__, source=source) if object_member: - return VariableBuilder(tx, NNModuleSource(source))(subobj) + out = VariableBuilder(tx, NNModuleSource(source))(subobj) + + if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + else: if istype(subobj, property): if self.source: @@ -298,7 +332,9 @@ def call_function( ) -> "VariableTracker": mod = tx.output.get_submodule(self.module_key) - with record_nn_module_stack(self.module_key, self.source, tx, mod): + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, mod + ): is_lazy = is_lazy_module(mod) if ( isinstance(mod, torch.nn.Sequential) @@ -442,7 +478,9 @@ def generic_call_method_helper(name): # Example: `self.layer.forward(x)` # This is used for explicit calling `forward` in a forward function. # Dynamo puts `call_method` node in FX, doesn't trigger hooks. - with record_nn_module_stack(self.module_key, self.source, tx, module): + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, module + ): return generic_call_method_helper(name) if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( @@ -627,7 +665,6 @@ def gen_source(source, name): if isinstance(args[0], SliceVariable): # Build a TupleVariable of NNModules result = [] - submods = [] # Turn the slice into the list of integers keys = list(range(len(module)))[args[0].as_python_constant()] @@ -641,9 +678,8 @@ def gen_source(source, name): source=src, ) ) - submods.append(submod) - new_module = torch.nn.Sequential(*submods) + new_module = module[args[0].as_python_constant()] new_module_variable = tx.output.register_attr_or_module( new_module, f"{self}.__getitem__(slice)", @@ -657,8 +693,10 @@ def gen_source(source, name): if isinstance(args[0], SymNodeVariable): key = args[0].evaluate_expr(tx.output) - else: + elif args[0].is_python_constant(): key = args[0].as_python_constant() + else: + unimplemented(f"getitem on NNModuleVariable with key {args[0]}") submod = module[key] return tx.output.register_attr_or_module( @@ -705,6 +743,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): _nonvar_fields = { "value_type", "is_state_mutated", + "nn_module_stack_source", *UserDefinedObjectVariable._nonvar_fields, } @@ -733,12 +772,25 @@ def __init__(self, value, **kwargs): super().__init__(value=value, **kwargs) self.is_state_mutated = False + # nn_module_stack_source is used to ensure BC for nn_module_stack. + # Downstream users prefer mod.linear instead of mod._modules['linear'] + # as the module stack. When Dynamo inlines the __getattr__ method, we + # cannot use self.source for nn_module_stack because it will be similar + # to mod._modules['linear']. In these cases, we set the + # nn_module_stack_source appropriately to resemble mod.linear. + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self): + return self.nn_module_stack_source or self.source + + def set_nn_module_stack_source(self, source): + self.nn_module_stack_source = source @staticmethod @functools.lru_cache(None) def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler - supported = {torch.nn.Module.__setattr__} + supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -746,8 +798,6 @@ def _nn_module_method_ids(): } def unpack_var_sequence(self, tx): - from .builder import VariableBuilder - try: fn = inspect.getattr_static(self.value_type, "__iter__") except AttributeError as e: @@ -758,11 +808,16 @@ def unpack_var_sequence(self, tx): torch.nn.ParameterList.__iter__, torch.nn.Sequential.__iter__, ): - assert self.source - return [ - VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) - for idx, item in enumerate(self.value) - ] + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) return super().unpack_var_sequence(tx) @@ -785,7 +840,9 @@ def call_function( guard_to_detect_forward_monkeypatching(self.source, mod) ctx = ( - record_nn_module_stack(str(id(mod)), self.source, tx, mod) + record_nn_module_stack( + str(id(mod)), self.get_nn_module_stack_source(), tx, mod + ) if self.source else nullcontext() ) @@ -889,6 +946,17 @@ def call_method( # Handle submodules self.is_state_mutated = True + if method is torch.nn.Module.__setattr__ and isinstance( + args[1], variables.DeletedVariable + ): + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + return tx.inline_user_function_return( + variables.UserFunctionVariable(torch.nn.Module.__delattr__), + [self, args[0]], + kwargs, + ) + return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 70354e28bb3d..923437193640 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Dict diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index fbfabb5fdf06..934e9a316a4b 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import logging @@ -17,7 +18,11 @@ from ..._guards import TracingContext from .. import config, polyfill, variables from ..codegen import PyCodegen -from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard @@ -870,6 +875,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + if not can_convert_to_tracable_parameter(): + unimplemented("Workaround for issues with nn_parameter construction") + try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() @@ -896,6 +904,13 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ) assert isinstance(result, variables.TensorVariable) result.class_type = torch.nn.Parameter + + # TODO(jansel/bdhirsh) - There is some issue with + # tracable_create_paramter. It does not seem to use the right + # grad_enabled. Since this is parameter, we can just override the + # has_grad_fn field to False to workaround the issue. + result.has_grad_fn = False + # In reconstruct() should use the original parameter. The one returned by the graph will be an alias. result.source = placeholder.source @@ -919,6 +934,12 @@ def _nn_param_via_prefix_insert(tx, data, requires_grad): cg.store(varname) tx.output.pregraph_bytecode.extend(cg.get_instructions()) + data_node = data.as_proxy().node + if data_node.op not in ("placeholder", "get_attr"): + unimplemented( + "Unexpected type of data placeholder op for parameter construction" + ) + # add the newly constructed nn.Parameter as a graph input source = SyntheticLocalSource(varname) example_value = torch.nn.Parameter( diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 0674b8cfd146..6f210d498ce0 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -48,6 +48,33 @@ ] +def _get_all_args(args, kwargs): + return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) + + +def _flatten_vts(vts): + from collections import deque + + from .dicts import ConstDictVariable + from .lazy import LazyVariableTracker + from .lists import ListVariable + + vts = deque(vts) + output = [] + + while vts: + vt = vts.pop() + LazyVariableTracker.realize_all(vt) + if isinstance(vt, ListVariable): + vts.extend(vt.items) + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) + else: + output.append(vt) + + return output + + def _get_subclass_type(var): assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) return var.python_type() @@ -109,17 +136,15 @@ def build_torch_function_fn(tx, value, source): def can_dispatch_torch_function(tx, args, kwargs): - if tx.output.torch_function_enabled: - all_args = pytree.arg_tree_leaves(*args, **kwargs) - return any(has_torch_function(arg) for arg in all_args) - else: - return False + return tx.output.torch_function_enabled and any( + has_torch_function(arg) for arg in _get_all_args(args, kwargs) + ) def dispatch_torch_function(tx, fn, args, kwargs): """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" - all_args = pytree.arg_tree_leaves(*args, **kwargs) + all_args = _get_all_args(args, kwargs) overloaded_args = _get_overloaded_args( [arg for arg in all_args if has_torch_function(arg)], _get_subclass_type, diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ca913060abf9..6c6d3182b660 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -34,7 +34,8 @@ from torch._guards import TracingContext from .. import variables -from ..exc import unimplemented +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ObservedException, unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource from ..utils import ( @@ -57,10 +58,7 @@ def is_standard_setattr(val): - return val in ( - object.__setattr__, - torch.nn.Module.__setattr__, - ) + return val in (object.__setattr__,) class UserDefinedVariable(VariableTracker): @@ -378,17 +376,7 @@ def call_function( else UserDefinedObjectVariable, {}, ) - if ( - inspect.getattr_static(self.value, "__init__", None) - is torch.nn.Module.__init__ - ): - tx.output.side_effects.store_attr( - var, - "__call_nn_module_init", - variables.ConstantVariable.create(True), - ) - return var - else: + with do_not_convert_to_tracable_parameter(): var.call_method(tx, "__init__", args, kwargs) return var elif variables.CustomizedDictVariable.is_matching_cls(self.value): @@ -396,9 +384,6 @@ def call_function( return variables.CustomizedDictVariable.create( self.value, args, kwargs, options ) - elif variables.DataClassVariable.is_matching_cls(self.value): - options = {"mutable_local": MutableLocal()} - return variables.DataClassVariable.create(self.value, args, kwargs, options) elif ( variables.RestrictedListSubclassVariable.is_matching_cls(self.value) and self.source @@ -638,6 +623,10 @@ def call_method( else AttrSource(AttrSource(self.source, "__class__"), name) ) # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init return UserMethodVariable(method, self, source=source).call_function( tx, args, kwargs ) @@ -799,7 +788,7 @@ def _check_for_getattr(self): def _getattr_static(self, name): if ( - isinstance(self.value, (torch.nn.Module, PyTreeSpec)) + isinstance(self.value, PyTreeSpec) or "__slots__" in self.value.__class__.__dict__ or type(self.value) == threading.local ): @@ -812,13 +801,13 @@ def _getattr_static(self, name): return cls_var except AttributeError: pass # __slots__ - # this might call torch.nn.Module.__getattr__ subobj = getattr(self.value, name) else: subobj = inspect.getattr_static(self.value, name) return subobj def has_key_in_generic_dict(self, tx, key): + self._check_for_getattribute() if tx.output.side_effects.has_pending_mutation_of_attr(self, key): mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) return not isinstance(mutated_attr, variables.DeletedVariable) @@ -852,9 +841,26 @@ def var_getattr(self, tx, name): new_source = None if self.source: new_source = AttrSource(self.source, "__getattr__") - return variables.UserMethodVariable( + out = variables.UserMethodVariable( getattr_fn, self, source=new_source ).call_function(tx, [ConstantVariable.create(name)], {}) + + if self.source and getattr_fn is torch.nn.Module.__getattr__: + if isinstance( + out, + ( + variables.UnspecializedNNModuleVariable, + variables.NNModuleVariable, + ), + ): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + elif getattr_fn is not None: unimplemented("UserDefined with non-function __getattr__") @@ -1000,14 +1006,35 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute() or self._check_for_getattr(): - unimplemented("hasattr with custom __getattr__") + if self._check_for_getattribute(): + unimplemented("hasattr with custom __getattribute__") try: self._getattr_static(name) return variables.ConstantVariable.create(True) except AttributeError: - return variables.ConstantVariable.create(False) + # Now check in __getattr__ function + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + try: + result = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [variables.ConstantVariable.create(name)], {}) + + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except ObservedException: + return variables.ConstantVariable.create(False) + elif getattr_fn is None: + return variables.ConstantVariable.create(False) + else: + unimplemented("UserDefined with non-function __getattr__") def odict_getitem(self, tx, key): from .builder import VariableBuilder @@ -1074,6 +1101,12 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + class RemovableHandleVariable(VariableTracker): REMOVED = -1 @@ -1104,3 +1137,6 @@ def reconstruct(self, codegen): return # unreachable due to codegen.add_cache() when the hook is installed super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index d41ff4b53af0..d9a514232569 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import functools diff --git a/torch/_export/converter.py b/torch/_export/converter.py index d902c5f1ac55..e2b108a658e0 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs +import operator + from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -5,6 +8,7 @@ from torch.export.exported_program import ExportedProgram from torch.export.graph_signature import ( + ConstantArgument, InputKind, InputSpec, OutputKind, @@ -33,54 +37,193 @@ def replacement(im, dim, scale): replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement) - print(replaced_patterns) - def normalize_name(name: str) -> str: return name.replace(".", "_") +def ir_name_to_func_name(name: str) -> str: + """prim::If -> convert_prim_If""" + name_list = name.split("::") + return "convert_" + "_".join(name_list) + + +def get_node_for_param_and_buffer(fx_graph, name, is_top_level_graph): + if is_top_level_graph: + return fx_graph.get_attr(name) + return fx_graph.placeholder(name) + + +# Those operators will be automatically populated to a instance method +# of TS2FXGraphConverter with name convert__(). +# Please check __init__ for method population implementations. +kind_to_standard_operators = { + "prim::TupleIndex": operator.getitem, + "aten::__is__": operator.is_, + "aten::__isnot__": operator.is_not, + "aten::__not__": operator.not_, + "aten::__contains__": operator.contains, +} + + +def get_ir_value_parent_name_and_attr_name(node): + irv_parent_name, irv_name = node.input().debugName(), node.output().debugName() + attr_name = node.s("name") + return irv_name, irv_parent_name, attr_name + + +def construct_fqn(ir, ref_map, name_map): + name_list = [] + while ir in ref_map: + name_list.append(name_map[ir]) + ir = ref_map[ir] + return ".".join(reversed(name_list)) + + +def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]: + """ + Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. + When a graph has control flow, the graph will be divided into multiple blocks. We want to convert + each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model + parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model, + we will run this pass which will: + 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls. + 2. Process the graph bottom up to find the lifted attributes of each block by taking the union + of the attributes used in the current block, and the lifted attributes of all its child blocks. + + Returns: + A mapping of blocks to a set of FQNs of its lifted attributes. + """ + + # A map from a block to its expected to be lifted arguments. + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = dict() + + # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a + # GetAttr node. By traversing this reference map, we can figure out the + # full IR aliasing pass and figure out the FQN of an attribute. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" + node_to_parent_map: Dict[str, str] = dict() + + # Used for reconstructing the FQN of an attribute based on the reference map. + # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR + # This name map stores which attribute name is called for a src IR --> dest IR action. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" + node_to_attr_name: Dict[str, str] = dict() + + def _dfs_get_attr_dependency(entry): + """ + First DFS path to construct reference map and name map. + """ + for node in entry.nodes(): + if node.kind() == "prim::GetAttr": + ( + irv_name, + irv_parent_name, + attr_name, + ) = get_ir_value_parent_name_and_attr_name(node) + node_to_parent_map[irv_name] = irv_parent_name + node_to_attr_name[irv_name] = attr_name + for block in node.blocks(): + _dfs_get_attr_dependency(block) + + def _map_blocks_to_lifted_attrs(entry): + """ + Walk the graph in a bottom-up fashion to build the expected to be + lifted arguments for each block. + """ + arguments: Set[str] = set() + for node in entry.nodes(): + for block in node.blocks(): + # Recursively build. + arguments = arguments.union(_map_blocks_to_lifted_attrs(block)) + if node.kind() == "prim::GetAttr": + irv_name = node.output().debugName() + # Skip for intermediate GetAttr, which will anyway not result a FQN. + # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"} + # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"} + # There is only one FQN %3-->%2-->%1: self.linear.weight + # %2-->%1 is not a FQN: self.linear + if irv_name not in set(node_to_parent_map.values()): + arguments.add( + construct_fqn(irv_name, node_to_parent_map, node_to_attr_name) + ) + if not isinstance(entry, torch._C.Graph): # Skip the top level. + blocks_to_lifted_attrs[entry] = arguments + return arguments + + _dfs_get_attr_dependency(graph) + _map_blocks_to_lifted_attrs(graph) + + return blocks_to_lifted_attrs + + def get_op_overload(node: torch._C.Node): schema_str = node.schema() schema = FunctionSchema.parse(schema_str) ns, op_name = str(schema.name.name).split("::") override = schema.name.overload_name - op_overload_packet = getattr(torch.ops.aten, op_name) - if override: - op_overload = getattr(op_overload_packet, override) - else: - op_overload = op_overload_packet.default + try: + op_overload_mod = getattr(torch.ops, ns) + op_overload_packet = getattr(op_overload_mod, op_name) + if override: + op_overload = getattr(op_overload_packet, override) + else: + op_overload = op_overload_packet.default + except Exception as e: + raise RuntimeError( + f"Unable to find operator {node.kind()} with schema {node.schema}" + ) from e return op_overload -class TS2EPConverter: - # TorchScript model to ExportedProgram converter +class TS2FXGraphConverter: def __init__( self, - ts_model, - sample_args: Tuple[Any, ...], - sample_kwargs: Optional[Dict[str, Any]] = None, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param_map: Dict[str, torch.Tensor], + name_to_buffer_map: Dict[str, torch.Tensor], + blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], ): - self.ts_model = ts_model - self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) - - self.sample_args = sample_args - self.sample_kwargs = sample_kwargs - - self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()} - self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()} + self.ts_graph = ts_graph + self.name_to_param_map = name_to_param_map + self.name_to_buffer_map = name_to_buffer_map self.fx_graph: torch.fx.Graph = torch.fx.Graph() self.input_specs: List[InputSpec] = [] self.output_specs: List[OutputSpec] = [] - self.name_to_node: Dict[str, Union[torch.fx.Node, List[torch.fx.Node]]] = {} + self.name_to_node: Dict[ + str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] + ] = {} self.constant_map: Dict[str, Any] = {} self.attribute_map: Dict[str, Any] = {} self.tensor_constants: Dict[str, torch.Tensor] = {} + self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + + self.blocks_to_lifted_attrs = blocks_to_lifted_attrs + + # Populate methods for the standard operators. + for k in kind_to_standard_operators.keys(): + handler_func_name = ir_name_to_func_name(k) + # Create an indirect function call: + # convert__ --> lambda node: _convert_standard_operator(node) + setattr( + self, + handler_func_name, + lambda node: self._convert_standard_operators(node), + ) + + def is_top_level_graph(self): + return isinstance(self.ts_graph, torch._C.Graph) + + def add_subgraph(self, subgraph) -> str: + name = f"subgraph_{len(self.subgraphs)}" + self.subgraphs[name] = subgraph + return name + def get_args_kwargs(self, node: torch._C.Node, schema): args = [] kwargs = {} @@ -109,7 +252,7 @@ def get_fx_value(self, value: torch._C.Value): else: raise ValueError(f"Input {value_name} not found") - def convert(self) -> ExportedProgram: + def convert(self) -> torch.fx.GraphModule: self.convert_graph_inputs() for node in self.ts_graph.nodes(): @@ -117,28 +260,24 @@ def convert(self) -> ExportedProgram: self.convert_graph_outputs() - gm = torch.fx.GraphModule({}, self.fx_graph) + # Pass parameter and buffer to the root for lookup. + gm = torch.fx.GraphModule( + {**self.subgraphs, **self.name_to_param_map, **self.name_to_buffer_map}, + self.fx_graph, + ) inplace_optimize_sym_size_div(gm) gm.graph.lint() - ep = self.retrace_as_exported_program(gm) - return ep + return gm def convert_graph_inputs(self): for graph_input in self.ts_graph.inputs(): name = graph_input.debugName() normalized_name = normalize_name(name) - fx_node = self.fx_graph.placeholder(normalized_name) - - # fx_node.meta["val"] = FakeTensor() - # TODO: set fx_node.meta["val"] - - self.name_to_node[name] = fx_node - - if name in self.param_names: + if name in self.name_to_param_map: self.input_specs.append( InputSpec( InputKind.PARAMETER, @@ -146,7 +285,10 @@ def convert_graph_inputs(self): target=name, ) ) - elif name in self.buffer_names: + fx_node = get_node_for_param_and_buffer( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_buffer_map: self.input_specs.append( InputSpec( InputKind.BUFFER, @@ -155,6 +297,9 @@ def convert_graph_inputs(self): persistent=True, ) ) + fx_node = get_node_for_param_and_buffer( + self.fx_graph, name, self.is_top_level_graph() + ) else: self.input_specs.append( InputSpec( @@ -163,6 +308,9 @@ def convert_graph_inputs(self): target=name, ) ) + fx_node = self.fx_graph.placeholder(normalized_name) + + self.name_to_node[name] = fx_node def convert_prim_Constant(self, node: torch._C.Node): name = node.output().debugName() @@ -201,6 +349,20 @@ def convert_prim_Constant(self, node: torch._C.Node): self.constant_map[name] = value + def convert_prim_device(self, node: torch._C.Node): + input_type = node.input().type() + if input_type.isSubtypeOf(torch._C.TensorType.get()): + device = input_type.device() # type: ignore[attr-defined] + output_name = node.output().debugName() + self.constant_map[output_name] = device + else: + raise ValueError(f"Unsupported JitType ({input_type}) when get device") + + def convert_prim_dtype(self, node: torch._C.Node): + dtype = node.input().type().dtype() + output_name = node.output().debugName() + self.constant_map[output_name] = dtype + def convert_prim_GetAttr(self, node: torch._C.Node): def get_attr(name: str): if name in self.attribute_map: @@ -218,7 +380,7 @@ def get_attr(name: str): f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name ) - def convert_aten_op(self, node: torch._C.Node): + def convert_call_function_op(self, node: torch._C.Node): target = get_op_overload(node) if target is torch.ops.aten.size.int: @@ -234,14 +396,57 @@ def convert_aten_op(self, node: torch._C.Node): output_name = node.output().debugName() self.name_to_node[output_name] = fx_node + def convert_prim_TupleConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + def convert_prim_ListConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def _convert_prim_iterator(self, node: torch._C.Node): output_list = [] - for input in node.inputs(): - output_list.append(self.get_fx_value(input)) + for inp in node.inputs(): + output_list.append(self.get_fx_value(inp)) output_name = node.output().debugName() self.name_to_node[output_name] = output_list + def convert_prim_DictConstruct(self, node: torch._C.Node): + output_dict = {} + k, v = None, None + for i, inp in enumerate(node.inputs()): + # We assume key value are stored in pair in the DictConstruct. + # The first element is the key and the following is the value. + if i % 2 == 0: + k = self.get_fx_value(inp) + else: + v = self.get_fx_value(inp) + assert ( + k is not None and v is not None + ), "DictConstruct has an empty key value pair." + output_dict[k] = v + k, v = None, None + + assert ( + k is None and v is None + ), "DictConstruct has an odd number of elements (violating our assumption)." + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_dict + + def convert_prim_ListUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def convert_prim_TupleUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def _convert_prim_unpack_iterator(self, node: torch._C.Node): + # Single input and multiple outputs for unpacking. + for i, outp in enumerate(node.outputs()): + outp_name = outp.debugName() + inp = self.get_fx_value(node.input()) + fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) + self.name_to_node[outp_name] = fx_node + def convert_aten_Int(self, node: torch._C.Node): # converts aten::Int as aten._to_copy + aten::_local_scalar_dense target = torch.ops.aten._to_copy.default @@ -312,30 +517,123 @@ def convert_aten_div(self, node: torch._C.Node): self.name_to_node[output_name] = fx_node return - self.convert_aten_op(node) + self.convert_call_function_op(node) + + def convert_aten___getitem__(self, node: torch._C.Node): + input_container, index = tuple( + self.get_fx_value(input) for input in node.inputs() + ) + fx_node = self.fx_graph.call_function( + operator.getitem, (input_container, index) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_If(self, node: torch._C.Node): + inputs = list(node.inputs()) + assert len(inputs) == 1 + predicate = self.get_fx_value(inputs[0]) + + # Get union of inputs to blocks + arguments = set() + for block in node.blocks(): + block_args = set() + + # TODO: block.inputs(), not sure what theyre used for + + for block_node in block.nodes(): + for block_node_in in block_node.inputs(): + if block_node_in.debugName() in self.name_to_node: + block_args.add(block_node_in.debugName()) + + arguments.update(block_args) + + # Lift parameters as inputs. + for block in node.blocks(): + arguments = arguments.union(self.blocks_to_lifted_attrs[block]) + + arguments = list(arguments) + + # Convert blocks to subgraphs + subgraph_nodes = [] + for block in node.blocks(): + subgraph_converter = TS2FXGraphConverter( + block, dict(), dict(), self.blocks_to_lifted_attrs + ) + subgraph_converter.constant_map = self.constant_map + subgraph_converter.attribute_map = self.attribute_map + + for block_arg in arguments: + normalized_block_arg_name = normalize_name(block_arg) + placeholder_node = subgraph_converter.fx_graph.placeholder( + normalized_block_arg_name + ) + subgraph_converter.name_to_node[block_arg] = placeholder_node + + subgraph = subgraph_converter.convert() + subgraph_name = self.add_subgraph(subgraph) + subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + + assert len(subgraph_nodes) == 2 + + fx_block_args = [self.name_to_node[arg_name] for arg_name in arguments] + args = ( + predicate, + subgraph_nodes[0], + subgraph_nodes[1], + tuple(fx_block_args), + ) + + cond_node = self.fx_graph.call_function(torch.cond, args, {}) + + output_name = node.output().debugName() + self.name_to_node[output_name] = cond_node + + def convert_aten_Bool(self, node: torch._C.Node): + self._convert_as_noop(node) + + def _convert_as_noop(self, node: torch._C.Node): + # Converts the node as a no-op by mapping its output node as arg[0] + + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + output_name = node.output().debugName() + self.name_to_node[output_name] = args[0] + + def convert_profiler__record_function_enter_new(self, node: torch._C.Node): + target = torch.ops.profiler._record_function_enter_new + args = tuple(self.get_fx_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_profiler__record_function_exit(self, node: torch._C.Node): + # _record_function_exit has side effect so we keep it in fx.graph + # currently, _record_function_enter_new and _record_function_exit are + # discarded during `retrace_as_exported_program`. + target = torch.ops.profiler._record_function_exit + args = tuple(self.get_fx_value(input) for input in node.inputs()) + self.fx_graph.call_function(target, args) + + def _convert_standard_operators(self, node: torch._C.Node): + target = kind_to_standard_operators[node.kind()] + args = tuple(self.get_fx_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node def convert_node(self, node: torch._C.Node): node_kind = node.kind() - if node_kind == "prim::CreateObject": - self.convert_prim_CreateObject(node) - elif node_kind == "prim::Constant": - self.convert_prim_Constant(node) - elif node_kind == "prim::GetAttr": - self.convert_prim_GetAttr(node) - elif node_kind == "prim::NumToTensor": - self.convert_prim_NumToTensor(node) - elif node_kind == "prim::ListConstruct": - self.convert_prim_ListConstruct(node) - # elif node_kind == "aten::Int": - # convert_aten_Int(node) - elif node_kind == "aten::_convolution": - self.convert_aten__convolution(node) - elif node_kind == "aten::div": - self.convert_aten_div(node) - elif node_kind.startswith("aten::"): - self.convert_aten_op(node) - else: - raise ValueError(f"Unsupported node kind: {node_kind}") + + # Get handler based on namespace and operator name. + # Provide a default node handler as well in case we don't find + # matching converter for that. + handler_func_name = ir_name_to_func_name(node_kind) + handler_func = getattr(self, handler_func_name, self.convert_call_function_op) + handler_func(node) def convert_graph_outputs(self): args = [] @@ -343,26 +641,75 @@ def convert_graph_outputs(self): output_name = graph_output.debugName() if output_name in self.name_to_node: args.append(self.name_to_node[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_name), + target=output_name, + ) + ) + elif output_name in self.constant_map: + args.append(self.constant_map[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=ConstantArgument( + name=output_name, value=self.constant_map[output_name] + ), + target=output_name, + ) + ) else: raise ValueError(f"Output {output_name} not found") - self.output_specs.append( - OutputSpec( - OutputKind.USER_OUTPUT, - arg=TensorArgument(name=output_name), - target=output_name, - ) - ) + self.fx_graph.output( + args[0] + ) # Get rid of an extra list wrapped around final output. - self.fx_graph.output(args) - def retrace_as_exported_program(self, gm: torch.fx.GraphModule): - # TODO: adjust input orders to match GraphSignature convention - inputs = [*self.sample_args, *self.params, *self.tensor_constants.values()] +class TS2EPConverter: + # TorchScript model to ExportedProgram converter + def __init__( + self, + ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], + sample_args: Tuple[Any, ...], + sample_kwargs: Optional[Dict[str, Any]] = None, + ): + self.ts_model = ts_model + self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) + + self.sample_args = sample_args + self.sample_kwargs = sample_kwargs + self.name_to_param_map: Dict[str, torch.Tensor] = ( + dict(ts_model.named_parameters()) + if isinstance(ts_model, torch.jit.ScriptModule) + else dict() + ) + self.name_to_buffer_map: Dict[str, torch.Tensor] = ( + dict(ts_model.named_buffers()) + if isinstance(ts_model, torch.jit.ScriptModule) + else dict() + ) + + def convert(self) -> ExportedProgram: + blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) + + graph_converter = TS2FXGraphConverter( + self.ts_graph, + self.name_to_param_map, + self.name_to_buffer_map, + blocks_to_lifted_attrs, + ) + gm = graph_converter.convert() + ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants) + return ep + + def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants): + # TODO: adjust input orders to match GraphSignature convention ep = torch.export._trace._export( gm, - tuple(inputs), + self.sample_args, strict=False, pre_dispatch=True, ) diff --git a/torch/_export/db/case.py b/torch/_export/db/case.py index 6c4c03572e3a..21b456fbe029 100644 --- a/torch/_export/db/case.py +++ b/torch/_export/db/case.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import re import string diff --git a/torch/_export/db/examples/__init__.py b/torch/_export/db/examples/__init__.py index d737548c3d48..2e93d4b80824 100644 --- a/torch/_export/db/examples/__init__.py +++ b/torch/_export/db/examples/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import glob import importlib from os.path import basename, dirname, isfile, join diff --git a/torch/_export/db/examples/assume_constant_result.py b/torch/_export/db/examples/assume_constant_result.py index 0078200bc0f0..1503e0c91134 100644 --- a/torch/_export/db/examples/assume_constant_result.py +++ b/torch/_export/db/examples/assume_constant_result.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch._dynamo as torchdynamo diff --git a/torch/_export/db/examples/autograd_function.py b/torch/_export/db/examples/autograd_function.py index 9c8aeadc45ae..3c9099b0cdb8 100644 --- a/torch/_export/db/examples/autograd_function.py +++ b/torch/_export/db/examples/autograd_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/class_method.py b/torch/_export/db/examples/class_method.py index 838a0a1cdb67..831339372274 100644 --- a/torch/_export/db/examples/class_method.py +++ b/torch/_export/db/examples/class_method.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_branch_class_method.py b/torch/_export/db/examples/cond_branch_class_method.py index 40430d23c0f2..21fe1d25516a 100644 --- a/torch/_export/db/examples/cond_branch_class_method.py +++ b/torch/_export/db/examples/cond_branch_class_method.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_branch_nested_function.py b/torch/_export/db/examples/cond_branch_nested_function.py index 00bce0b580a1..03639c0a207d 100644 --- a/torch/_export/db/examples/cond_branch_nested_function.py +++ b/torch/_export/db/examples/cond_branch_nested_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/torch/_export/db/examples/cond_branch_nonlocal_variables.py index 2db6192117df..676e7d21ffd2 100644 --- a/torch/_export/db/examples/cond_branch_nonlocal_variables.py +++ b/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_closed_over_variable.py b/torch/_export/db/examples/cond_closed_over_variable.py index 226576cc83f7..cf4787f481c4 100644 --- a/torch/_export/db/examples/cond_closed_over_variable.py +++ b/torch/_export/db/examples/cond_closed_over_variable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_operands.py b/torch/_export/db/examples/cond_operands.py index 1a0db6a110d3..03fd467959a2 100644 --- a/torch/_export/db/examples/cond_operands.py +++ b/torch/_export/db/examples/cond_operands.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/cond_predicate.py b/torch/_export/db/examples/cond_predicate.py index c72c11e32f57..fa3cdeaf3b05 100644 --- a/torch/_export/db/examples/cond_predicate.py +++ b/torch/_export/db/examples/cond_predicate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/constrain_as_size_example.py b/torch/_export/db/examples/constrain_as_size_example.py index 16d646252414..a3664b7e80f1 100644 --- a/torch/_export/db/examples/constrain_as_size_example.py +++ b/torch/_export/db/examples/constrain_as_size_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/constrain_as_value_example.py b/torch/_export/db/examples/constrain_as_value_example.py index 1de266c689c4..b1b412d41391 100644 --- a/torch/_export/db/examples/constrain_as_value_example.py +++ b/torch/_export/db/examples/constrain_as_value_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/decorator.py b/torch/_export/db/examples/decorator.py index fbc95182e60e..da963ce7da01 100644 --- a/torch/_export/db/examples/decorator.py +++ b/torch/_export/db/examples/decorator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/_export/db/examples/dictionary.py b/torch/_export/db/examples/dictionary.py index 5a210906e680..19f138e6f4d1 100644 --- a/torch/_export/db/examples/dictionary.py +++ b/torch/_export/db/examples/dictionary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_assert.py b/torch/_export/db/examples/dynamic_shape_assert.py index 52cc43a21049..57ba98552e0c 100644 --- a/torch/_export/db/examples/dynamic_shape_assert.py +++ b/torch/_export/db/examples/dynamic_shape_assert.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_constructor.py b/torch/_export/db/examples/dynamic_shape_constructor.py index 599747f7968a..5ce7fdda2877 100644 --- a/torch/_export/db/examples/dynamic_shape_constructor.py +++ b/torch/_export/db/examples/dynamic_shape_constructor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_if_guard.py b/torch/_export/db/examples/dynamic_shape_if_guard.py index 2120ec0145fe..9350c6d992f5 100644 --- a/torch/_export/db/examples/dynamic_shape_if_guard.py +++ b/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_map.py b/torch/_export/db/examples/dynamic_shape_map.py index 5607c2796d68..421d4b355efb 100644 --- a/torch/_export/db/examples/dynamic_shape_map.py +++ b/torch/_export/db/examples/dynamic_shape_map.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_round.py b/torch/_export/db/examples/dynamic_shape_round.py index d581d6d839bc..57a1e07dab97 100644 --- a/torch/_export/db/examples/dynamic_shape_round.py +++ b/torch/_export/db/examples/dynamic_shape_round.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/dynamic_shape_slicing.py b/torch/_export/db/examples/dynamic_shape_slicing.py index eb237876f4e6..ddc2f86f774c 100644 --- a/torch/_export/db/examples/dynamic_shape_slicing.py +++ b/torch/_export/db/examples/dynamic_shape_slicing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/dynamic_shape_view.py b/torch/_export/db/examples/dynamic_shape_view.py index bcedd04cf36f..666da36ad2a8 100644 --- a/torch/_export/db/examples/dynamic_shape_view.py +++ b/torch/_export/db/examples/dynamic_shape_view.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/fn_with_kwargs.py b/torch/_export/db/examples/fn_with_kwargs.py index 6182a7479555..d5a9a23415d9 100644 --- a/torch/_export/db/examples/fn_with_kwargs.py +++ b/torch/_export/db/examples/fn_with_kwargs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, ExportArgs, SupportLevel diff --git a/torch/_export/db/examples/list_contains.py b/torch/_export/db/examples/list_contains.py index d25d815cde1a..6105220c09b9 100644 --- a/torch/_export/db/examples/list_contains.py +++ b/torch/_export/db/examples/list_contains.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/list_unpack.py b/torch/_export/db/examples/list_unpack.py index 2251c6eb360d..66b4fe456a0d 100644 --- a/torch/_export/db/examples/list_unpack.py +++ b/torch/_export/db/examples/list_unpack.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List import torch diff --git a/torch/_export/db/examples/model_attr_mutation.py b/torch/_export/db/examples/model_attr_mutation.py index 409a0c0f6c03..4c2a03d4e77b 100644 --- a/torch/_export/db/examples/model_attr_mutation.py +++ b/torch/_export/db/examples/model_attr_mutation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/nested_function.py b/torch/_export/db/examples/nested_function.py index 608ef39d5187..cc668ee561a6 100644 --- a/torch/_export/db/examples/nested_function.py +++ b/torch/_export/db/examples/nested_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/null_context_manager.py b/torch/_export/db/examples/null_context_manager.py index da759b0980fa..ff4b94e6bf44 100644 --- a/torch/_export/db/examples/null_context_manager.py +++ b/torch/_export/db/examples/null_context_manager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/_export/db/examples/optional_input.py b/torch/_export/db/examples/optional_input.py index 47bb5e1bab8d..dfc256d6a5ce 100644 --- a/torch/_export/db/examples/optional_input.py +++ b/torch/_export/db/examples/optional_input.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/pytree_flatten.py b/torch/_export/db/examples/pytree_flatten.py index 0d799b2a609a..9c91cc21df3c 100644 --- a/torch/_export/db/examples/pytree_flatten.py +++ b/torch/_export/db/examples/pytree_flatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/scalar_output.py b/torch/_export/db/examples/scalar_output.py index 86217847bff8..46e03c1f7e94 100644 --- a/torch/_export/db/examples/scalar_output.py +++ b/torch/_export/db/examples/scalar_output.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/specialized_attribute.py b/torch/_export/db/examples/specialized_attribute.py index 3f8f09c4128d..a53ad213c63f 100644 --- a/torch/_export/db/examples/specialized_attribute.py +++ b/torch/_export/db/examples/specialized_attribute.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum import torch diff --git a/torch/_export/db/examples/static_for_loop.py b/torch/_export/db/examples/static_for_loop.py index af14f6fe8ae1..4ad60737ff5d 100644 --- a/torch/_export/db/examples/static_for_loop.py +++ b/torch/_export/db/examples/static_for_loop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/static_if.py b/torch/_export/db/examples/static_if.py index 048bf20ce8bf..bc5dce9f0667 100644 --- a/torch/_export/db/examples/static_if.py +++ b/torch/_export/db/examples/static_if.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case diff --git a/torch/_export/db/examples/tensor_setattr.py b/torch/_export/db/examples/tensor_setattr.py index fae18fb1cf93..201dca37c81a 100644 --- a/torch/_export/db/examples/tensor_setattr.py +++ b/torch/_export/db/examples/tensor_setattr.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/torch_sym_min.py b/torch/_export/db/examples/torch_sym_min.py index f7edc7003f14..a8fe560773a4 100644 --- a/torch/_export/db/examples/torch_sym_min.py +++ b/torch/_export/db/examples/torch_sym_min.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/examples/type_reflection_method.py b/torch/_export/db/examples/type_reflection_method.py index 869fb4cadd65..5d6570ca0cb9 100644 --- a/torch/_export/db/examples/type_reflection_method.py +++ b/torch/_export/db/examples/type_reflection_method.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel, export_rewrite_case diff --git a/torch/_export/db/examples/user_input_mutation.py b/torch/_export/db/examples/user_input_mutation.py index 01c5d775a264..b60036257617 100644 --- a/torch/_export/db/examples/user_input_mutation.py +++ b/torch/_export/db/examples/user_input_mutation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._export.db.case import export_case, SupportLevel diff --git a/torch/_export/db/logging.py b/torch/_export/db/logging.py index fc412b8c5082..8cd0827d3893 100644 --- a/torch/_export/db/logging.py +++ b/torch/_export/db/logging.py @@ -1,2 +1,3 @@ +# mypy: allow-untyped-defs def exportdb_error_message(case_name: str): return "" diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py index 5d28ea315490..49dfd0cf996e 100644 --- a/torch/_export/exported_program.py +++ b/torch/_export/exported_program.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 8f67a3cd258e..9db3653de1e2 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import inspect from collections import defaultdict @@ -111,7 +112,12 @@ def make_fake_params_buffers( def make_fake_inputs( - nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=False + nn_module, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=False, + _allow_complex_guards_as_runtime_asserts=False, ): """ Given an nn module, example inputs, and constraints, return a new fake mode, @@ -156,13 +162,22 @@ def make_fake_inputs( "co_firstlineno": code.co_firstlineno, } fake_mode = FakeTensorMode( - shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields), + shape_env=ShapeEnv( + tracked_fakes=[], + co_fields=co_fields, + prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, + ), allow_non_fake_inputs=True, export=True, ) else: fake_mode = FakeTensorMode( - shape_env=ShapeEnv(tracked_fakes=[]), + shape_env=ShapeEnv( + tracked_fakes=[], + prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, + ), allow_non_fake_inputs=True, ) if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 1cf7e75ad5f9..840fc663f3ea 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import traceback import typing @@ -32,6 +33,7 @@ _TORCH_SYM_OPS: Set[Callable] = { torch.sym_int, + torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, diff --git a/torch/_export/pass_infra/proxy_value.py b/torch/_export/pass_infra/proxy_value.py index 66592d48a45e..07d888b30656 100644 --- a/torch/_export/pass_infra/proxy_value.py +++ b/torch/_export/pass_infra/proxy_value.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # pyre-strict from typing import Union diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index e04059a9114a..3dd87b546da8 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index 5a2a8b5874bf..44f0ea270212 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import operator import traceback diff --git a/torch/_export/passes/collect_tracepoints_pass.py b/torch/_export/passes/collect_tracepoints_pass.py index ca8eaf30be59..8d65a720b9d7 100644 --- a/torch/_export/passes/collect_tracepoints_pass.py +++ b/torch/_export/passes/collect_tracepoints_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import torch diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py new file mode 100644 index 000000000000..684fe07b0ec3 --- /dev/null +++ b/torch/_export/passes/constant_folding.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs +import collections +from collections import defaultdict +from typing import Any, Callable, Dict, Optional + +import torch +import torch.utils._pytree as pytree + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm, + skip_constructors=False, + ): + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.node.Node): + if ( + node.target == torch.ops.prims.convert_element_type.default + and node.args[0].op == "get_attr" # type: ignore[union-attr] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): + env = {} + for n in self.module.graph.find_nodes(op="placeholder"): + env[n] = self.unknown_value + return super().run(initial_env=env) + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_graph_tag(gm: torch.fx.GraphModule): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 83914fb828c5..d9cd62ffc928 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from typing import Any, Dict, List, Union diff --git a/torch/_export/passes/remove_runtime_assertions.py b/torch/_export/passes/remove_runtime_assertions.py index adcc708e5548..a80b62d2765a 100644 --- a/torch/_export/passes/remove_runtime_assertions.py +++ b/torch/_export/passes/remove_runtime_assertions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/torch/_export/passes/replace_set_grad_with_hop_pass.py b/torch/_export/passes/replace_set_grad_with_hop_pass.py index 91104c17c38d..0b0bef582e45 100644 --- a/torch/_export/passes/replace_set_grad_with_hop_pass.py +++ b/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import copy diff --git a/torch/_export/passes/replace_sym_size_ops_pass.py b/torch/_export/passes/replace_sym_size_ops_pass.py index 109a96d7b4bd..29d594d41f06 100644 --- a/torch/_export/passes/replace_sym_size_ops_pass.py +++ b/torch/_export/passes/replace_sym_size_ops_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py index f32b442733eb..edc249b572b5 100644 --- a/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +++ b/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Optional, Set import torch diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index cde4cf1ada27..b22b9778819e 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import hashlib import re diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 38ef1da7d5d4..f8fdc1011b52 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import base64 import copy import copyreg @@ -170,6 +171,7 @@ def _reverse_map(d: Dict[Any, Enum]): operator.floordiv, operator.mod, torch.sym_int, + torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, @@ -387,14 +389,6 @@ def _is_single_tensor_list_return(target: Any) -> bool: return_type.getElementType(), torch.TensorType ) -def _output_node_at_index(node, index): - for user in node.users: - assert user.target is operator.getitem, f"{user} is not a getitem node" - if index == user.args[1]: - return user - return None - - @dataclass class GraphState: @@ -427,6 +421,7 @@ def __init__( self.graph_signature = graph_signature self.module_call_graph = module_call_graph self.custom_objs: Dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: Dict[str, str] = {} @contextmanager def save_graph_state(self): @@ -552,6 +547,19 @@ def handle_call_function(self, node: torch.fx.Node): def handle_get_attr(self, node): pass + def _output_node_at_index(self, node, index): + user_node = None + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + if user_node is None: + user_node = user + else: + # We want to deduplicate getitem nodes that are trying to + # index to the same index + self.duplicate_getitem_nodes[user.name] = user_node.name + return user_node + def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: ret = {} if stack_trace := node.meta.get("stack_trace"): @@ -705,13 +713,16 @@ def serialize_input( return Argument.create( as_sym_bool=SymBoolArgument.create(as_name=arg.name) ) - else: - if isinstance(arg.meta["val"], ep.CustomObjArgument): - return Argument.create( - as_custom_obj=CustomObjArgument( - name=arg.name, class_fqn=arg.meta["val"].class_fqn - ) + elif isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn ) + ) + elif arg.name in self.duplicate_getitem_nodes: + dedup_name = self.duplicate_getitem_nodes[arg.name] + return Argument.create(as_tensor=TensorArgument(name=dedup_name)) + else: return Argument.create(as_tensor=TensorArgument(name=arg.name)) elif isinstance(arg, inductor_tensor_buffers): # Other branches are for arguments in fx node. @@ -1121,7 +1132,7 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: # e.g "-> Tensor[]" tensor_args = [] for idx, meta in enumerate(meta_val): - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) name = ( user_node.name if user_node is not None @@ -1151,7 +1162,7 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: output_arguments.append(Argument.create(as_none=())) elif isinstance(meta, FakeTensor): assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType)) - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) name = ( user_node.name if user_node is not None @@ -1165,20 +1176,20 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: ) and isinstance( return_schema.real_type.getElementType(), torch.TensorType ) - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) assert user_node is not None args = [] for i, m in enumerate(meta): if m is None: continue - sub_user_node = _output_node_at_index(user_node, i) + sub_user_node = self._output_node_at_index(user_node, i) assert sub_user_node is not None, f"No user found at index {i}" args.append(self.serialize_tensor_output(sub_user_node.name, m)) output_arguments.append(Argument.create(as_tensors=args)) elif isinstance(meta, (int, SymInt)): - user_node = _output_node_at_index(node, idx) + user_node = self._output_node_at_index(node, idx) name = ( user_node.name if user_node is not None @@ -1208,7 +1219,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: if len(meta_val) == 1: assert isinstance(meta_val[0], torch.Tensor) - user_node = _output_node_at_index(node, 0) + user_node = self._output_node_at_index(node, 0) name = ( user_node.name if user_node is not None @@ -1218,7 +1229,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: outputs = [] for i, element_meta_val in enumerate(meta_val): - user_node = _output_node_at_index(node, i) + user_node = self._output_node_at_index(node, i) if isinstance(element_meta_val, list): # e.g "-> Tensor[]" assert user_node is not None @@ -1228,7 +1239,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: if not isinstance(m, torch.Tensor): raise SerializeError(f"Serialize list output with type {type(m)} nyi") - sub_user_node = _output_node_at_index(user_node, j) + sub_user_node = self._output_node_at_index(user_node, j) name = ( sub_user_node.name if sub_user_node is not None @@ -1465,10 +1476,15 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - if isinstance(sym, sympy.Symbol): + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1487,7 +1503,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index 8dfce61f0ab2..b129e8dd9a89 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from dataclasses import fields from typing import Hashable, Set diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py index d35fe7e1586c..c427a4030c9c 100644 --- a/torch/_export/serde/upgrade.py +++ b/torch/_export/serde/upgrade.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs class GraphModuleOpUpgrader: diff --git a/torch/_export/tools.py b/torch/_export/tools.py index d76392993bd2..23fae4a9196c 100644 --- a/torch/_export/tools.py +++ b/torch/_export/tools.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import warnings from typing import Any, Dict, Iterable, Optional, Tuple diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 772bd3e124b7..1cec59aaaa0c 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import dataclasses import inspect diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 3f89324642eb..8ee7c8926834 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import math import operator @@ -175,6 +176,7 @@ def _allowed_op_types() -> Tuple[Type[Any], ...]: _allowed_torch_functions = ( torch.autograd.grad_mode.set_grad_enabled, torch.sym_int, + torch.sym_float, torch.sym_ite, torch.sym_max, torch.sym_min, diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index 5ca2375ec124..c18ed34a395c 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager import torch diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 07885c136c7d..8144a47f057a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Utils for caching the outputs of AOTAutograd """ @@ -6,6 +7,7 @@ import functools import logging import os +from typing import TYPE_CHECKING import torch from torch._functorch import config @@ -16,10 +18,12 @@ FxGraphHashDetails, get_code_hash, ) -from torch.fx.node import Node from .schemas import AOTConfig # noqa: F401 +if TYPE_CHECKING: + from torch.fx.node import Node + log = logging.getLogger(__name__) @@ -118,7 +122,7 @@ def __init__( self.code_hash = get_autograd_code_hash() self.autograd_config = config.save_config() try: - super().__init__(gm, example_inputs, {}) + super().__init__(gm, example_inputs, {}, []) except BypassFxGraphCache as e: # Sometimes inductor configs are unpickleable and can fail raise BypassAOTAutogradCache from e diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index e01f6df6957d..44301291a91f 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module is one of the analysis modules - it takes as input a function or graph and some preexisting properties, and returns some data that is useful for deciding @@ -665,6 +666,15 @@ def view_avoid_dupes_with_primals(t): ) user_outs = pytree.tree_map(from_fun, f_output_tangents) + if torch._dynamo.config.inline_inbuilt_nn_modules: + static_parameter_input_indices = [ + i + for i, arg in enumerate(flat_args) + if isinstance(arg, torch.nn.Parameter) + ] + else: + static_parameter_input_indices = [] + f_mutated_inputs = [ inp for inp, info in zip(flat_f_args, input_info) @@ -716,6 +726,7 @@ def view_avoid_dupes_with_primals(t): subclass_tangent_meta=create_subclass_meta(traced_tangents), is_train=is_train, grad_enabled_mutation=grad_enabled_mutation, + static_parameter_indices=static_parameter_input_indices, tokens=mode._tokens, ) return metadata diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 1a6f1c7dce1e..c38a98366cb3 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module dispatches the graphs to either the forward-only or joint compilation pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. @@ -186,11 +187,20 @@ def _map_assigned_buffer_to_proxy(_mod, name, buffer): if aot_config.enable_log: aot_graphs_log.info( - "%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id) + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) trace_structured( "aot_forward_graph", - payload_fn=lambda: fw_module.print_readable(print_output=False), + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) # TODO: should factor this into a separate function for export that always only returns just the graph. diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 2e0a7d322f6f..02cf2ab0f428 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file contains utilities related to functionalization in AOTAutograd: 1. converting to/from functional tensors @@ -5,7 +6,9 @@ 3. regenerating/replaying views from their base 4. checking if a graph is functional i.e. whether it contains any mutation ops """ +from __future__ import annotations +from typing import Optional import torch from torch import Tensor @@ -219,10 +222,7 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, - # Actual type: Optional[FunctionalTensorMetadataEq] - # Can't use it here because it lives inside schemas.py. Importing that class would lead - # to an error due to an import cycle. - target_functional_tensor=None, + target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, ): # Patch the correct requires_grad field of the output tensor, depending on whether: # (i) the reconstructed output (out) was came from a tensor that requires grad or not; @@ -244,9 +244,6 @@ def patch_requires_grad(out): and target_functional_tensor is not None and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) ): - from .schemas import FunctionalTensorMetadataEq - - assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq) functional_tensor = target_functional_tensor.tensor out = torch._functionalize_apply_view_metas( @@ -321,6 +318,27 @@ def has_same_metadata(t1, t2): ) +# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata +# after applying all the ViewMeta operations. +class FunctionalTensorMetadataEq: + def __init__(self, tensor: torch.Tensor) -> None: + assert torch._is_functional_tensor(tensor) + self.tensor = tensor + + def __eq__(self, other: object) -> bool: + # If other is None, then it probably means that we weren't able to recreate + # the FunctionalTensorMetadataEq. One of this cases is when we update the + # view metadata by calling: create_synthetic_base_metadata. + if other is None: + return True + + # Comparison agains any other type is not implemented. + if not isinstance(other, FunctionalTensorMetadataEq): + return NotImplemented + + return has_same_metadata(self.tensor, other.tensor) + + # new_arg and arg here are either: # (1) both a FakeTensor # (2) both a traceable tensor subclass that holds a FakeTensor diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 9a02dffb3d1b..29a32ee03078 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module is one of the analysis modules - it takes as input a function or graph and some preexisting properties, and returns some data that is useful for deciding diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 9eff7b20c04b..5eb681889d8a 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Functions in this module do most of the "work" of AOTAutograd. An aot_dispatch_* function: @@ -247,11 +248,20 @@ def aot_dispatch_autograd( if aot_config.enable_log: aot_joint_log.info( - "%s", lazy_format_graph_code("Joint graph", fx_g, aot_config.aot_id) + "%s", + lazy_format_graph_code( + "Joint graph", + fx_g, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) trace_structured( "aot_joint_graph", - payload_fn=lambda: fx_g.print_readable(print_output=False), + payload_fn=lambda: fx_g.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) with torch.no_grad(): @@ -389,19 +399,35 @@ def aot_dispatch_autograd( if aot_config.enable_log: aot_graphs_log.info( "%s", - lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id), + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) aot_graphs_log.info( "%s", - lazy_format_graph_code("Backward graph", bw_module, aot_config.aot_id), + lazy_format_graph_code( + "Backward graph", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + ), ) trace_structured( "aot_forward_graph", - payload_fn=lambda: fw_module.print_readable(print_output=False), + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) trace_structured( "aot_backward_graph", - payload_fn=lambda: bw_module.print_readable(print_output=False), + payload_fn=lambda: bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) with track_graph_compiling(aot_config, "forward"): diff --git a/torch/_functorch/_aot_autograd/logging_utils.py b/torch/_functorch/_aot_autograd/logging_utils.py index 414166cbdd2f..c961f74dc6c1 100644 --- a/torch/_functorch/_aot_autograd/logging_utils.py +++ b/torch/_functorch/_aot_autograd/logging_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Contains utils for logging in AOTAutograd, including managing the names of the graphs under compilation, capturing user-friendly tracebacks, and debug messages. diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index a450f401f9e2..0afa24ce4ee8 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module defines runtime wrappers, which, based on previous analysis attempts to: 1. process the inputs and outputs @@ -7,7 +8,6 @@ """ import collections import pprint -import time from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps @@ -24,7 +24,6 @@ tracing, TracingContext, ) -from torch._logging import trace_structured from torch._prims_common import CUDARngStateHelper from torch._subclasses import FakeTensor @@ -201,24 +200,31 @@ def runtime_wrapper(args: List[Any]): for idx in indices_of_inps_to_detach: if isinstance(args_[idx], torch.Tensor): args_[idx] = args_[idx].detach() - with torch.autograd._force_original_view_tracking(True): + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with torch.autograd._force_original_view_tracking( + True + ), torch.enable_grad(): all_outs = call_func_at_runtime_with_args( compiled_fn, args_, disable_amp=disable_amp, steal_args=True ) else: - # When we have an inference graph, we run with torch.no_grad. + # When we have an inference graph, we run with grad disabled. # It's possible to get an inference graph with inputs that require grad, # in which case we want to make sure autograd is disabled # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) - if torch.is_grad_enabled(): - with torch.no_grad(): - all_outs = call_func_at_runtime_with_args( - compiled_fn, args, disable_amp=disable_amp, steal_args=True - ) - else: + # NOTE: We use _set_grad_enabled directly to reduce runtime overhead + grad_enabled = torch.is_grad_enabled() + try: + if grad_enabled: + torch._C._set_grad_enabled(False) all_outs = call_func_at_runtime_with_args( compiled_fn, args, disable_amp=disable_amp, steal_args=True ) + finally: + if grad_enabled: + torch._C._set_grad_enabled(True) del args num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices @@ -391,7 +397,7 @@ def runtime_wrapper(args: List[Any]): else: t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy() if runtime_metadata.grad_enabled_mutation is not None: - torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation) + torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) return ret_outs return runtime_wrapper @@ -1801,41 +1807,9 @@ def call_compiled_backward(): with tracing(saved_context), compile_context( saved_compile_context ), context(), track_graph_compiling(aot_config, "backward"): - fail_type: Optional[str] = None - fail_reason: Optional[str] = None - start_time = time.time() - try: - CompiledFunction.compiled_bw = aot_config.bw_compiler( - bw_module, placeholder_list - ) - except Exception as e: - fail_type = str(type(e)) - fail_reason = str(e) - if saved_compile_context is not None: - e.compile_id = saved_compile_context.compile_id # type: ignore[attr-defined] - raise - finally: - # TODO: Similar to CompilationMetrics, we would - # like to report inductor_compile_time, but we - # cannot conveniently do so because these are - # keyed on utils.frame, and frame key is not - # incremented on backwards compilations. Maybe - # should just bump the frame key here too? - end_time = time.time() - # TODO: Put this in scuba? But CompilationMetrics - # is kind of not a great match, because there's no - # interaction with Dynamo, so a lot of Dynamo only - # events don't exist anymore. So we need a new - # scuba table. Lazy lazy... - trace_structured( - "aot_autograd_backward_compilation_metrics", - lambda: { - "start_time": start_time, - "elapsed_time": time.time() - start_time, - "fail_type": fail_type, - "fail_reason": fail_reason, - }, - ) + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, placeholder_list + ) out = call_func_at_runtime_with_args( CompiledFunction.compiled_bw, diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 982fcb9e6464..d5588a6e912c 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes input/output types, metadata, config, function signatures etc. @@ -17,7 +18,10 @@ from .. import config -from .functional_utils import _check_if_mutation_can_be_in_graph, has_same_metadata +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + FunctionalTensorMetadataEq, +) from .utils import strict_zip zip = strict_zip @@ -54,27 +58,6 @@ ) -# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata -# after applying all the ViewMeta operations. -class FunctionalTensorMetadataEq: - def __init__(self, tensor: torch.Tensor) -> None: - assert torch._is_functional_tensor(tensor) - self.tensor = tensor - - def __eq__(self, other: object) -> bool: - # If other is None, then it probably means that we weren't able to recreate - # the FunctionalTensorMetadataEq. One of this cases is when we update the - # view metadata by calling: create_synthetic_base_metadata. - if other is None: - return True - - # Comparison agains any other type is not implemented. - if not isinstance(other, FunctionalTensorMetadataEq): - return NotImplemented - - return has_same_metadata(self.tensor, other.tensor) - - # This class stores info about every user output. @dataclass(frozen=True) class OutputAliasInfo: @@ -304,6 +287,9 @@ class ViewAndMutationMeta: # raised deterministic: Optional[bool] = None + # Keeps track of which input indices store parameters (which we will treat as static) + static_parameter_indices: List[int] = field(default_factory=list) + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are # side-effectful operators, FunctionalTensorMode will populate this # dictionary telling us how many tokens we will need during tracing. diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index cee3cf6e4eda..98f08bb786c4 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index c673acdabe12..fa33d9fd79c4 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module is responsible for transforming functions to be traced into a form that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) @@ -546,9 +547,14 @@ def _functionalized_f_helper(*args): and meta.input_info[i].mutations_hidden_from_autograd ): # Hidden from autograd = run under no_grad, **and** don't bump VC - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - inpt_old - ): + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: inpt_old.copy_(inpt_new) elif ( meta.input_info[i].mutates_data diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index e23a32f10cc4..3d577d2b37b5 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Contains various utils for AOTAutograd, including those for handling collections. """ @@ -25,6 +26,7 @@ type(None), *py_sym_types, FakeScriptObject, + torch.ScriptObject, ] original_zip = zip diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index f7724a6add60..c52a9cde0d55 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -493,6 +493,12 @@ def convert(idx, x): return shape_env.create_symintnode( shape_env.create_symbol(x, source), hint=x, source=source ) + if isinstance( + x, torch.ScriptObject + ) and torch._library.fake_class_registry.has_fake_class( + x._type().qualified_name() + ): + return torch._library.fake_class_registry.to_fake_obj(fake_mode, x) if not isinstance(x, torch.Tensor): return x if isinstance(x, FakeTensor): @@ -509,10 +515,14 @@ def convert(idx, x): # see note [Tensor Fakification and Symbol Caching] symbolic_context = None source = None + trace = True if tracing_context := torch._guards.TracingContext.try_get(): if x in tracing_context.tensor_to_context: symbolic_context = tracing_context.tensor_to_context[x] source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False if ( idx < aot_config.num_params_buffers and config.static_weight_shapes @@ -527,15 +537,15 @@ def convert(idx, x): static_shapes=False, symbolic_context=symbolic_context, source=source, + trace=trace, ) return [convert(idx, x) for idx, x in enumerate(flat_args)] fake_flat_args = process_inputs(flat_args) - needs_autograd = ( - any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) - and torch.is_grad_enabled() + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) ) with enable_python_dispatcher(): @@ -559,10 +569,19 @@ def convert(idx, x): fake_flat_args, fw_metadata ) - if needs_autograd and not any( + output_and_mutation_safe = not any( x.requires_grad for x in fw_metadata.output_info - ): + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) + + if needs_autograd and output_and_mutation_safe: # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. # so we actually have an inference graph. needs_autograd = False # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index 477a01583b3d..1b755550a8bf 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can # trace through functorch transforms. # Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing @@ -188,7 +189,7 @@ def vmap( vmap does not provide general autobatching or handle variable-length sequences out of the box. """ - from torch.compiler import is_compiling + from torch._dynamo import is_compiling _check_randomness_arg(randomness) if not (chunk_size is None or chunk_size > 0): @@ -390,7 +391,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla """ # To avoid cyclical dependency. import torch._functorch.eager_transforms as eager_transforms - from torch.compiler import is_compiling + from torch._dynamo import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) @@ -432,8 +433,8 @@ def grad_and_value( See :func:`grad` for examples """ + from torch._dynamo import is_compiling from torch._functorch import eager_transforms - from torch.compiler import is_compiling def wrapper(*args, **kwargs): return eager_transforms.grad_and_value_impl( diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 03bfd710ae34..b827fb20424c 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, NamedTuple, Tuple import torch diff --git a/torch/_functorch/batch_norm_replacement.py b/torch/_functorch/batch_norm_replacement.py index a2df284138e7..672a8ce76955 100644 --- a/torch/_functorch/batch_norm_replacement.py +++ b/torch/_functorch/batch_norm_replacement.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.nn as nn from torch._functorch.utils import exposed_in diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index c559951f3809..60bbf1f21c66 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -88,6 +88,39 @@ # a fusion can be expensive. ban_recompute_reductions = True +# By default, the partitioner is purely trying to optimize for runtime (although +# it should always use less memory than eager) +# This knob controls the partitioner to make that tradeoff for you, choosing the +# fastest option that saves less activations than the memory budget. +# Specifically, 0.0 corresponds to the activation memory from applying +# activation checkpointing to the full compiled region, and 1.0 corresponds to +# the activation memory from the default runtime-optimized strategy. So, 0.4 +# would result in a strategy that saves 40% of the activations compared to the +# default strategy. +# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below +# the activation memory budget. +# NOTE: This *cannot* be treated as +activation_memory_budget = 1.0 + +# This controls how we estimate the runtime when deciding what the cheapest +# operators to recompute are. The 3 options are +# "flops": Bases it off of the flop count provided by torch.utils.flop_counter +# "profile": Benchmarks each operator to come up with a runtime +# "testing": Returns 1 for everything +activation_memory_budget_runtime_estimator = "flops" + +# This controls the solver used for the 0-1 knapsack. By default we use a +# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" +# (which has a scipy dependency). +activation_memory_budget_solver = "dp" + +# This dumps out a png visualization of the expected runtime vs. activation +# memory tradeoffs for all memory budget values from 0 to 1 in increments of +# 0.5. See an example here: +# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 +visualize_memory_budget_pareto = ( + os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" +) # Sets all of the ban_recompute heuristics to False except ban_recompute_reductions # Generally, this will probably result in some memory improvement, but at the diff --git a/torch/_functorch/deprecated.py b/torch/_functorch/deprecated.py index bf080fcc3165..ebb930e8ecb7 100644 --- a/torch/_functorch/deprecated.py +++ b/torch/_functorch/deprecated.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The APIs in this file are exposed as `functorch.*`. They are thin wrappers around the torch.func.* APIs that have deprecation warnings -- we're trying @@ -37,7 +38,7 @@ def get_warning(api, new_api=None, replace_newlines=False): def warn_deprecated(api, new_api=None): warning = get_warning(api, new_api, replace_newlines=True) - warnings.warn(warning, FutureWarning, stacklevel=2) + warnings.warn(warning, FutureWarning, stacklevel=3) def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 80751c9694fd..fbea5164014b 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -765,8 +765,10 @@ def compute_jacobian_preallocate_and_copy(): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) + else: + wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn @@ -1346,8 +1348,10 @@ def push_jvp(basis): # Dynamo does not support HOP composition if their inner function is # annotated with @functools.wraps(...). We circumvent this issue by applying # wraps only if we're not tracing with dynamo. - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) + else: + wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 7533811ed235..5552036e8ddf 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import Counter from typing import Any, Dict, List, Optional, Sequence, Tuple, Union diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py index 711be174d827..8932f750551c 100644 --- a/torch/_functorch/make_functional.py +++ b/torch/_functorch/make_functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 0956ee7e367c..8e954f910ba4 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import heapq @@ -8,9 +9,7 @@ import os from collections import defaultdict from dataclasses import dataclass, replace -from typing import Callable, Dict, List, Optional, Set, Tuple, Union - -import sympy +from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import torch import torch._inductor.inductor_prims @@ -27,8 +26,12 @@ ) from torch.fx.passes import graph_drawer from . import config +from ._aot_autograd.logging_utils import get_aot_graph_name from .compile_utils import fx_graph_cse, get_aten_target +if TYPE_CHECKING: + import sympy + AOT_PARTITIONER_DEBUG = config.debug_partitioner log = logging.getLogger(__name__) @@ -450,14 +453,16 @@ def _size_of(node: fx.Node) -> int: # layering violation) elif isinstance(val, (list, tuple)): return sum( - _tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) + _tensor_nbytes(hint_int(n.numel(), fallback=4096), n.dtype) for n in val if isinstance(n, torch.Tensor) ) elif isinstance(val, torch.Tensor): - return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) + return _tensor_nbytes(hint_int(val.numel(), fallback=4096), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") + if node.op == "get_attr": + return 0 raise RuntimeError("We should always have `val` metadata on the nodes") @@ -531,25 +536,22 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: for idx, node in enumerate(gm.graph.nodes): order[node] = idx - # Populate depth for the nodes. Depth is the distance from the inputs. - depths = {} - output_node = next(iter(gm.graph.find_nodes(op="output"))) - for node in gm.graph.nodes: - if node.op == "placeholder": - depths[node] = 0 - else: - depths[node] = max([depths[arg] for arg in node.all_input_nodes], default=0) - def insert_node_in_graph(node): - if node in env: - return env[node] + cur_nodes = [node] + insertable_nodes = set() + while len(cur_nodes) > 0: + node = cur_nodes.pop() + if node in insertable_nodes or node in env: + continue + insertable_nodes.add(node) - # Bias traversal towards the nodes that have higher depth - prioritizes - # critical path first. - for arg, _ in sort_depths(node.all_input_nodes, depths): - env[arg] = insert_node_in_graph(arg) - env[node] = new_graph.node_copy(node, lambda x: env[x]) - return env[node] + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + cur_nodes += node.all_input_nodes + + insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) + for node in insertable_nodes: + env[node] = new_graph.node_copy(node, lambda x: env[x]) # Find first bwd node in the graph tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) @@ -749,7 +751,7 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: return joint_module -def get_saved_values( +def solve_min_cut( joint_graph: fx.Graph, node_info: NodeInfo, min_cut_options: MinCutOptions, @@ -876,7 +878,6 @@ def ban_recomputation_if_allowed(node): return False if node in dont_ban: return False - # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1267,9 +1268,197 @@ def get_name_to_node(graph: fx.Graph): return name_to_node +def greedy_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: List[float], runtimes: List[float], max_memory: float +) -> Tuple[float, List[int], List[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = int(round(max_memory * S)) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items + + +def _optimize_runtime_with_given_memory( + memory: List[float], + runtimes: List[float], + max_memory: float, +) -> Tuple[float, List[int], List[int]]: + SOLVER = config.activation_memory_budget_solver + if SOLVER == "greedy": + return greedy_knapsack(memory, runtimes, max_memory) + elif SOLVER == "ilp": + return ilp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp": + return dp_knapsack(memory, runtimes, max_memory) + else: + raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") + + +from torch.utils._mode_utils import no_dispatch + + +def estimate_runtime(node): + RUNTIME_MODE = config.activation_memory_budget_runtime_estimator + + def materialize_arg(x): + if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): + shape = list(x.meta["val"].shape) + + def realize_symbol(d): + return hint_int(d, fallback=4096) + + shape = [realize_symbol(s) for s in shape] + return x.meta["val"].new_zeros(shape) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): + return hint_int(x.meta["val"], fallback=4096) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): + return 1.0 + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): + return True + else: + return x + + if RUNTIME_MODE == "testing": + return 1 + + elif RUNTIME_MODE == "profile": + from triton.testing import do_bench + + with no_dispatch(): + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + ms = do_bench(lambda: node.target(*args, **kwargs)) + return ms + + elif RUNTIME_MODE == "flops": + # todo(chilli): Normalize this to also return ms + from torch.utils.flop_counter import FlopCounterMode + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + with FlopCounterMode(display=False) as mode: + node.target(*args, **kwargs) + counted_flops = mode.get_total_flops() + return max(counted_flops, 1) + else: + raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") + + def choose_saved_values_set( joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 ) -> List[fx.Node]: + if memory_budget > 1 or memory_budget < 0: + raise RuntimeError( + f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" + ) min_cut_options = MinCutOptions( ban_if_used_far_apart=config.ban_recompute_used_far_apart, ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, @@ -1286,16 +1475,164 @@ def choose_saved_values_set( ban_if_materialized_backward=False, ban_if_not_in_allowlist=False, ) - if memory_budget == 0: return node_info.inputs - runtime_optimized_saved_values, _ = get_saved_values( + runtime_optimized_saved_values, _ = solve_min_cut( joint_graph, node_info, min_cut_options, ) - return runtime_optimized_saved_values + # return runtime_optimized_saved_values + if memory_budget == 1: + return runtime_optimized_saved_values + + def estimate_activations_size(saved_values: List[fx.Node]) -> float: + return sum([_size_of(i) for i in saved_values]) / 1e9 + + min_act_size = estimate_activations_size(node_info.inputs) + max_act_size = estimate_activations_size(runtime_optimized_saved_values) + # The optimized choice is smaller than the inputs anyways + if max_act_size <= min_act_size: + return runtime_optimized_saved_values + + def get_normalized_size(sz): + return (sz / 1e9) / (max_act_size - min_act_size) + + def get_mem_ratio(activations: List[fx.Node]): + return (estimate_activations_size(activations) - min_act_size) / ( + max_act_size - min_act_size + ) + + more_aggressive_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ) + more_aggressive_saved_values, _ = solve_min_cut( + joint_graph, node_info, more_aggressive_options + ) + if get_mem_ratio(more_aggressive_saved_values) < memory_budget: + return more_aggressive_saved_values + + aggressive_options = replace( + more_aggressive_options, + ban_if_not_in_allowlist=False, + ) + aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( + joint_graph, node_info, aggressive_options + ) + + if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: + return aggressive_recomputation_saved_values + + from torch._inductor.fx_utils import get_node_storage + + input_storages = {get_node_storage(node) for node in node_info.inputs} + + def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: + return [ + i + for i in banned_nodes + if ( + # Only allow recomputing nodes that are actually required for BW + i.dist_from_bw < int(1e9) # type: ignore[attr-defined] + and get_node_storage(i) not in input_storages + ) + ] + + recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) + + # default: runtime_optimized_saved_values + # more aggressive: more_aggressive_saved_values + # full aggressive: aggressive_recomputation_saved_values + + all_recomputable_banned_nodes = sorted( + recomputable_banned_nodes, key=_size_of, reverse=True + ) + if len(all_recomputable_banned_nodes) == 0: + return node_info.inputs + memories_banned_nodes = [ + get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes + ] + runtimes_banned_nodes = [ + estimate_runtime(node) for node in all_recomputable_banned_nodes + ] + from torch.utils._mode_utils import no_dispatch + + def get_saved_values_knapsack(memory_budget): + with no_dispatch(): + ( + expected_runtime, + saved_node_idxs, + recomputable_node_idxs, + ) = _optimize_runtime_with_given_memory( + memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) + ) + dont_ban = set() + for idx in recomputable_node_idxs: + dont_ban.add(all_recomputable_banned_nodes[idx]) + assert dont_ban.issubset(all_recomputable_banned_nodes) + + saved_values, _ = solve_min_cut( + joint_graph, + node_info, + aggressive_options, + dont_ban, + ) + return saved_values, expected_runtime + + if config.visualize_memory_budget_pareto: + options = [] + for sweep_memory_budget in range(100, -1, -5): + saved_values, expected_runtime = get_saved_values_knapsack( + sweep_memory_budget / 100 + ) + options.append( + ( + sweep_memory_budget, + sum(runtimes_banned_nodes) - expected_runtime, + get_mem_ratio(saved_values), + ) + ) + + import matplotlib.pyplot as plt + + x_values = [item[2] for item in options] + y_values = [item[1] for item in options] + + # Plotting the values with updated axis labels and chart title + plt.figure(figsize=(10, 6)) + plt.plot(x_values, y_values, marker="o") + + # Adding labels for each point + for i, txt in enumerate(x_values): + plt.annotate( + f"{txt:.2f}", + (x_values[i], y_values[i]), + textcoords="offset points", + xytext=(0, 10), + ha="center", + ) + + plt.xlabel("Memory Budget") + plt.ylabel("Runtime of Recomputed Components") + plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") + plt.grid(True) + fig = plt.gcf() + plt.show() + fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" + fig.savefig(fig_name) + log.warning("Generated Pareto frontier curve at %s", fig_name) + + # todo(chilli): Estimated doesn't align exactly with actual - actual is + # usually less memory than estimated. i'm guessing (actually quite + # unsure about this) that's because estimated is just only including + # tensors we actually banned from recompute, but there may be other + # tensors that we choose to save. + + return get_saved_values_knapsack(memory_budget=memory_budget)[0] def min_cut_rematerialization_partition( @@ -1411,7 +1748,15 @@ def classify_nodes(joint_module): for user in node.users: node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) + memory_budget = config.activation_memory_budget + for node in joint_graph.nodes: + if isinstance(node.meta.get("memory_budget", None), float): + memory_budget = node.meta["memory_budget"] + break + # print("Memory Budget: ", memory_budget) + saved_values = choose_saved_values_set( + joint_graph, node_info, memory_budget=memory_budget + ) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index 5a78facf08c0..fb2aae84c0b9 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from abc import ABC, abstractmethod from typing import Any, List, Tuple diff --git a/torch/_functorch/utils.py b/torch/_functorch/utils.py index 303ebbc45d63..5e88b8462c5f 100644 --- a/torch/_functorch/utils.py +++ b/torch/_functorch/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Tuple, Union diff --git a/torch/_guards.py b/torch/_guards.py index 4dccd4aa84e6..92041700f0b0 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index e0e22eb4202f..0d88aa0db2c6 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools from typing import Callable, List @@ -76,7 +77,7 @@ def add(x: torch.Tensor, y: torch.Tensor): assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" - if not torch.compiler.is_compiling(): + if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( combine_fn, input, dim diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 89263bd65e7a..189f746b77a0 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional, Tuple, Union import torch diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 359feb192ae5..f4fe64d67f0b 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index f76596a3c6f3..a8da01fe06ec 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import Any, Dict, Optional, Tuple diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index f4586a0a57b0..c2efa3b48b7f 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Callable, Tuple, Union import torch diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 2bf88ea19565..f5bf1d43c19f 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.utils._pytree as pytree from torch._C import DispatchKey diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index f675519ee182..a3f5e2115aee 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.utils._pytree as pytree diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index 81c20bc3462b..d781248a19c9 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch._subclasses.functional_tensor diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index 235dfe6ec416..744e559e65d0 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from contextlib import contextmanager diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index a99afaaa9547..5552ef1ff8b2 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import inspect import logging diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 0fcf22bcc338..f4b393e7c234 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from contextlib import contextmanager from dataclasses import dataclass @@ -95,13 +96,21 @@ def wrapped(*args): @contextmanager def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag + _old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules try: # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False + + # TODO(anijain2305, export-team) For non-strict export with module + # stack info, the codepatch forces the nn module __getattr__ to + # ProxyAttr __getattr__ downstream. To circumvent the issue for now, + # skip inlining inbuilt nn modules for cond. + torch._dynamo.config.inline_inbuilt_nn_modules = False yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing + torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index b0ab00bdfac4..4577036b731f 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Tuple, Union import torch diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index f288c350f0ee..6d83a44e752a 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import logging diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 0d7cd8cece49..9d9445c5de3f 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional, Tuple import torch.fx diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py new file mode 100644 index 000000000000..496a7a5ad841 --- /dev/null +++ b/torch/_inductor/async_compile.py @@ -0,0 +1,259 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import multiprocessing +import os +import sys +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from functools import partial +from time import time +from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._inductor import config +from torch._inductor.codecache import ( + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + TritonCodeCache, + TritonFuture, +) +from torch._inductor.compile_worker.subproc_pool import ( + _warm_process_pool, + AnyPool, + SubprocPool, +) +from torch._inductor.compile_worker.watchdog import _async_compile_initializer + +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) + +from torch.hub import _Faketqdm, tqdm + +if TYPE_CHECKING: + from torch._inductor.runtime.hints import HalideMeta + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ModuleNotFoundError: + # Might not be installed. + pass + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[AnyPool] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert config.compile_threads > 1 + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(config.compile_threads) + else: + pre_fork_setup() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + _warm_process_pool(cls.process_pool(), config.compile_threads) + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + _set_triton_ptxas_path() + + kernel = TritonCodeCache.load(kernel_name, source_code) + if config.compile_threads > 1: + return TritonFuture( + kernel, + self.process_pool().submit( + _worker_compile_triton, + kernel._reload_in_subproc, + ), + ) + else: + kernel.precompile() + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: List[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if config.compile_threads <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if config.compile_threads <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, CodeCacheFuture)): + scope[key] = result.result() + pbar.update(1) + + _compile_end() + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" +): + pass +else: + AsyncCompile.warm_pool() diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 7db4d2a3291c..71171b3a4c32 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib @@ -11,7 +12,6 @@ import warnings from concurrent.futures import ThreadPoolExecutor from ctypes import byref, c_size_t, c_void_p, CDLL -from types import ModuleType from typing import ( Any, Callable, @@ -25,6 +25,7 @@ ) import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch import multiprocessing from torch._dynamo.testing import rand_strided @@ -40,6 +41,7 @@ if TYPE_CHECKING: from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue + from types import ModuleType from torch._inductor.select_algorithm import TritonTemplateCaller diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 4640ec4dce6b..8c62ef2ba3c9 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,3 +1,5 @@ +# mypy: allow-untyped-defs +import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -11,6 +13,9 @@ from .virtualized import V +log = logging.getLogger(__name__) + + class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -55,6 +60,7 @@ def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c2e6b4f0d95d..574511d004a4 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import base64 @@ -9,7 +10,6 @@ import io import json import logging -import multiprocessing import os import pickle import pkgutil @@ -26,7 +26,6 @@ import threading import warnings from bisect import bisect_right -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from copy import copy from ctypes import c_void_p, cdll, CDLL from functools import partial @@ -49,26 +48,16 @@ ) import torch -from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.utils import counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env -from torch._inductor.compile_worker.subproc_pool import ( - _warm_process_pool, - AnyPool, - SubprocPool, -) -from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import ( _module_to_triton_kernel, _reload_python_module, _reload_python_module_in_subproc, - _set_triton_ptxas_path, - _worker_compile_triton, ) -from torch._inductor.runtime.hints import HalideMeta from torch._inductor.runtime.runtime_utils import cache_dir -from torch._inductor.utils import clear_on_fresh_inductor_cache, is_linux +from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux from torch._logging import trace_structured from torch._subclasses.fake_tensor import ( @@ -79,15 +68,19 @@ from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv if TYPE_CHECKING: + from concurrent.futures import Future + from torch._inductor.graph import GraphLowering from torch._inductor.ir import ChoiceCaller + from torch._inductor.runtime.hints import HalideMeta -from torch.hub import _Faketqdm, tqdm _HERE = os.path.abspath(__file__) _TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) _LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") +_IS_WINDOWS = sys.platform == "win32" + if config.is_fbcode(): from triton.fb import build_paths from triton.fb.build import _run_build_command @@ -114,31 +107,11 @@ def use_global_cache() -> bool: output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") -kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") LOCK_TIMEOUT = 600 _IS_WINDOWS = sys.platform == "win32" -# timing metrics for time spent in the compilation -_cumulative_compile_time = 0.0 -_t0: Optional[float] = None - - -def _compile_start() -> None: - global _t0 - if _t0 is None: - _t0 = time() - - -def _compile_end() -> None: - global _cumulative_compile_time, _t0 - if _t0 is not None: - t1 = time() - _cumulative_compile_time += t1 - _t0 - _t0 = None - # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) - log = logging.getLogger(__name__) @@ -258,7 +231,7 @@ def set_value(self, *keys: str, value: Any) -> None: class PersistentCache(CacheBase): - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def get_global_cache(self): global_cache_path = self.get_global_cache_path() if global_cache_path is None or not global_cache_path.is_file(): @@ -446,11 +419,22 @@ def _ident(x: Any) -> Any: return x +def extract_tensor_metadata_for_cache_key(t): + """ + Extracts the tensor metadata and removes fields of the TensorMetadata + that are not needed for caching + """ + meta = extract_tensor_metadata(t) + if not hasattr(t, "_is_inductor_static"): + meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) + return meta + + def _reduce_fake_tensor(t): """ See FxGraphCachePickler. Custom reducer to pickle FakeTensors. """ - metadata = extract_tensor_metadata(t) + metadata = extract_tensor_metadata_for_cache_key(t) return (_ident, (metadata,)) @@ -481,7 +465,7 @@ def _reduce_tensor(t): f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue." ) - metadata = extract_tensor_metadata(t) + metadata = extract_tensor_metadata_for_cache_key(t) return (_ident, (TensorMetadataAndValues(metadata, values),)) @@ -554,7 +538,7 @@ def debug_str(cls, inp: Any) -> str: def get_str(obj) -> str: if isinstance(obj, torch.Tensor): - return str(extract_tensor_metadata(obj)) + return str(extract_tensor_metadata_for_cache_key(obj)) elif isinstance(obj, bytes): return "" else: @@ -576,17 +560,29 @@ def get_str(obj) -> str: return "\n".join(lines) -def get_code_hash(roots): - contents: Dict[str, bytes] = {torch.__version__: b""} - for lib in pkgutil.iter_modules(roots): +def build_code_hash(roots, prefix, hasher): + for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): spec = lib.module_finder.find_spec(lib.name, None) assert spec is not None module = spec.origin assert module is not None with open(module, "rb") as f: - contents[module] = f.read() + hasher.update(spec.name.encode("utf-8")) + hasher.update(f.read()) + if lib.ispkg: + # need to also hash submodules + build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) - return hashlib.sha256(pickle.dumps(contents)).digest() + +def get_code_hash(roots, extra_files=()): + hasher = hashlib.sha256() + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash(roots, "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() @functools.lru_cache(None) @@ -596,7 +592,15 @@ def torch_key(): """ if not config.is_fbcode(): inductor_root = os.path.dirname(__file__) - return get_code_hash([inductor_root]) + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "codegen/aoti_runtime/implementation.cpp", + "codegen/cpp_prefix.h", + "script.ld", + ) + return get_code_hash( + [inductor_root], [os.path.join(inductor_root, x) for x in extra_files] + ) from libfb.py import parutil @@ -639,6 +643,7 @@ def __init__( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], ): self.gm = gm self.example_inputs = example_inputs @@ -654,6 +659,9 @@ def __init__( else: self.fx_kwargs[k] = fx_kwargs[k] + # Alignment checks + self.inputs_to_check = inputs_to_check + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. self.deterministic_algorithms_settings = ( torch.are_deterministic_algorithms_enabled(), @@ -686,11 +694,12 @@ def compiled_fx_graph_hash( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], ) -> str: """ Generate a unique hash of the FX graph for caching. """ - details = FxGraphHashDetails(gm, example_inputs, fx_kwargs) + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) # The prefix distinguishes among the other kinds of objects we # cache in this module. key = "f" + FxGraphCachePickler.get_hash(details) @@ -919,7 +928,10 @@ def _save_graph( shape_env = FxGraphCache._get_shape_env() assert shape_env is not None symints = FxGraphCache._filter_backed_symints(example_inputs) - disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints) + guards = shape_env.get_pruned_guards(symints) + disk_compiled_graph.guards_expr = shape_env.produce_guards_expression( + placeholders=symints, guards=guards + ) try: content = pickle.dumps(disk_compiled_graph) @@ -990,6 +1002,7 @@ def load( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], local: bool, remote: bool, ): @@ -1001,22 +1014,22 @@ def load( compiled_graph = None try: FxGraphCache._check_can_cache(gm) - key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs) + key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs, inputs_to_check) remote_cache = None if remote: cache_id = "fx-graph-v1" try: - import triton - if config.is_fbcode(): - remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend( - cache_id + from triton.fb.fb_memcache import ( + FbMemcacheRemoteFxGraphCacheBackend, ) + + remote_cache = FbMemcacheRemoteFxGraphCacheBackend(cache_id) else: - remote_cache = triton.runtime.cache.RedisRemoteCacheBackend( - cache_id - ) + from torch._inductor.remote_cache import RedisRemoteCacheBackend + + remote_cache = RedisRemoteCacheBackend(cache_id) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) @@ -1237,7 +1250,7 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: class VecISA: _bit_width: int - _macro: str + _macro: List[str] _arch_flags: str _dtype_nelements: Dict[torch.dtype, int] @@ -1283,7 +1296,7 @@ def bit_width(self) -> int: def nelements(self, dtype: torch.dtype = torch.float) -> int: return self._dtype_nelements[dtype] - def build_macro(self) -> str: + def build_macro(self) -> List[str]: return self._macro def build_arch_flags(self) -> str: @@ -1292,8 +1305,10 @@ def build_arch_flags(self) -> str: def __hash__(self) -> int: return hash(str(self)) - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def __bool__(self) -> bool: + from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions + if config.cpp.vec_isa_ok is not None: return config.cpp.vec_isa_ok @@ -1310,16 +1325,21 @@ def __bool__(self) -> bool: lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: - output_path = input_path[:-3] + "so" - build_cmd = shlex.split( - cpp_compile_command( - input_path, output_path, warning_all=False, vec_isa=self - ) + output_dir = os.path.dirname(input_path) + buid_options = CppTorchOptions(vec_isa=self, warning_all=False) + x86_isa_help_builder = CppBuilder( + key, + [input_path], + buid_options, + output_dir, ) try: # Check if the output file exist, and compile when not. + output_path = x86_isa_help_builder.get_target_file_path() if not os.path.isfile(output_path): - compile_file(input_path, output_path, build_cmd) + status, target_file = x86_isa_help_builder.build() + if status: + return False # Check build result subprocess.check_call( @@ -1340,7 +1360,9 @@ def __bool__(self) -> bool: @dataclasses.dataclass class VecNEON(VecISA): _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h - _macro = "-DCPU_CAPABILITY_NEON" + _macro = ["CPU_CAPABILITY_NEON"] + if sys.platform == "darwin" and platform.processor() == "arm": + _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF") _arch_flags = "" # Unused _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} @@ -1353,8 +1375,12 @@ def __str__(self) -> str: @dataclasses.dataclass class VecAVX512(VecISA): _bit_width = 512 - _macro = "-DCPU_CAPABILITY_AVX512" - _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + _macro = ["CPU_CAPABILITY_AVX512"] + _arch_flags = ( + "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + if not _IS_WINDOWS + else "/arch:AVX512" + ) # TODO: use cflags _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} def __str__(self) -> str: @@ -1366,8 +1392,10 @@ def __str__(self) -> str: @dataclasses.dataclass class VecAVX2(VecISA): _bit_width = 256 - _macro = "-DCPU_CAPABILITY_AVX2" - _arch_flags = "-mavx2 -mfma" + _macro = ["CPU_CAPABILITY_AVX2"] + _arch_flags = ( + "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" + ) # TODO: use cflags _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} def __str__(self) -> str: @@ -1379,7 +1407,11 @@ def __str__(self) -> str: @dataclasses.dataclass class VecZVECTOR(VecISA): _bit_width = 256 - _macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION" + _macro = [ + "CPU_CAPABILITY_ZVECTOR", + "CPU_CAPABILITY=ZVECTOR", + "HAVE_ZVECTOR_CPU_DEFINITION", + ] _arch_flags = "-mvx -mzvector" _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} @@ -1391,7 +1423,7 @@ def __str__(self) -> str: class InvalidVecISA(VecISA): _bit_width = 0 - _macro = "" + _macro = [""] _arch_flags = "" _dtype_nelements = {} @@ -1404,6 +1436,31 @@ def __bool__(self) -> bool: # type: ignore[override] __hash__: Callable[[VecISA], Any] = VecISA.__hash__ +def x86_isa_checker() -> List[str]: + supported_isa: List[str] = [] + + def _check_and_append_supported_isa( + dest: List[str], isa_supported: bool, isa_name: str + ): + if isa_supported: + dest.append(isa_name) + + Arch = platform.machine() + """ + Arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + if Arch != "x86_64" and Arch != "AMD64": + return supported_isa + + avx2 = torch.cpu._is_cpu_support_avx2() + avx512 = torch.cpu._is_cpu_support_avx512() + + _check_and_append_supported_isa(supported_isa, avx2, "avx2") + _check_and_append_supported_isa(supported_isa, avx512, "avx512") + + return supported_isa + + invalid_vec_isa = InvalidVecISA() supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()] @@ -1416,7 +1473,8 @@ def valid_vec_isa_list() -> List[VecISA]: if sys.platform == "darwin" and platform.processor() == "arm": return [VecNEON()] - if sys.platform != "linux": + cur_os = sys.platform + if cur_os != "linux" and cur_os != "win32": return [] if platform.machine() == "s390x": @@ -1434,12 +1492,11 @@ def valid_vec_isa_list() -> List[VecISA]: return [] isa_list = [] - with open("/proc/cpuinfo") as _cpu_info: - _cpu_info_content = _cpu_info.read() - for isa in supported_vec_isa_list: - if str(isa) in _cpu_info_content and isa: - isa_list.append(isa) - return isa_list + _cpu_supported_isa = x86_isa_checker() + for isa in supported_vec_isa_list: + if str(isa) in _cpu_supported_isa and isa: + isa_list.append(isa) + return isa_list def pick_vec_isa() -> VecISA: @@ -1494,7 +1551,7 @@ def cpp_flags() -> str: def cpp_wrapper_flags() -> str: - return "-DTORCH_INDUCTOR_CPP_WRAPPER" + return "-D TORCH_INDUCTOR_CPP_WRAPPER" def optimization_flags() -> str: @@ -1636,7 +1693,14 @@ def get_include_and_linking_paths( _set_gpu_runtime_env() from torch.utils import cpp_extension - macros = vec_isa.build_macro() if vec_isa != invalid_vec_isa else "" + # Remove below in the further + # macros = "-D {}".format(vec_isa.build_macro()) if vec_isa != invalid_vec_isa else "" + macros = "" + if vec_isa != invalid_vec_isa: + for x in vec_isa.build_macro(): + macros_def = f"-D {x} " + macros += macros_def + build_arch_flags = "" if sys.platform == "linux" and ( include_pytorch @@ -1748,6 +1812,11 @@ def get_include_and_linking_paths( else: libs = ["omp"] if config.is_fbcode() else ["gomp"] + # For AOT mode, the produced library relies on torch cpu to set grad mode + # like aoti_torch_grad_mode_set_enabled + if aot_mode and sys.platform == "linux" and not config.is_fbcode(): + libs += ["torch", "torch_cpu"] + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 if not config.abi_compatible: libs += ["c10"] @@ -1848,7 +1917,7 @@ def cpp_compile_command( {get_glibcxx_abi_build_flags()} {ipaths_str} {lpaths} {libs} {build_arch_flags} {macros} {linker_paths} {clang_flags} - {optimization_flags()} + {optimization_flags()} {cpp_wrapper_flags()} {use_custom_generated_macros()} {use_fb_internal_macros()} {use_standard_sys_dir_headers()} @@ -1991,10 +2060,14 @@ def _compile_consts_linux(consts: bytes) -> str: # as read-only (i.e. .lrodata) which could accomodate larger size of data # to be linked. rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" + + assert ( + ALIGN_BYTES & (ALIGN_BYTES - 1) + ) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64" cmd = ( f"{objcopy_command} --rename-section" f"{rename_data}" - " --set-section-alignment .data=64" # following the gAlignment of CPU in c10/core/alignment.h + f" --set-section-alignment .data={ALIGN_BYTES}" # following the gAlignment of CPU in c10/core/alignment.h f" {consts_o} {consts_o}" ) log.debug("aot constant rename section command: %s", cmd) @@ -2118,7 +2191,14 @@ def _compile_consts_darwin(consts: bytes) -> str: else: run_command_and_check(compile_cmd) - def _to_bytes(t: torch.Tensor) -> bytes: + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: + def _pad_to_alignment(raw_bytes): + padded_bytes = raw_bytes.ljust( + (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, + b"\x00", + ) + return padded_bytes + # This serializes the tensor's untyped_storage to bytes by accessing # the raw data of the underlying structure. import ctypes @@ -2127,22 +2207,27 @@ def _to_bytes(t: torch.Tensor) -> bytes: return b"" if t.is_mkldnn: - raw_array = ctypes.cast( - torch.ops.mkldnn.data_ptr(t), - ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)), - ) - return bytes(raw_array.contents) + data_ptr = torch.ops.mkldnn.data_ptr(t) + nbytes = torch.ops.mkldnn._nbytes(t) + else: + t_cpu = t.untyped_storage().cpu() + data_ptr = t_cpu.data_ptr() + nbytes = t_cpu.nbytes() - t_cpu = t.untyped_storage().cpu() raw_array = ctypes.cast( - t_cpu.data_ptr(), - ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()), + data_ptr, + ctypes.POINTER(ctypes.c_ubyte * nbytes), ) + raw_bytes = bytes(raw_array.contents) + return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) - return bytes(raw_array.contents) - + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) serialized_weights = b"".join( - _to_bytes(graph.get_original_value_of_constant(name)) + _to_bytes(graph.get_original_value_of_constant(name), all_cuda) for name in graph.constants.keys() if name not in graph.folded_constants ) @@ -2353,8 +2438,21 @@ def load_async(cls, source_code: str, cuda=False, submit_fn=None, extra_flags=() "vec_isa": pick_vec_isa(), "extra_flags": extra_flags, } - cpp_command = repr(cpp_compile_command("i", "o", **compile_command)) - key, input_path = write(source_code, "cpp", extra=cpp_command) + + _set_gpu_runtime_env() # cpp_extension consults the env + + from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions + + dummy_builder = CppBuilder( + name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + dummy_cmd = repr(dummy_builder.get_command_line()) + key, input_path = write(source_code, "cpp", extra=dummy_cmd) if key not in cls.cache: from filelock import FileLock @@ -2569,7 +2667,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache): cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache_clear = staticmethod(cache.clear) cpp_compile_command_flags = { - "include_pytorch": not config.abi_compatible, + "include_pytorch": True, "shared": True, } entry_function = "inductor_entry_cpp" @@ -2627,6 +2725,101 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache): ) +# TODO: Will remove the temp code after switch to new cpp_builder +def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]): + new_diff: List[str] = [x for x in new_cmd if x not in old_cmd] + old_diff: List[str] = [y for y in old_cmd if y not in new_cmd] + + if new_diff or old_diff: + print("!!! new_cmd: ", new_cmd) + print("!!! old_cmd: ", old_cmd) + print("!!! new_diff: ", new_diff) + print("!!! old_diff: ", old_diff) + raise RuntimeError("Error in new and old command different.") + + +def _do_validate_cpp_commands( + include_pytorch: bool, + cuda: bool, + compile_only: bool, + mmap_weights: bool, + use_absolute_path: bool, +): + # PreCI will failed if test machine can't run cuda. + temp_dir = tempfile.TemporaryDirectory() + test_dir_path = temp_dir.name + test_cuda = torch.cuda.is_available() and cuda + input_path = os.path.join(test_dir_path, "dummy_input.cpp") + output_path = os.path.join(test_dir_path, "dummy_output.so") + extra_flags = ["-D TEST_EXTRA_FLAGS"] + if compile_only: + output_path = os.path.join(test_dir_path, "dummy_output.o") + picked_isa = pick_vec_isa() + + old_cmd = cpp_compile_command( + input=input_path, + output=output_path, + include_pytorch=include_pytorch, + vec_isa=picked_isa, + cuda=test_cuda, + aot_mode=False, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=mmap_weights, + extra_flags=extra_flags, + ).split(" ") + + from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions + + dummy_build_option = CppTorchCudaOptions( + vec_isa=picked_isa, + include_pytorch=include_pytorch, + cuda=test_cuda, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=mmap_weights, + extra_flags=extra_flags, + ) + + dummy_builder = CppBuilder( + name="dummy_output", + sources=input_path, + BuildOption=dummy_build_option, + output_dir=test_dir_path, + ) + new_cmd = dummy_builder.get_command_line().split(" ") + + _temp_validate_new_and_old_command(new_cmd, old_cmd) + + temp_dir.cleanup() + + +# TODO: Will remove the temp code after switch to new cpp_builder +# It could help on sync new cpp_builder generate same command line as the old one. +def validate_new_cpp_commands(): + cuda = [True, False] + use_mmap_weights = [True, False] + compile_only = [True, False] + include_pytorch = [True, False] + use_absolute_path = [True, False] + + for x in cuda: + for y in use_mmap_weights: + for z in compile_only: + for m in include_pytorch: + for n in use_absolute_path: + print( + f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}, use_absolute_path:{n}" + ) + _do_validate_cpp_commands( + include_pytorch=m, + cuda=x, + mmap_weights=y, + compile_only=z, + use_absolute_path=n, + ) + + @clear_on_fresh_inductor_cache class HalideCodeCache(CppPythonBindingsCodeCache): cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} @@ -3205,12 +3398,6 @@ def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]: return (DLLWrapper(dst_file_path), hash_key, source_code_path) -def caching_device_properties(): - for _, device_interface in get_registered_device_interfaces(): - if device_interface.is_available(): - device_interface.Worker.get_device_properties() - - class CodeCacheFuture: def result(self): raise NotImplementedError @@ -3244,171 +3431,3 @@ def __init__(self, result_fn): def result(self): return self.result_fn() - - -# Used to keep track of all process pools invoked so far. -_pool_set: Set[AnyPool] = set() - - -def shutdown_compile_workers() -> None: - """Shut down all outstanding compile-worker pools.""" - for pool in _pool_set: - pool.shutdown() - after_fork() - - -def after_fork(): - """Reset pools to initial state without shutting them down""" - _pool_set.clear() - AsyncCompile.process_pool.cache_clear() - - -try: - os.register_at_fork(after_in_child=after_fork) -except AttributeError: - pass # register_at_fork does not exists on windows - - -class AsyncCompile: - def __init__(self) -> None: - pass - - @staticmethod - @functools.lru_cache(1) - def pool() -> ThreadPoolExecutor: - assert config.compile_threads > 1 - return ThreadPoolExecutor(config.compile_threads) - - @staticmethod - @functools.lru_cache(1) - def process_pool() -> AnyPool: - assert config.compile_threads > 1 - pool: AnyPool - if config.worker_start_method == "subprocess": - # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(config.compile_threads) - else: - # ensure properties have been calculated before processes - # are forked - caching_device_properties() - ctx = multiprocessing.get_context(config.worker_start_method) - pool = ProcessPoolExecutor( - config.compile_threads, - mp_context=ctx, - initializer=partial(_async_compile_initializer, os.getpid()), - ) - # when this pool is created in a subprocess object, the normal exit handler - # doesn't run, and we need to register our own handler. - # exitpriority has to be high, because another one of the finalizers will - # kill the worker thread that sends the shutdown message to the workers... - multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) - - _pool_set.add(pool) - return pool - - @classmethod - def warm_pool(cls) -> None: - if config.compile_threads <= 1: - return - _compile_start() - _warm_process_pool(cls.process_pool(), config.compile_threads) - _compile_end() - - @classmethod - def submit(cls, task: Callable[..., Any]) -> Any: - if config.compile_threads <= 1: - return task() - return cls.pool().submit(task) - - def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): - kernel_code_log.info("Triton Kernel:\n%s", source_code) - _compile_start() - _set_triton_ptxas_path() - - kernel = TritonCodeCache.load(kernel_name, source_code) - if config.compile_threads > 1: - return TritonFuture( - kernel, - self.process_pool().submit( - _worker_compile_triton, - kernel._reload_in_subproc, - ), - ) - else: - kernel.precompile() - return kernel - - def multi_kernel(self, *args, **kwargs) -> Any: - from torch._inductor.codegen.multi_kernel import MultiKernelCall - - # no need to call this in parallel since the sub-kernels are already parallel tasks - return MultiKernelCall(*args, **kwargs) - - def cpp(self, source_code: str): - kernel_code_log.info("CPP Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppCodeCache.load(source_code).kernel - else: - get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) - return LambdaFuture(lambda: get_result().kernel) - - def cpp_pybinding(self, argtypes: List[str], source_code: str): - kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) - if config.compile_threads <= 1: - return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) - else: - get_result = CppPythonBindingsCodeCache.load_pybinding_async( - argtypes, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def cuda(self, source_code, dst_file_ext): - kernel_code_log.info("CUDA Kernel:\n%s", source_code) - - def task(): - return CUDACodeCache.load(source_code, dst_file_ext)[0] - - return self.submit(task) - - def halide(self, meta: HalideMeta, source_code: str): - kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) - if config.compile_threads <= 1: - return HalideCodeCache.generate_halide(meta, source_code) - else: - get_result = HalideCodeCache.generate_halide_async( - meta, source_code, submit_fn=self.submit - ) - return LambdaFuture(get_result) - - def wait(self, scope: Dict[str, Any]) -> None: - num_kernels = len( - [ - value - for key, value in scope.items() - if isinstance(value, (Future, CodeCacheFuture)) - ] - ) - pbar = tqdm( - total=num_kernels, - desc="Inductor Compilation", - disable=config.disable_progress, - delay=0, - ) - if config.compile_threads > 1: - for key, result in scope.items(): - if config.verbose_progress and not isinstance(pbar, _Faketqdm): - pbar.set_postfix_str(key) - if isinstance(result, (Future, CodeCacheFuture)): - scope[key] = result.result() - pbar.update(1) - - _compile_end() - - -if ( - os.environ.get("TORCH_TNT_IN_USE", "0") == "1" - or os.environ.get("TORCH_WARM_POOL", "1") != "1" -): - pass -else: - AsyncCompile.warm_pool() diff --git a/torch/_inductor/codegen/aoti_hipify_utils.py b/torch/_inductor/codegen/aoti_hipify_utils.py index a86ef2d29761..9edfe839946d 100644 --- a/torch/_inductor/codegen/aoti_hipify_utils.py +++ b/torch/_inductor/codegen/aoti_hipify_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils.hipify.hipify_python import PYTORCH_MAP, RE_PYTORCH_PREPROCESSOR diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e471fefabe1a..8ca6dc2b9153 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dataclasses import functools @@ -340,6 +341,8 @@ def propagate_scheduler_node(cls, node): DataTypePropagation.propagate_loopbody(node._body) +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -369,12 +372,6 @@ def all_in_parens(string): return string return f"({string})" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -384,11 +381,14 @@ def _print_Mul(self, expr): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -399,12 +399,96 @@ def _print_GreaterThan(self, expr): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + + def doprint(self, expr, *, simplify: bool = True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -414,56 +498,72 @@ def _print_ModularIndexing(self, expr): x = f"({x} // {div})" return f"{x} % {mod}" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_Pow(self, expr): - # Pow() confuses triton + def _print_FloatPow(self, expr): base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -508,7 +608,7 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -647,6 +747,29 @@ def remainder(a, b): ) return ops.where(cond, ops.add(r, b), r) + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) @@ -1637,10 +1760,7 @@ def indirect_indexing( pos = var.bounds & ValueRanges(0, sympy.oo) new_bounds = new_bounds | pos - new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) - # Propagate the mask as mask propagation when using where is not correct - new_var.update_on_args("index_wrap", (var,), {}) - var = new_var + var = self.cse.generate(self.compute, stm, bounds=new_bounds) sympy_var = parent_handler.indirect_indexing(var, size, check) if generate_assert(check): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 92a9c285d2b1..3370001aa429 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dataclasses import functools @@ -275,11 +276,11 @@ def visit_modular_indexing(divisor, modulus): original_index = index - div = sympy.Wild("divisor") + div = sympy.Wild("divisor", integer=True) if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus") + mod = sympy.Wild("modulus", integer=True) if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) @@ -1236,8 +1237,8 @@ def log2(x): return f"{x}.log2()" @staticmethod - def nextafter(x): - return f"{x}.nextafter()" + def nextafter(x, y): + return f"{x}.nextafter({y})" @staticmethod def copysign(a, b): diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 18d6301d57a6..cc8fcb699691 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional import torch diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index c5e989eb2eed..47d6e87e5a70 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple from typing import Dict, List, Optional, Type diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index aeebd2698aa5..e46465178840 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index f1d4fbaaac33..04bc8f1ec3d9 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index dbe3daf1c45c..9534ff8e5d09 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import math @@ -105,11 +106,54 @@ def _print_floor(self, expr): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 - r = f"std::trunc({self._print(expr.args[0])})" + r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr): + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_CMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + def _print_Pow(self, expr): # Uses float constants to perform FP div base, exp = expr.args @@ -144,6 +188,11 @@ def _print_ceiling(self, expr): r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_Min(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: @@ -205,8 +254,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 1259418fc09e..65ff4ebf4e69 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import os @@ -9,12 +10,13 @@ from sympy import Expr import torch + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey from .. import config, ir - from ..codecache import CudaKernelParamCache -from ..utils import cache_on_self, sympy_product +from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import IndentedBuffer @@ -238,8 +240,6 @@ class RAIIPyObject { """ ) - from .memory_planning import ALIGN_BYTES - # Round up to the nearest multiple of ALIGN_BYTES # ALIGN_BYTES must be a power of 2 self.header.splice( @@ -720,6 +720,11 @@ def codegen_model_constructor(self): ), f"input {name=} cannot be symbolic" self.write_input_output_info("inputs_info_", idx, name) + all_cuda = all( + V.graph.get_original_value_of_constant(name).is_cuda + for name in V.graph.constants.keys() + if name not in V.graph.folded_constants + ) for idx, name in enumerate(V.graph.constants.keys()): tensor = V.graph.get_original_value_of_constant(name) assert isinstance(tensor, torch.Tensor) @@ -730,14 +735,19 @@ def codegen_model_constructor(self): self.prefix.writeline( f"constants_info_[{idx}].offset = {tensor.storage_offset()};" ) - if tensor.is_mkldnn: - self.prefix.writeline( - f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};" - ) - else: - self.prefix.writeline( - f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" - ) + + # If constants to serialize contain cpu tensors, we always align data_size it to 64. + # When loading the constants, the valid data will depends on the size + # not the data_size so there won't be correctness issue. + data_size = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" + ) + from_folded = "true" if name in V.graph.folded_constants else "false" self.prefix.writeline( f"constants_info_[{idx}].from_folded = {from_folded};" @@ -1503,7 +1513,7 @@ def generate_inf_and_nan_checker(self, nodes): for buf in nodes.get_names(): # TODO: Add buf name directly into check_inf_and_nan. self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_check_inf_and_nan({buf}));" + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" ) def codegen_device(self, device): @@ -1533,15 +1543,15 @@ def codegen_layout(self, layout): else: return LAYOUT_TO_ATEN[layout] - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def codegen_int_array_var( self, int_array: str, writer=None, known_statically=False, graph=None, # for per-graph caching - is_bool=False, ): + # This is used for size/stride declaration # Because the memory planning is done in two passes (see the implementation # of self.generate), the writeline behavior is different in the two passes. # As a result, the emitted int array declarations may appear in a later @@ -1552,7 +1562,7 @@ def codegen_int_array_var( writer = self var = f"int_array_{next(self.int_array_id)}" - ctype = "int32_t" if is_bool else "int64_t" + ctype = "int64_t" if var not in self.declared_int_array_vars: self.declared_int_array_vars.add(var) if known_statically: @@ -1561,43 +1571,6 @@ def codegen_int_array_var( writer.writeline(f"const {ctype} {var}[] = {int_array};") return var - @functools.lru_cache(None) - def codegen_var_array( - self, - var_array: str, - writer=None, - known_statically=False, - graph=None, # for per-graph caching - type_hint=None, # ['int64_t', 'tensor', 'bool'] - ): - # Because the memory planning is done in two passes (see the implementation - # of self.generate), the writeline behavior is different in the two passes. - # As a result, the emitted int array declarations may appear in a later - # position of the generated code, so the second pass codegen should not - # reuse int array declarations generated in the first pass - if writer is None: - # The first pass codegen uses `self` as the writer - writer = self - if not type_hint or type_hint in ["bool", "int64_t"]: - return self.codegen_int_array_var( - var_array, - writer, - known_statically, - graph, - is_bool=type_hint == "bool", - ) - - var = f"var_array_{next(self.var_array_id)}" - assert type_hint == "tensor" - ctype = "AtenTensorHandle*" - if var not in self.declared_var_array_vars: - self.declared_var_array_vars.add(var) - if known_statically: - writer.writeline(f"static constexpr {ctype} {var}[] = {var_array};") - else: - writer.writeline(f"const {ctype} {var}[] = {var_array};") - return var - def make_buffer_allocation(self, buffer): return self.make_allocation( buffer.get_name(), @@ -2109,7 +2082,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( def extract_output_name(out): assert out is not None, "None, i.e. optional output is not supported" - if isinstance(out, ir.MultiOutput): + if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() elif isinstance(out, (list, tuple)): return type(out)(extract_output_name(o) for o in out) @@ -2190,17 +2163,19 @@ def generate_py_arg_inner(raw_arg, arg_type): return f"PyCapsule_New(reinterpret_cast({raw_arg.codegen_reference()}.get()), NULL, NULL)" elif isinstance(arg_type, torch.IntType): # int - return f"PyInt_FromLong({raw_arg})" + return f"PyLong_FromLongLong({raw_arg})" elif isinstance(arg_type, torch.SymIntType): # SymInt expr = ( raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg ) - return f"PyInt_FromLong({self.expr_printer(expr)})" + return f"PyLong_FromLongLong({self.expr_printer(expr)})" elif isinstance(arg_type, torch.FloatType): return f"PyFloat_FromDouble({raw_arg})" elif isinstance(arg_type, torch.BoolType): - return f"PyBool_FromBool({raw_arg})" + return f"PyBool_FromLong({1 if raw_arg else 0})" + elif isinstance(arg_type, torch.StringType): + return f'PyUnicode_FromString("{raw_arg}")' else: raise NotImplementedError( f"arg type {arg_type} is not yet supported by custom_op_wrapper" @@ -2334,65 +2309,28 @@ def generate_reset_kernel_saved_flags(self): def generate_save_uncompiled_kernels(self): pass - def val_to_cpp_arg_str(self, val, type_) -> str: - if config.abi_compatible and isinstance(type_, torch.OptionalType): - if val is None: - return "0" # nullptr is not available in C - if not isinstance(type_.getElementType(), torch.TensorType): - var_name = f"var_{next(self.arg_var_id)}" - if isinstance( - type_.getElementType(), - (torch.ListType, torch.TupleType, torch.DeviceObjType), - ): - arg_str = self.val_to_arg_str(val) - if val is None: - return "{arg_str}, 0" - else: - # For datatypes with auxiliary info, we need to hoist out the extra arguments. - # NOTE: This only works if there is one additional argument, though it can easily be generalized. - main_value, aux = arg_str.rsplit(", ") - self.writeline(f"auto {var_name} = {main_value};") - return f"&{var_name}, {aux}" - else: - self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};") - return f"&{var_name}" - elif config.c_shim_version == "2": - # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim - base_handle = self.val_to_arg_str(val) - if "wrap_with_raii_handle_if_needed" in base_handle: - # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to - # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. - tmp_var_name = f"var_{next(self.arg_var_id)}" - self.writeline( - f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" - ) - base_handle = tmp_var_name - var_name = f"var_{next(self.arg_var_id)}" - self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") - return f"&{var_name}" - - return self.val_to_arg_str(val, type_) + def c_type_for_prim_type(self, type_) -> str: + assert ( + config.abi_compatible + ), "c_type_for_prim_type is only used in ABI compatible mode" + if isinstance(type_, torch.OptionalType): + return f"{self.c_type_for_prim_type(type_.getElementType())}*" + elif isinstance(type_, torch.TensorType): + return "AtenTensorHandle" + elif isinstance(type_, (torch.IntType, torch.SymIntType)): + return "int64_t" + elif isinstance( + type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) + ) or repr(type_) in ("ScalarType", "Layout"): + return "int32_t" + elif isinstance(type_, torch.FloatType): + return "double" + else: + raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") - def val_to_arg_str(self, val, type_=None) -> str: - if val is None: - # When None is passed as an argument, it represents an optional that does not contain a value. - if config.abi_compatible: - if type_ is None or isinstance(type_, torch.OptionalType): - return "0" # nullptr is not available in C - elif isinstance(type_, torch.TensorType): - var_name = f"var_{next(self.arg_var_id)}" - self.writeline(f"AtenTensorHandle {var_name}_handle;") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" - ) - self.writeline( - f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" - ) - return var_name - else: - raise AssertionError("Can not map None to a known data type") - return "c10::nullopt" - elif isinstance(val, bool): + def val_to_arg_str_for_prim_type(self, val, type_) -> str: + # TODO: not using type_ as the first step of refactoring. Will update this later. + if isinstance(val, bool): if config.abi_compatible: return "1" if val else "0" else: @@ -2416,34 +2354,104 @@ def val_to_arg_str(self, val, type_=None) -> str: else: return "-std::numeric_limits::infinity()" elif isinstance(val, (list, tuple)): - # FIXME handle embedded optional types? - result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}" + # FIXME: This happens because type_ is not always properly set to torch.ListType + return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" + else: + return repr(val) + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor if config.abi_compatible: - assert len(val) > 0, "Empty array is not supported in C" - static = self.is_statically_known_list_of_ints(val) - type_hint = "bool" if isinstance(val[0], bool) else "int64_t" - if ( - type_ is not None - and isinstance(type_, torch._C.ListType) - and isinstance(type_.getElementType(), torch._C.OptionalType) - and isinstance( - type_.getElementType().getElementType(), torch._C.TensorType + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.ListType, + torch.TupleType, + torch.DeviceObjType, + ), + ): + return "0, 0" + else: + return "0" # nullptr is not available in C + elif isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" ) - ): - type_hint = "tensor" - tmp_arg_list = "" - for x in val: - tmp_arg_list += f"&{x}_handle, " - result = f"{{{tmp_arg_list}}}" - # Need to pass the array length because we can't use std::vector - var_array = self.codegen_var_array( - result, - known_statically=static, - graph=self.get_codegened_graph(), - type_hint=type_hint, + self.writeline( + f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" + ) + return var_name + else: + raise AssertionError("Can not map None to a known data type") + else: + if isinstance(type_, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"at::Tensor {var_name} = at::Tensor();") + return var_name + else: + return "std::nullopt" + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + if config.abi_compatible: + if not isinstance(element_type, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + if isinstance( + element_type, + (torch.ListType, torch.TupleType, torch.DeviceObjType), + ): + # type_ is something like Optional[List] or Optional[Device] + arg_str = self.val_to_arg_str(val, element_type) + # For datatypes with auxiliary info, we need to hoist out the extra arguments. + # NOTE: This only works if there is one additional argument, though it can easily be generalized. + main_value, aux = arg_str.rsplit(", ") + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + else: + self.writeline( + f"{self.c_type_for_prim_type(element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" + ) + return f"&{var_name}" + elif config.c_shim_version == "2": + # type_ is Optional[Tensor] + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val, element_type) + if "wrap_with_raii_handle_if_needed" in base_handle: + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" + ) + base_handle = tmp_var_name + var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"AtenTensorHandle {var_name} = {base_handle}.get();" + ) + return f"&{var_name}" + else: + return self.val_to_arg_str(val, element_type) + + elif isinstance(type_, torch.ListType): + assert isinstance( + val, (list, tuple) + ), f"{val} does not match with arg type {type_}" + element_type = type_.getElementType() + if config.abi_compatible: + assert len(val) > 0, "Empty array is not supported in C" + var_name = f"var_array_{next(self.var_array_id)}" + result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + self.writeline( + f"const {self.c_type_for_prim_type(element_type)} {var_name}[] = {result};" ) - return f"{var_array}, {len(val)}" + # Need to pass the array length because we can't use std::vector + return f"{var_name}, {len(val)}" else: - return result - else: - return repr(val) + return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + + return self.val_to_arg_str_for_prim_type(val, type_) diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py index e77277d75621..ad8c8eafbbd1 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import os from itertools import chain, count @@ -76,7 +77,7 @@ def generate(self, is_inference): self.prefix.writeline("\n") return super().generate(is_inference) - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def generate_load_kernel_once( self, name: str, diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 5c91736e9abd..0b91219d8f03 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import cast, Sequence diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 8cad41082d64..12b7b21de61e 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 871c8b388494..24a02efe3805 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging diff --git a/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py b/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py index d8bf408dc28a..11258382ad21 100644 --- a/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py +++ b/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List from unittest.mock import patch diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index 2a386a114e86..4ee8af3949ae 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ..cutlass_utils import try_import_cutlass if try_import_cutlass(): diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 789a2e44152c..04866fe4deb1 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import os diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 93a8c08b6a0f..7ff99b871c82 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ..common import DeviceOpOverrides, register_device_op_overrides diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 89c326cef546..3a7dccf7442b 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import enum import logging diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index f7be73c247fd..0b5b9d795202 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Sequence, Union from ..scheduler import ( diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 2aade2a297df..435bd2d895ce 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import collections @@ -10,7 +11,7 @@ import torch from .. import config, ir -from ..utils import cache_on_self, CachedMethod, IndentedBuffer +from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer from ..virtualized import V from .wrapper import ( @@ -22,36 +23,6 @@ ) -ALIGN_BYTES = 64 -assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" - - -def _align(nbytes): - """Round up to the nearest multiple of ALIGN_BYTES""" - return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES - - -def _is_aligned(v: sympy.Expr): - """v can be statically proven to be a multiple of ALIGN_BYTES""" - if isinstance(v, (sympy.Add, sympy.Max)): - return all(map(_is_aligned, v.args)) - return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES - - -class align(sympy.Function): - """Symbolically round up to the nearest multiple of ALIGN_BYTES""" - - nargs = (1,) - is_integer = True - - @classmethod - def eval(cls, value): - if isinstance(value, (int, sympy.Integer)): - return _align(int(value)) - if _is_aligned(value): - return value - - @dataclasses.dataclass class LiveRange: """ diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 8b4dbb179016..84279191ceac 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import os from typing import Any, List diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 23c602c10e5d..2063a183385b 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import collections @@ -20,7 +21,6 @@ Sequence, Set, Tuple, - TYPE_CHECKING, Union, ) @@ -54,8 +54,6 @@ from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter from .multi_kernel import MultiKernel -if TYPE_CHECKING: - pass log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -341,7 +339,8 @@ def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: index = self.combine_contiguous_dims(index, tree) - return index + + return self.combine_modular_indexing_pairs(index) self.simplify_indexing = simplify_indexing self.initialize_range_tree(pid_cache) @@ -425,7 +424,23 @@ def dense_size_str(self): sizes = self.dense_size_list() return f"[{', '.join(sizes)}]" + def combine_modular_indexing_pairs(self, index): + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + return V.graph.sizevars.combine_modular_indexing_pairs(new_index) + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): """ More aggressive simplification to merge contiguous dims """ @@ -1400,8 +1415,9 @@ def codegen_template( for node in [template_node, *epilogue_nodes]: node.mark_run() partial_code = render() - for node in epilogue_nodes: - node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + with kernel.set_subgraph_body(""): + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) if not isinstance(partial_code, str): partial_code.finalize_hook("") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d83680198e7d..9b6184f7e185 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses @@ -7,7 +8,18 @@ import os import textwrap from functools import lru_cache -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import sympy @@ -23,7 +35,6 @@ from .. import config, ir from ..codecache import code_hash, get_path, PyCodeCache -from ..ir import IRNode from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK from ..runtime.runtime_utils import do_bench_gpu, get_max_y_grid, next_power_of_2 @@ -52,6 +63,9 @@ from .simd import constant_repr, IterationRangesEntry, pexpr, SIMDKernel, SIMDScheduling from .triton_utils import config_of, signature_of, signature_to_meta +if TYPE_CHECKING: + from ..ir import IRNode + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") @@ -272,23 +286,68 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). If you are trying to hit this, maybe try something like + # torch.arange(n, device="cuda") - 1 and then do a modulus on it + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # TODO: This is wrong, see + # https://github.com/triton-lang/triton/issues/955 + # But for Sympy expressions, things will /mostly/ work out because we + # don't usually deal with negative numbers in the division + def _print_FloorDiv(self, expr): + assert expr.is_integer + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype def _print_floor(self, expr): assert len(expr.args) == 1 return ( f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _helper_sqrt(self, expr): return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" @@ -359,20 +418,9 @@ def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_FloorDiv(self, expr): - if expr.is_integer: - return super()._print_FloorDiv(expr) - - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" - - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 - return ( - f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - ) + return f"libdevice.llrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 @@ -432,12 +480,6 @@ def __init__(self, name, bounds: ValueRanges[Any]): self.mask_vars: Set[str] = set() def update_on_args(self, name, args, kwargs): - # When making a variable that is going to be used in indirect indexing - # if a where clause is used it should mean that the result is always a - # valid index, so you shouldn't include any of the dependent variables - # in the resulting load mask - if name == "where": - return for arg in args: if isinstance(arg, TritonCSEVariable): self.mask_vars.update(arg.mask_vars) @@ -889,7 +931,9 @@ def masked(mask, body, other): f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", bounds=ValueRanges.wrap(other), ) - return ops.where(new_mask, result, other) + ret = ops.where(new_mask, result, other) + ret.mask_vars.discard(new_mask) + return ret @staticmethod def load_seed(name, offset): @@ -1295,14 +1339,7 @@ def load(self, name: str, index: sympy.Expr): ep = ", eviction_policy='evict_first'" else: ep = "" - # "other" below is a workaround for https://github.com/openai/triton/issues/737 - # for bool, even though it's likely subject to the same bug, setting `other` leads - # to LLVM errors so we are skipping it for now - if ( - (has_tmpmask or has_rindex) - and V.graph.get_dtype(name) != torch.bool - and indexing.has_mask() - ): + if (has_tmpmask or has_rindex) and indexing.has_mask(): other = ", other=0.0" else: other = "" @@ -2313,10 +2350,18 @@ def codegen_nan_check(self): _, call_args, arg_types, _ = self.args.python_argdefs() for arg, arg_type in zip(call_args, arg_types): if isinstance(arg_type, TensorArg): - line = f"assert not {arg}.isnan().any().item()" - wrapper.writeline(line) - line = f"assert not {arg}.isinf().any().item()" - wrapper.writeline(line) + if V.graph.cpp_wrapper: + if config.abi_compatible: + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) + else: + wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});') + else: + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) def create_cse_var(self, *args, **kwargs): return TritonCSEVariable(*args, **kwargs) @@ -2352,7 +2397,10 @@ def iteration_ranges_get_pid(self, entry): and not entry.has_zdim and not (isinstance(entry.numel, int) and entry.numel <= get_max_y_grid()) ): - key = f"{key} * (tl.program_id({entry.grid_dim + 1}) + 1)" + # For ynumel larger than max_ygrid, we need to use zdim. + # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). + # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. + key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" pid = entry.pid_cache.get(key, key) if self.index_dtype != "tl.int32": return f"{pid}.to({self.index_dtype})" diff --git a/torch/_inductor/codegen/triton_foreach.py b/torch/_inductor/codegen/triton_foreach.py index 8ed909ec823a..4a909a6025d5 100644 --- a/torch/_inductor/codegen/triton_foreach.py +++ b/torch/_inductor/codegen/triton_foreach.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from collections import defaultdict from dataclasses import dataclass diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 6df3f39a9724..1e0475ffd0f9 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Optional, Set diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index ea6f25ae2c0a..2e4107f85916 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional import sympy diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 02f4fee19bb0..092dfd4e0b9c 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import dataclasses @@ -28,17 +29,12 @@ from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.codegen.multi_kernel import MultiKernelState -from torch.fx.experimental.symbolic_shapes import ( - ConvertIntKey, - DivideByKey, - free_unbacked_symbols, - SymTypes, -) +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes from torch.fx.node import _get_qualified_name from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.symbol import symbol_is_type, SymT -from .. import codecache, config, ir +from .. import async_compile, config, ir from ..ir import ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -444,7 +440,7 @@ def __init__(self): self.stride = "stride()" self.last_seen_device_guard_index: Optional[int] = None self.supports_intermediate_hooks = True - self.expr_printer = pexpr + self.expr_printer: Callable[[Any], str] = pexpr self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol self.allow_stack_allocation: Optional[bool] = None @@ -506,7 +502,7 @@ def write_header(self) -> None: from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided - from {codecache.__name__} import AsyncCompile + from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall @@ -516,8 +512,8 @@ def write_header(self) -> None: assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool - reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile() """ @@ -730,10 +726,6 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( ): self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})") - def generate_inf_and_nan_checker(self, node): - # TODO: Add check for python too. - pass - @dynamo_timed def generate(self, is_inference): if config.profile_bandwidth: @@ -864,7 +856,7 @@ def strideof(name): return f"{name}_stride" # Assign all symbolic shapes needed to local variables - needed = V.graph.sizevars.free_symbols() + bound_vars: Set[sympy.Symbol] = set() def is_expr(x): return isinstance(x[1], sympy.Expr) @@ -874,37 +866,28 @@ def is_expr(x): filter(lambda x: not is_expr(x), graph_inputs.items()) ) - def is_unbacked_symbol(s): - return isinstance(s, sympy.Symbol) and free_unbacked_symbols(s) - for name, shape in graph_inputs_expr: - shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] - if (b := shape in needed) or is_unbacked_symbol(shape): - if b: - needed.remove(shape) # type: ignore[arg-type] + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: code.writeline(f"{self.declare}{shape} = {name}{self.ending}") + bound_vars.add(shape) for name, value in graph_inputs_tensors: shapes = value.get_size() for dim, shape in enumerate(shapes): - shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] - if (b := shape in needed) or is_unbacked_symbol(shape): - if b: - needed.remove(shape) # type: ignore[arg-type] + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: code.writeline( f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" ) + bound_vars.add(shape) for name, value in graph_inputs_tensors: shapes = value.get_stride() for dim, shape in enumerate(shapes): - shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] - if (b := shape in needed) or is_unbacked_symbol(shape): - if b: - needed.remove(shape) # type: ignore[arg-type] + if isinstance(shape, sympy.Symbol) and shape not in bound_vars: code.writeline( f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" ) + bound_vars.add(shape) def ensure_size_computed(self, sym: sympy.Symbol): if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): @@ -920,10 +903,7 @@ def finalize_prefix(self): pass def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: - if simplify: - return pexpr(V.graph.sizevars.simplify(x)) - else: - return pexpr(x) + return pexpr(x, simplify=simplify) def codegen_sizevar(self, x: Expr) -> str: return self.codegen_python_sizevar(x) @@ -1224,7 +1204,9 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): # Also include any possible kernel being called indirectly from triton import JITFunction + from triton.language import constexpr + # global constexpr vars handled above symbols_included = {original_name} def traverse(cur_kernel): @@ -1237,6 +1219,7 @@ def traverse(cur_kernel): for inst in dis.Bytecode(cur_kernel.fn) if inst.opname == "LOAD_GLOBAL" } + global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) for symbol_name in cur_kernel.fn.__code__.co_names: if symbol_name in symbols_included: continue @@ -1248,9 +1231,25 @@ def traverse(cur_kernel): compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) - elif isinstance(symbol, (int, str, bool)): + elif isinstance(symbol, (int, str, bool, constexpr)): compile_wrapper.newline() - compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + if isinstance(symbol, constexpr): + symbol_str = f"tl.constexpr({symbol.value!r})" + else: + symbol_str = f"{symbol!r}" + if annotation := global_annotations.get(symbol_name): + annotion_code = "" + if isinstance(annotation, type): + annotation_code = ( + f": {annotation.__module__}.{annotation.__name__}" + ) + else: + annotation_code = f": {annotation!r}" + compile_wrapper.writeline( + f"{symbol_name}{annotation_code} = {symbol_str}" + ) + else: + compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") symbols_included.add(symbol_name) elif ( symbol_name in unqualified_loads @@ -1411,9 +1410,6 @@ def writelines(self, lines): def enter_context(self, ctx): self.lines.append(LineContext(ctx)) - def val_to_cpp_arg_str(self, val, type_) -> str: - raise NotImplementedError - def val_to_arg_str(self, s, type_=None): from torch.utils._triton import dtype_to_string, has_triton_package diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 1f1258898290..6eec71344ae8 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ..common import DeviceOpOverrides, register_device_op_overrides diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 3d9233b34370..71e8740a5fd7 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -1,3 +1,4 @@ +import functools import math from enum import IntEnum @@ -22,6 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum): HOPPER = 2 +@functools.lru_cache def get_gpu_type() -> NVIDIA_GPU_TYPE: gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" if "V100" in gpu_info: @@ -59,7 +61,7 @@ def get_collective_input_size_bytes(node: ir.IRNode) -> int: # For ease of testing numel = int(numel) else: - numel = V.graph.sizevars.size_hint(numel) + numel = V.graph.sizevars.size_hint(numel, fallback=0) sz_bytes += numel * get_dtype_size(inp.layout.dtype) return sz_bytes diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index a1fe0e1cdceb..9f95f7354437 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # pyre-strict from typing import List diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 7eca31da87b4..f29e05a723ec 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import itertools @@ -11,6 +12,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from unittest import mock +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools + import torch.fx import torch.utils._pytree as pytree @@ -22,6 +25,7 @@ utils as dynamo_utils, ) from torch._dynamo.utils import ( + counters, detect_fake_mode, flatten_graph_inputs, lazy_format_graph_code, @@ -118,6 +122,19 @@ def complex_memory_overlap(t: torch.Tensor) -> bool: return False +def get_static_input_idxs(num_fixed): + # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes + # of cudagraphs. Rather than copying these into cudagraph-owned memory + # like we do for normal inputs on each run, we will re-record a cudagraph if these + # parameter locations change. + context = torch._guards.TracingContext.try_get() + fixed = list(range(num_fixed)) + if not context or not context.fw_metadata: + return fixed + + return fixed + context.fw_metadata.static_parameter_indices + + @functools.lru_cache(None) def _step_logger(): return dynamo_logging.get_step_logger(log) @@ -375,7 +392,7 @@ def should_use_remote_fx_graph_cache(): return False try: - from triton.runtime.fb_memcache import MEMCACHE_VERSION + from triton.fb.fb_memcache import MEMCACHE_VERSION except ModuleNotFoundError: return False @@ -408,12 +425,12 @@ def with_fresh_cache_if_config(f): # the backward graph as well. @_use_lazy_graph_module(dynamo_config.use_lazy_graph_module) @with_fresh_cache_if_config -@dynamo_utils.dynamo_timed(phase_name="inductor_compile") +@dynamo_utils.dynamo_timed(phase_name="inductor_compile", fwd_only=False) def compile_fx_inner( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], cudagraphs: Optional[BoxedBool] = None, - num_fixed: int = 0, + static_input_idxs: Optional[List[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -438,6 +455,9 @@ def compile_fx_inner( _LazyGraphModule.force_recompile(gm) return make_boxed_func(gm.forward) + if static_input_idxs is None: + static_input_idxs = [] + assert isinstance( next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" @@ -447,7 +467,7 @@ def compile_fx_inner( gm, example_inputs, cudagraphs=cudagraphs, - num_fixed=num_fixed, + static_input_idxs=static_input_idxs, is_backward=is_backward, graph_id=graph_id, cpp_wrapper=cpp_wrapper, @@ -466,7 +486,7 @@ def compile_fx_inner( # of fx_codegen_and_compile changes, the dict should be updated accordingly graph_kwargs = { "cudagraphs": cudagraphs, - "num_fixed": num_fixed, + "static_input_idxs": static_input_idxs, "is_backward": is_backward, "graph_id": graph_id, "cpp_wrapper": cpp_wrapper, @@ -480,16 +500,26 @@ def compile_fx_inner( start = time.time() fx_graph_remote_cache = should_use_remote_fx_graph_cache() + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) if ( not config.force_disable_caches and (config.fx_graph_cache or fx_graph_remote_cache) and not aot_mode ): + for i, input in enumerate(example_inputs): + if ( + isinstance(input, torch.Tensor) + and input.device.type == "cuda" + and i in static_input_idxs + ): + input._is_inductor_static = True # type: ignore[attr-defined] + compiled_graph = FxGraphCache.load( fx_codegen_and_compile, gm, example_inputs, graph_kwargs, + inputs_to_check, local=config.fx_graph_cache, remote=fx_graph_remote_cache, ) @@ -506,6 +536,8 @@ def compile_fx_inner( log_cudagraph_skip_and_bump_counter( f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" ) + else: + counters["inductor"]["cudagraph_skips"] += 1 BoxedBool.disable(cudagraphs) # Return the output strides to the caller via TracingContext @@ -539,7 +571,7 @@ def compile_fx_inner( ) has_mutation_str = check_for_mutation_ignore_cuda_graph_managed_tensor( - gm, compiled_graph, num_fixed + gm, compiled_graph, static_input_idxs ) has_mutation = has_mutation_str is not None @@ -579,7 +611,7 @@ def compile_fx_inner( compiled_graph.current_callable = cudagraphify( compiled_graph.current_callable, example_inputs, - static_input_idxs=range(num_fixed), + static_input_idxs=static_input_idxs, device_index=next(iter(compiled_graph.device_idxs)), stack_traces=stack_traces, is_backward=is_backward, @@ -625,8 +657,8 @@ def compiled_artifact(new_inputs): # cudagraphs does its own aligning of inputs if not cudagraphs: - new_callable = align_inputs( - compiled_graph.current_callable, example_inputs, range(num_fixed) + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check ) if new_callable is not compiled_graph.current_callable: compiled_graph.current_callable = new_callable @@ -648,7 +680,7 @@ def fx_codegen_and_compile( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], cudagraphs: Optional[BoxedBool] = None, - num_fixed: int = 0, + static_input_idxs: Optional[List[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -713,10 +745,17 @@ def fx_codegen_and_compile( # has some issues with memory in training _recursive_post_grad_passes(gm, is_inference=is_inference) V.debug.fx_graph_transformed(gm, example_inputs) - post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm)) + post_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "AFTER POST GRAD", gm, include_stride=True, include_device=True + ), + ) trace_structured( "inductor_post_grad_graph", - payload_fn=lambda: gm.print_readable(print_output=False), + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), ) if config.is_fbcode(): log_optimus_to_scuba( @@ -737,7 +776,6 @@ def fx_codegen_and_compile( const_gm, example_inputs=[], shape_env=shape_env, - num_static_inputs=num_fixed, graph_id=graph_id, cpp_wrapper=cpp_wrapper, aot_mode=aot_mode, @@ -759,7 +797,6 @@ def fx_codegen_and_compile( # we currently use fake tensors and defake them later. example_inputs=example_inputs, shape_env=shape_env, - num_static_inputs=num_fixed, graph_id=graph_id, cpp_wrapper=cpp_wrapper, aot_mode=aot_mode, @@ -908,15 +945,6 @@ def run(new_inputs): return run -def align_inputs( - model: Callable[[List[torch.Tensor]], Any], - inputs: List[torch.Tensor], - static_input_idxs: Sequence[int] = (), -): - inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs) - return align_inputs_from_check_idxs(model, inputs_to_check) - - @dynamo_utils.dynamo_timed def cudagraphify( model: torch.fx.GraphModule, @@ -1177,6 +1205,7 @@ def fw_compiler_freezing( n.name for n in model_outputs if isinstance(n, torch.fx.Node) ) + static_input_idxs = list(range(num_fixed)) # constant params will be real tensors, not fake tracing_context = torch._guards.TracingContext.try_get() if tracing_context is not None: @@ -1186,11 +1215,14 @@ def fw_compiler_freezing( if i not in preserved_arg_indices: params_flat[i] = None + if tracing_context.fw_metadata: + static_input_idxs += tracing_context.fw_metadata.static_parameter_indices + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): optimized_function = inner_compile( opt_model, aot_example_inputs, - num_fixed=num_fixed, + static_input_idxs=static_input_idxs, cudagraphs=cudagraphs, graph_id=graph_id, is_inference=True, @@ -1321,6 +1353,7 @@ def fw_compiler_base( fixed = torch._inductor.utils.num_fw_fixed_arguments( num_example_inputs, len(example_inputs) ) + user_visible_outputs = {} if config.keep_output_stride: @@ -1376,7 +1409,7 @@ def fw_compiler_base( return inner_compile( model, example_inputs, - num_fixed=fixed, + static_input_idxs=get_static_input_idxs(fixed), cudagraphs=cudagraphs, graph_id=graph_id, is_inference=is_inference, @@ -1420,7 +1453,7 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) return inner_compile( model, example_inputs, - num_fixed=fixed, + static_input_idxs=list(range(fixed)), cudagraphs=cudagraphs, is_backward=True, graph_id=graph_id, diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index 6cd1d1e600ac..7f0965415bbf 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -1,9 +1,10 @@ +# mypy: allow-untyped-defs import argparse import os import sys import typing -from torch._inductor.codecache import caching_device_properties +from torch._inductor.async_compile import pre_fork_setup from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path @@ -34,9 +35,7 @@ def main(): # redirect output of workers to stderr os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) - # ensure properties have been calculated before processes - # are forked - caching_device_properties() + pre_fork_setup() _async_compile_initializer(args.parent) SubprocMain(args.workers, read_fd, write_fd).main() diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index f3f8e7b3b3ef..5aba18707b41 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging @@ -12,6 +13,7 @@ from concurrent.futures import Future, ProcessPoolExecutor from typing import Any, Callable, Dict +from torch._inductor import config from torch._inductor.compile_worker.watchdog import _async_compile_initializer log = logging.getLogger(__name__) @@ -58,6 +60,19 @@ def _recv_msg(read_pipe): return job_id, data +def _get_ld_library_path(): + path = os.environ.get("LD_LIBRARY_PATH", "") + if config.is_fbcode(): + from libfb.py.parutil import get_runtime_path + + runtime_path = get_runtime_path() + if runtime_path: + lib_path = os.path.join(runtime_path, "runtime", "lib") + path = os.pathsep.join([lib_path, path]) if path else lib_path + + return path + + class SubprocPool: """ Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in @@ -84,13 +99,14 @@ def __init__(self, nprocs: int): # torch._inductor.codecache since the warming process is what # creates the SubprocPool in the first place. "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": _get_ld_library_path(), }, ) self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin) self.write_lock = threading.Lock() self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout) self.read_thread = threading.Thread(target=self._read_thread, daemon=True) - self.read_thread.start() self.futures_lock = threading.Lock() self.pending_futures: Dict[int, Future[Any]] = {} @@ -98,6 +114,10 @@ def __init__(self, nprocs: int): self.running = True + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + def submit(self, job_fn: Callable[..., Any], *args): if args: job_fn = functools.partial(job_fn, *args) @@ -106,11 +126,11 @@ def submit(self, job_fn: Callable[..., Any], *args): with self.futures_lock: job_id = next(self.job_id_count) self.pending_futures[job_id] = future = Future() + future.set_running_or_notify_cancel() with self.write_lock: if not self.running: raise RuntimeError("submit() on closed pool") _send_msg(self.write_pipe, job_id, job_data) - future.set_running_or_notify_cancel() return future def _read_thread(self): diff --git a/torch/_inductor/compile_worker/watchdog.py b/torch/_inductor/compile_worker/watchdog.py index c91c9efb492c..f3956e1272e9 100644 --- a/torch/_inductor/compile_worker/watchdog.py +++ b/torch/_inductor/compile_worker/watchdog.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import signal from threading import Thread diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 960a3567f8c5..2ea60000d265 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os # noqa: C101 import sys from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union @@ -12,9 +13,6 @@ def is_fbcode(): # add some debug printouts debug = False -# add inf and NaN checkers -debug_check_inf_and_nan = False - # Whether to disable a progress bar for autotuning disable_progress = True @@ -387,7 +385,9 @@ def is_fbcode(): # The multiprocessing start method to use for inductor workers in the codecache. # "subprocess", "fork", or "spawn" def decide_worker_start_method(): - start_method = os.environ.get("TORCHINDUCTOR_WORKER_START", "fork") + start_method = os.environ.get( + "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" + ) assert start_method in [ "subprocess", "fork", @@ -421,6 +421,8 @@ def decide_worker_start_method(): "schedule_comm_wait", ] +_micro_pipeline_tp: bool = False + def decide_compile_threads(): """ @@ -804,8 +806,8 @@ class cuda: # Path to CUDA NVCC. # NVCC search order: # 1) cuda_cxx set in this config - # 2)CUDACXX environment variable - # 3)CUDA_HOME environment variable + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable # 4) default system search PATH. cuda_cxx: Optional[str] = None @@ -880,6 +882,12 @@ class trace: # to workaround the above failure. dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + # If not None, this is the URL that saves the SVG files of the input/output + # graph of each pass that changed the graph + # The nodes that are being transformed in each pass will be colored in yellow + # URL only supports local directory for now + log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None) + # Store cProfile (see snakeviz to view) compile_profile = False diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 5f5cc12be872..523aac95d354 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from typing import Any, Callable, Dict, Optional diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py new file mode 100644 index 000000000000..413270edc314 --- /dev/null +++ b/torch/_inductor/cpp_builder.py @@ -0,0 +1,1179 @@ +# mypy: allow-untyped-defs +# This CPP JIT builder is designed to support both Windows and Linux OS. +# The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 + +import copy +import errno +import functools +import logging +import os +import platform +import re +import shlex +import shutil +import subprocess +import sys +import sysconfig +import warnings +from pathlib import Path +from typing import List, Sequence, Tuple, Union + +import torch +from torch._inductor import config, exc +from torch._inductor.codecache import ( + _get_python_include_dirs, + _LINKER_SCRIPT, + _transform_cuda_paths, + get_lock_dir, + invalid_vec_isa, + LOCK_TIMEOUT, + VecISA, +) +from torch._inductor.runtime.runtime_utils import cache_dir + +if config.is_fbcode(): + from triton.fb import build_paths # noqa: F401 + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args, **kwargs): + pass + + def log_global_cache_stats(*args, **kwargs): + pass + + def log_global_cache_vals(*args, **kwargs): + pass + + def use_global_cache() -> bool: + return False + + +# Windows need setup a temp dir to store .obj files. +_BUILD_TEMP_DIR = "CxxBuild" + +# initialize variables for compilation +_IS_LINUX = sys.platform.startswith("linux") +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + + +log = logging.getLogger(__name__) + + +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler() # noqa: RSE102 + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +def _get_cpp_compiler() -> str: + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + else: + if config.is_fbcode(): + return build_paths.cc() + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + compiler = cpp_compiler_search(search) + return compiler + + +def _is_gcc(cpp_compiler) -> bool: + return bool(re.search(r"(gcc|g\+\+)", cpp_compiler)) + + +def is_gcc() -> bool: + return _is_gcc(_get_cpp_compiler()) + + +def _is_clang(cpp_compiler) -> bool: + # Mac OS apple clang maybe named as gcc, need check compiler info. + if sys.platform == "darwin": + return is_apple_clang(cpp_compiler) + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) + + +def is_clang() -> bool: + compiler = _get_cpp_compiler() + return _is_clang(compiler) + + +@functools.lru_cache(None) +def is_apple_clang(cpp_compiler) -> bool: + version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +def _append_list(dest_list: List[str], src_list: List[str]): + for item in src_list: + dest_list.append(copy.deepcopy(item)) + + +def _remove_duplication_in_list(orig_list: List[str]) -> List[str]: + new_list: List[str] = [] + for item in orig_list: + if item not in new_list: + new_list.append(item) + return new_list + + +def _create_if_dir_not_exist(path_dir): + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError( # noqa: TRY200 (Use `raise from`) + f"Fail to create path {path_dir}" + ) + + +def _remove_dir(path_dir): + if os.path.exists(path_dir): + for root, dirs, files in os.walk(path_dir, topdown=False): + for name in files: + file_path = os.path.join(root, name) + os.remove(file_path) + for name in dirs: + dir_path = os.path.join(root, name) + os.rmdir(dir_path) + os.rmdir(path_dir) + + +def run_command_line(cmd_line, cwd=None): + cmd = shlex.split(cmd_line) + try: + status = subprocess.check_output(args=cmd, cwd=cwd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + return status + + +class BuildOptionsBase: + """ + This is the Base class for store cxx build options, as a template. + Acturally, to build a cxx shared library. We just need to select a compiler + and maintains the suitable args. + """ + + def __init__(self) -> None: + self._compiler = "" + self._definations: List[str] = [] + self._include_dirs: List[str] = [] + self._cflags: List[str] = [] + self._ldflags: List[str] = [] + self._libraries_dirs: List[str] = [] + self._libraries: List[str] = [] + # Some args is hard to abstract to OS compatable, passthough it directly. + self._passthough_args: List[str] = [] + + self._aot_mode: bool = False + self._use_absolute_path: bool = False + self._compile_only: bool = False + + def _remove_duplicate_options(self): + self._definations = _remove_duplication_in_list(self._definations) + self._include_dirs = _remove_duplication_in_list(self._include_dirs) + self._cflags = _remove_duplication_in_list(self._cflags) + self._ldflags = _remove_duplication_in_list(self._ldflags) + self._libraries_dirs = _remove_duplication_in_list(self._libraries_dirs) + self._libraries = _remove_duplication_in_list(self._libraries) + self._passthough_args = _remove_duplication_in_list(self._passthough_args) + + def get_compiler(self) -> str: + return self._compiler + + def get_definations(self) -> List[str]: + return self._definations + + def get_include_dirs(self) -> List[str]: + return self._include_dirs + + def get_cflags(self) -> List[str]: + return self._cflags + + def get_ldflags(self) -> List[str]: + return self._ldflags + + def get_libraries_dirs(self) -> List[str]: + return self._libraries_dirs + + def get_libraries(self) -> List[str]: + return self._libraries + + def get_passthough_args(self) -> List[str]: + return self._passthough_args + + def get_aot_mode(self) -> bool: + return self._aot_mode + + def get_use_absolute_path(self) -> bool: + return self._use_absolute_path + + def get_compile_only(self) -> bool: + return self._compile_only + + +def _get_warning_all_cflag(warning_all: bool = True) -> List[str]: + if not _IS_WINDOWS: + return ["Wall"] if warning_all else [] + else: + return [] + + +def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]: + if _IS_WINDOWS: + return [f"std:{std_num}"] + else: + return [f"std={std_num}"] + + +def _get_linux_cpp_cflags(cpp_compiler) -> List[str]: + if not _IS_WINDOWS: + cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] + if _is_clang(cpp_compiler): + cflags.append("Werror=ignored-optimization-argument") + return cflags + else: + return [] + + +def _get_optimization_cflags() -> List[str]: + if _IS_WINDOWS: + return ["O2"] + else: + cflags = ["O0", "g"] if config.aot_inductor.debug_compile else ["O3", "DNDEBUG"] + cflags.append("ffast-math") + cflags.append("fno-finite-math-only") + + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + if not config.cpp.enable_floating_point_contract_flag: + cflags.append("ffp-contract=off") + + if config.is_fbcode(): + # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies. + # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths. + # We will fix it later by exposing the lib path. + return cflags + + if sys.platform == "darwin": + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + # Also, `-march=native` is unrecognized option on M1 + cflags.append("Xclang") + else: + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + # Internal cannot find libgomp.so + if not config.is_fbcode(): + cflags.append("fopenmp") + + return cflags + + +def _get_shared_cflag(compile_only: bool) -> List[str]: + if _IS_WINDOWS: + SHARED_FLAG = ["DLL"] + else: + if compile_only: + return ["fPIC"] + if platform.system() == "Darwin" and "clang" in _get_cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return ["shared", "fPIC", "undefined dynamic_lookup"] + else: + return ["shared", "fPIC"] + + return SHARED_FLAG + + +def get_cpp_options( + cpp_compiler, + compile_only: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), +): + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + cflags = ( + _get_shared_cflag(compile_only) + + _get_optimization_cflags() + + _get_warning_all_cflag(warning_all) + + _get_cpp_std_cflag() + + _get_linux_cpp_cflags(cpp_compiler) + ) + + passthough_args.append(" ".join(extra_flags)) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppOptions(BuildOptionsBase): + """ + This class is inherited from BuildOptionsBase, and as cxx build options. + This option need contains basic cxx build option, which contains: + 1. OS related args. + 2. Toolchains related args. + 3. Cxx standard related args. + Note: + 1. This Options is good for assist modules build, such as x86_isa_help. + """ + + def __init__( + self, + compile_only: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + use_absolute_path: bool = False, + ) -> None: + super().__init__() + self._compiler = _get_cpp_compiler() + self._use_absolute_path = use_absolute_path + self._compile_only = compile_only + + ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) = get_cpp_options( + cpp_compiler=self._compiler, + compile_only=compile_only, + extra_flags=extra_flags, + warning_all=warning_all, + ) + + _append_list(self._definations, definations) + _append_list(self._include_dirs, include_dirs) + _append_list(self._cflags, cflags) + _append_list(self._ldflags, ldflags) + _append_list(self._libraries_dirs, libraries_dirs) + _append_list(self._libraries, libraries) + _append_list(self._passthough_args, passthough_args) + self._remove_duplicate_options() + + +def _get_glibcxx_abi_build_flags() -> List[str]: + if not _IS_WINDOWS: + return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + else: + return [] + + +def _get_torch_cpp_wrapper_defination() -> List[str]: + return ["TORCH_INDUCTOR_CPP_WRAPPER"] + + +def _use_custom_generated_macros() -> List[str]: + return [" C10_USING_CUSTOM_GENERATED_MACROS"] + + +def _use_fb_internal_macros() -> List[str]: + if not _IS_WINDOWS: + if config.is_fbcode(): + fb_internal_macros = [ + "C10_USE_GLOG", + "C10_USE_MINIMAL_GLOG", + "C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ] + # TODO: this is to avoid FC breakage for fbcode. When using newly + # generated model.so on an older verion of PyTorch, need to use + # the v1 version for aoti_torch_create_tensor_from_blob + create_tensor_from_blob_v1 = "AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" + + fb_internal_macros.append(create_tensor_from_blob_v1) + + # TODO: remove comments later: + # Moved to _get_openmp_args + # openmp_lib = build_paths.openmp_lib() + # return [f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"] + return fb_internal_macros + else: + return [] + else: + return [] + + +def _setup_standard_sys_libs( + cpp_compiler, + aot_mode: bool, + use_absolute_path: bool, +): + cflags: List[str] = [] + include_dirs: List[str] = [] + passthough_args: List[str] = [] + if _IS_WINDOWS: + return cflags, include_dirs, passthough_args + + if config.is_fbcode(): + cflags.append("nostdinc") + include_dirs.append(build_paths.sleef()) + include_dirs.append(build_paths.cc_include()) + include_dirs.append(build_paths.libgcc()) + include_dirs.append(build_paths.libgcc_arch()) + include_dirs.append(build_paths.libgcc_backward()) + include_dirs.append(build_paths.glibc()) + include_dirs.append(build_paths.linux_kernel()) + include_dirs.append("include") + + if aot_mode and not use_absolute_path: + linker_script = _LINKER_SCRIPT + else: + linker_script = os.path.basename(_LINKER_SCRIPT) + + if _is_clang(cpp_compiler): + passthough_args.append(" --rtlib=compiler-rt") + passthough_args.append(" -fuse-ld=lld") + passthough_args.append(f" -Wl,--script={linker_script}") + passthough_args.append(" -B" + build_paths.glibc_lib()) + passthough_args.append(" -L" + build_paths.glibc_lib()) + + return cflags, include_dirs, passthough_args + + +@functools.lru_cache +def _cpp_prefix_path() -> str: + from torch._inductor.codecache import write # TODO + + path = Path(Path(__file__).parent).parent / "codegen/cpp_prefix.h" + with path.open() as f: + content = f.read() + _, filename = write( + content, + "h", + ) + return filename + + +def _get_build_args_of_chosen_isa(vec_isa: VecISA): + macros = [] + build_flags = [] + if vec_isa != invalid_vec_isa: + # Add Windows support later. + for x in vec_isa.build_macro(): + macros.append(copy.deepcopy(x)) + + build_flags = [vec_isa.build_arch_flags()] + + if config.is_fbcode() and vec_isa != invalid_vec_isa: + cap = str(vec_isa).upper() + macros = [ + f"CPU_CAPABILITY={cap}", + f"CPU_CAPABILITY_{cap}", + f"HAVE_{cap}_CPU_DEFINITION", + ] + + return macros, build_flags + + +def _get_torch_related_args(include_pytorch: bool, aot_mode: bool): + from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH + + include_dirs = [ + os.path.join(_TORCH_PATH, "include"), + os.path.join(_TORCH_PATH, "include", "torch", "csrc", "api", "include"), + # Some internal (old) Torch headers don't properly prefix their includes, + # so we need to pass -Itorch/lib/include/TH as well. + os.path.join(_TORCH_PATH, "include", "TH"), + os.path.join(_TORCH_PATH, "include", "THC"), + ] + libraries_dirs = [TORCH_LIB_PATH] + libraries = [] + if sys.platform == "linux" and not config.is_fbcode(): + libraries = ["torch", "torch_cpu"] + if not aot_mode: + libraries.append("torch_python") + + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 + if not config.abi_compatible: + libraries.append("c10") + libraries_dirs.append(TORCH_LIB_PATH) + + return include_dirs, libraries_dirs, libraries + + +def _get_python_related_args(): + python_include_dirs = _get_python_include_dirs() + python_include_path = sysconfig.get_path( + "include", scheme="nt" if _IS_WINDOWS else "posix_prefix" + ) + if python_include_path is not None: + python_include_dirs.append(python_include_path) + + if _IS_WINDOWS: + python_path = os.path.dirname(sys.executable) + python_lib_path = [os.path.join(python_path, "libs")] + else: + python_lib_path = [sysconfig.get_config_var("LIBDIR")] + + if config.is_fbcode(): + python_include_dirs.append(build_paths.python()) + + return python_include_dirs, python_lib_path + + +def _get_openmp_args(cpp_compiler): + cflags: List[str] = [] + ldflags: List[str] = [] + include_dir_paths: List[str] = [] + lib_dir_paths: List[str] = [] + libs: List[str] = [] + passthough_args: List[str] = [] + if _IS_MACOS: + from torch._inductor.codecache import ( + homebrew_libomp, + is_conda_llvm_openmp_installed, + ) + + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not is_apple_clang(cpp_compiler) + + # check the `OMP_PREFIX` environment first + omp_prefix = os.getenv("OMP_PREFIX") + if omp_prefix is not None: + header_path = os.path.join(omp_prefix, "include", "omp.h") + valid_env = os.path.exists(header_path) + if valid_env: + include_dir_paths.append(os.path.join(omp_prefix, "include")) + lib_dir_paths.append(os.path.join(omp_prefix, "lib")) + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + if not omp_available: + libs.append("omp") + + # prefer to use openmp from `conda install llvm-openmp` + conda_prefix = os.getenv("CONDA_PREFIX") + if not omp_available and conda_prefix is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(conda_prefix, "lib") + include_dir_paths.append(os.path.join(conda_prefix, "include")) + lib_dir_paths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs.append("iomp5") + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + include_dir_paths.append(os.path.join(libomp_path, "include")) + lib_dir_paths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + elif _IS_WINDOWS: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + + cflags.append("openmp") + libs = [] + else: + if config.is_fbcode(): + include_dir_paths.append(build_paths.openmp()) + + openmp_lib = build_paths.openmp_lib() + fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" + passthough_args.append(fb_openmp_extra_flags) + + libs.append("omp") + else: + if _is_clang(cpp_compiler): + # TODO: fix issue, can't find omp.h + cflags.append("fopenmp") + libs.append("gomp") + else: + cflags.append("fopenmp") + libs.append("gomp") + + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthough_args + + +def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]: + macros = [] + if use_mmap_weights: + macros.append(" USE_MMAP_SELF") + return macros + + +def get_cpp_torch_options( + cpp_compiler, + vec_isa: VecISA, + include_pytorch: bool, + aot_mode: bool, + compile_only: bool, + use_absolute_path: bool, + use_mmap_weights: bool, +): + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination() + use_custom_generated_macros_definations = _use_custom_generated_macros() + + ( + sys_libs_cflags, + sys_libs_include_dirs, + sys_libs_passthough_args, + ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_absolute_path) + + isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) + + ( + torch_include_dirs, + torch_libraries_dirs, + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( + omp_cflags, + omp_ldflags, + omp_include_dir_paths, + omp_lib_dir_paths, + omp_lib, + omp_passthough_args, + ) = _get_openmp_args(cpp_compiler) + + cxx_abi_passthough_args = _get_glibcxx_abi_build_flags() + fb_macro_passthough_args = _use_fb_internal_macros() + + mmap_self_macros = get_mmap_self_macro(use_mmap_weights) + + definations = ( + torch_cpp_wrapper_definations + + use_custom_generated_macros_definations + + isa_macros + + fb_macro_passthough_args + + mmap_self_macros + ) + include_dirs = ( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib + passthough_args = ( + sys_libs_passthough_args + + isa_ps_args_build_flags + + cxx_abi_passthough_args + + omp_passthough_args + ) + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchOptions(CppOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options. And then it will maintains torch related build + args. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + """ + + def __init__( + self, + vec_isa: VecISA, + include_pytorch: bool = False, + warning_all: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + compile_only=compile_only, + warning_all=warning_all, + extra_flags=extra_flags, + use_absolute_path=use_absolute_path, + ) + + self._aot_mode = aot_mode + + ( + torch_definations, + torch_include_dirs, + torch_cflags, + torch_ldflags, + torch_libraries_dirs, + torch_libraries, + torch_passthough_args, + ) = get_cpp_torch_options( + cpp_compiler=self._compiler, + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + + if compile_only: + torch_libraries_dirs = [] + torch_libraries = [] + + _append_list(self._definations, torch_definations) + _append_list(self._include_dirs, torch_include_dirs) + _append_list(self._cflags, torch_cflags) + _append_list(self._ldflags, torch_ldflags) + _append_list(self._libraries_dirs, torch_libraries_dirs) + _append_list(self._libraries, torch_libraries) + _append_list(self._passthough_args, torch_passthough_args) + self._remove_duplicate_options() + + +def get_cpp_torch_cuda_options(cuda: bool, aot_mode: bool = False): + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.cuda() + + from torch.utils import cpp_extension + + include_dirs = cpp_extension.include_paths(cuda) + libraries_dirs = cpp_extension.library_paths(cuda) + + if cuda: + definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") + + if torch.version.hip is not None: + if config.is_fbcode(): + libraries += ["amdhip64"] + else: + libraries += ["c10_hip", "torch_hip"] + definations.append(" __HIP_PLATFORM_AMD__") + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + libraries += ["c10_cuda", "cuda", "torch_cuda"] + + if aot_mode: + cpp_prefix_include_dir = [f"{os.path.dirname(_cpp_prefix_path())}"] + include_dirs += cpp_prefix_include_dir + + if cuda and torch.version.hip is None: + _transform_cuda_paths(libraries_dirs) + + if config.is_fbcode(): + if torch.version.hip is not None: + include_dirs.append(os.path.join(build_paths.rocm(), "include")) + else: + include_dirs.append(os.path.join(build_paths.cuda(), "include")) + + if aot_mode and cuda and config.is_fbcode(): + if torch.version.hip is None: + # TODO: make static link better on Linux. + passthough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +class CppTorchCudaOptions(CppTorchOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options and torch common build options. And then it will + maintains cuda device related build args. + """ + + def __init__( + self, + vec_isa: VecISA, + include_pytorch: bool = False, + cuda: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + ) -> None: + super().__init__( + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + extra_flags=extra_flags, + ) + + cuda_definations: List[str] = [] + cuda_include_dirs: List[str] = [] + cuda_cflags: List[str] = [] + cuda_ldflags: List[str] = [] + cuda_libraries_dirs: List[str] = [] + cuda_libraries: List[str] = [] + cuda_passthough_args: List[str] = [] + + ( + cuda_definations, + cuda_include_dirs, + cuda_cflags, + cuda_ldflags, + cuda_libraries_dirs, + cuda_libraries, + cuda_passthough_args, + ) = get_cpp_torch_cuda_options(cuda=cuda, aot_mode=aot_mode) + + if compile_only: + cuda_libraries_dirs = [] + cuda_libraries = [] + + _append_list(self._definations, cuda_definations) + _append_list(self._include_dirs, cuda_include_dirs) + _append_list(self._cflags, cuda_cflags) + _append_list(self._ldflags, cuda_ldflags) + _append_list(self._libraries_dirs, cuda_libraries_dirs) + _append_list(self._libraries, cuda_libraries) + _append_list(self._passthough_args, cuda_passthough_args) + self._remove_duplicate_options() + + +def get_name_and_dir_from_output_file_path( + aot_mode: bool, use_absolute_path: bool, file_path: str +): + name_and_ext = os.path.basename(file_path) + name, ext = os.path.splitext(name_and_ext) + dir = os.path.dirname(file_path) + + if config.is_fbcode(): + if not (aot_mode and not use_absolute_path): + dir = "." + return name, dir + + +class CppBuilder: + """ + CppBuilder is a cpp jit builder, and it supports both Windows, Linux and MacOS. + Args: + name: + 1. Build target name, the final target file will append extension type automatically. + 2. Due to the CppBuilder is supports mutliple OS, it will maintains ext for OS difference. + sources: + Source code file list to be built. + BuildOption: + Build options to the builder. + output_dir: + 1. The output_dir the taget file will output to. + 2. The default value is empty string, and then the use current dir as output dir. + 3. Final target file: output_dir/name.ext + """ + + def get_shared_lib_ext(self) -> str: + SHARED_LIB_EXT = ".dll" if _IS_WINDOWS else ".so" + return SHARED_LIB_EXT + + def get_object_ext(self) -> str: + EXT = ".obj" if _IS_WINDOWS else ".o" + return EXT + + def __init__( + self, + name: str, + sources: Union[str, List[str]], + BuildOption: BuildOptionsBase, + output_dir: str = "", + ) -> None: + self._compiler = "" + self._cflags_args = "" + self._definations_args = "" + self._include_dirs_args = "" + self._ldflags_args = "" + self._libraries_dirs_args = "" + self._libraries_args = "" + self._passthough_parameters_args = "" + + self._output_dir = "" + self._target_file = "" + + self._use_absolute_path: bool = False + + self._name = name + + # Code start here, initial self internal veriables firstly. + self._compiler = BuildOption.get_compiler() + self._use_absolute_path = BuildOption.get_use_absolute_path() + + if len(output_dir) == 0: + self._output_dir = os.path.dirname(os.path.abspath(__file__)) + else: + self._output_dir = output_dir + + self._compile_only = BuildOption.get_compile_only() + file_ext = ( + self.get_object_ext() if self._compile_only else self.get_shared_lib_ext() + ) + self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") + + if isinstance(sources, str): + sources = [sources] + + if config.is_fbcode(): + if BuildOption.get_aot_mode() and not self._use_absolute_path: + inp_name = sources + # output process @ get_name_and_dir_from_output_file_path + else: + # We need to copy any absolute-path torch includes + inp_name = [os.path.basename(i) for i in sources] + self._target_file = os.path.basename(self._target_file) + + self._sources_args = " ".join(inp_name) + else: + self._sources_args = " ".join(sources) + + for cflag in BuildOption.get_cflags(): + if _IS_WINDOWS: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " + + for defination in BuildOption.get_definations(): + if _IS_WINDOWS: + self._definations_args += f"/D {defination} " + else: + self._definations_args += f"-D {defination} " + + for inc_dir in BuildOption.get_include_dirs(): + if _IS_WINDOWS: + self._include_dirs_args += f"/I {inc_dir} " + else: + self._include_dirs_args += f"-I{inc_dir} " + + for ldflag in BuildOption.get_ldflags(): + if _IS_WINDOWS: + self._ldflags_args += f"/{ldflag} " + else: + self._ldflags_args += f"-{ldflag} " + + for lib_dir in BuildOption.get_libraries_dirs(): + if _IS_WINDOWS: + self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" ' + else: + self._libraries_dirs_args += f"-L{lib_dir} " + + for lib in BuildOption.get_libraries(): + if _IS_WINDOWS: + self._libraries_args += f'"{lib}.lib" ' + else: + self._libraries_args += f"-l{lib} " + + for passthough_arg in BuildOption.get_passthough_args(): + self._passthough_parameters_args += f"{passthough_arg} " + + def get_command_line(self) -> str: + def format_build_command( + compiler, + sources, + include_dirs_args, + definations_args, + cflags_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + passthougn_args, + target_file, + ): + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704 + # https://stackoverflow.com/a/31566153 + cmd = ( + f"{compiler} {include_dirs_args} {definations_args} {cflags_args} {sources} " + f"{passthougn_args} /LD /Fe{target_file} /link {libraries_dirs_args} {libraries_args} {ldflags_args} " + ) + cmd = cmd.replace("\\", "/") + else: + compile_only_arg = "-c" if self._compile_only else "" + cmd = re.sub( + r"[ \n]+", + " ", + f""" + {compiler} {sources} {definations_args} {cflags_args} {include_dirs_args} + {passthougn_args} {ldflags_args} {libraries_args} {libraries_dirs_args} {compile_only_arg} -o {target_file} + """, + ).strip() + return cmd + + command_line = format_build_command( + compiler=self._compiler, + sources=self._sources_args, + include_dirs_args=self._include_dirs_args, + definations_args=self._definations_args, + cflags_args=self._cflags_args, + ldflags_args=self._ldflags_args, + libraries_args=self._libraries_args, + libraries_dirs_args=self._libraries_dirs_args, + passthougn_args=self._passthough_parameters_args, + target_file=self._target_file, + ) + return command_line + + def get_target_file_path(self): + return self._target_file + + def convert_to_cpp_extension_args(self): + include_dirs = self._include_dirs_args + cflags = ( + self._cflags_args + + self._definations_args + + self._passthough_parameters_args + ) + ldflags = self._ldflags_args + self._libraries_args + self._libraries_dirs_args + + return include_dirs, cflags, ldflags + + def build(self) -> Tuple[int, str]: + """ + It is must need a temperary directory to store object files in Windows. + After build completed, delete the temperary directory to save disk space. + """ + _create_if_dir_not_exist(self._output_dir) + _build_tmp_dir = os.path.join( + self._output_dir, f"{self._name}_{_BUILD_TEMP_DIR}" + ) + _create_if_dir_not_exist(_build_tmp_dir) + + build_cmd = self.get_command_line() + + status = run_command_line(build_cmd, cwd=_build_tmp_dir) + + _remove_dir(_build_tmp_dir) + return status, self._target_file diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 6fe00710b0af..2b6a9dab45da 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, which share the same memory pool. Sharing a memory pool is an extremely @@ -80,6 +81,7 @@ from torch._inductor.cudagraph_utils import ( check_for_mutation, FunctionID, + get_placeholder_stack_trace, log_cudagraph_skip_and_bump_counter, WrappedFunction, ) @@ -752,6 +754,11 @@ def __init__( self.device = device_index self.stack_traces = stack_traces self.stream = stream + # If we are inlining builtin nn modules we will re-record if static inputs change + # if not we should error because dynamo should have recompiled in this case + self.rerecord_if_static_inputs_change = ( + torch._dynamo.config.inline_inbuilt_nn_modules + ) # if this is a root parent will be None. use weakref to prevent reference cycle self._parent = weakref.ref(parent) if parent is not None else None @@ -951,8 +958,13 @@ def _copy_inputs_and_remove_from_src(self, dsts, srcs): def check_static_inputs_are_stable(self, new_inputs): # avoid checking managed tensor static points since we already checked those in check_invariants - if not torch._C._tensors_data_ptrs_at_indices_equal( - new_inputs, self.static_input_data_ptrs, self.non_managed_static_input_idxs + if ( + not self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + ) ): # this should error static_tensors = [new_inputs[i] for i in self.non_managed_static_input_idxs] @@ -960,11 +972,17 @@ def check_static_inputs_are_stable(self, new_inputs): self.static_input_data_ptrs[i] for i in self.non_managed_static_input_idxs ] - for t, data_ptr in zip(static_tensors, data_ptrs): - torch._check( - t.data_ptr() == data_ptr, - lambda: f"static input data pointer changed from {data_ptr} to {t.data_ptr()}", - ) + error_msg = "static input data pointer changed.\n" + for i, (t, data_ptr) in enumerate(zip(static_tensors, data_ptrs)): + index = self.non_managed_static_input_idxs[i] + if t.data_ptr() != data_ptr: + placeholder = self.wrapped_function.placeholders[index] + error_msg = ( + f"{error_msg}input name: {placeholder.name}. " + f"data pointer changed from {data_ptr} to {t.data_ptr()}. " + f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" + ) + torch._check(False, lambda: error_msg) def run_first_inputs(self, new_inputs): if config.triton.fast_path_cudagraph_asserts: @@ -993,6 +1011,9 @@ def run(self, new_inputs): if config.triton.force_cudagraph_sync: torch.cuda.synchronize() + # Reset this to run the check in the future + self.static_inputs_stable = False + return outputs def reconstruct_outputs(self): @@ -1546,8 +1567,8 @@ def _allocate_and_copy_recording_inputs( def check_invariants(self, inputs: List[Tensor]) -> bool: """ - Checks if this node can be run. The same pattern of tensor liveness and tensors - managed in the cudagraph private pool must remain stable. + Checks if this node can be run. The same pattern of tensor liveness, static inputs, + and tensors managed in the cudagraph private pool must remain stable. """ # previously managed data pointers remain stable @@ -1558,6 +1579,18 @@ def check_invariants(self, inputs: List[Tensor]) -> bool: ): return False + # static input data pointers should remain stable + # if we are inlining builtin nn modules we re-record in this case + # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable + # and error if they are not stable + if ( + self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, self.static_input_data_ptrs, self.static_input_idxs + ) + ): + return False + if not self._check_liveness( self.expected_dead_indices_before_graph, self.path_weakrefs ): diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index c87022fcb788..188c91ba65f0 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from typing import Any, Callable, Dict, List, Optional, Tuple @@ -143,15 +144,16 @@ def set(self, device_idx: Optional[int]): def check_for_mutation_ignore_cuda_graph_managed_tensor( - gm: torch.fx.GraphModule, compiled_graph, num_fixed: int + gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: List[int] ) -> Optional[str]: default_msg = format_default_skip_message("mutated inputs") # doesnt work for non-trees because the warmup run would apply mutation twice if torch._inductor.config.triton.cudagraph_trees: + unique_idxs = set(static_input_idxs) # checking if mutation is only on parameters/static inputs mutation_indices = [ - idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed + idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs ] has_mutation = len(mutation_indices) != 0 if not has_mutation: @@ -162,3 +164,17 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor( else: has_mutation = len(compiled_graph.mutated_inputs) != 0 return None if not has_mutation else default_msg + + +def get_placeholder_stack_trace(placeholder: torch.fx.Node) -> Optional[str]: + """ + Gets the first non-empty stack trace of a placeholder or its users. + """ + if placeholder.stack_trace: + return placeholder.stack_trace + + for user in placeholder.users: + if user.stack_trace: + return user.stack_trace + + return None diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index ef1beb7c15a4..b0ad369c4316 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -1,6 +1,6 @@ +# mypy: allow-untyped-defs import collections import contextlib -import cProfile import dataclasses import functools import itertools @@ -388,9 +388,6 @@ def reset_log_level(level): self._setup_log_capture("debug.log", logging.DEBUG) if config.trace.info_log: self._setup_log_capture("info.log", logging.INFO) - if config.trace.compile_profile: - self._prof = cProfile.Profile() - self._prof.enable() def _setup_log_capture(self, filename: str, level: int): log = logging.getLogger("torch._inductor") diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 960c3a42e1f1..c9c3eb579e6c 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import math diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index d7cd3ce64f4f..d5abfaa49696 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import collections import dataclasses diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 9e6aa6effae2..8a172d8c29b1 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import os @@ -45,7 +46,7 @@ def __init__(self, target, args, kwargs): There is a decomposition available for {target} in torch._decomp.get_decompositions(). Please add this operator to the - `decompositions` list in torch._inductor.decompositions + `decompositions` list in torch._inductor.decomposition """ ) ) diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 7d7cbed25193..9a5f12820a2b 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import itertools diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index 5cfabf9b7707..7453cde1ce9d 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 532a546dd4b6..6ef0f71a807c 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -22,9 +22,11 @@ import torch import torch.fx as fx from torch._dynamo.utils import counters +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from .. import config from ..fx_utils import get_fake_args_kwargs from ..virtualized import V @@ -578,14 +580,19 @@ def schedule_comm_wait(graph: fx.Graph) -> None: def fuse_ddp_communication( graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int ) -> None: - for pa in passes: - if isinstance(pa, str): - func = globals()[pa] - else: - func = pa - if "bucket_size_mb" in { - v.name for v in inspect.signature(func).parameters.values() - }: - func(graph, bucket_size_mb=bucket_size_mb) - else: - func(graph) + for i, pa in enumerate(passes): + with GraphTransformObserver( + graph.owning_module, + f"fuse_ddp_communication_pass_{i}", + config.trace.log_url_for_graph_xform, + ): + if isinstance(pa, str): + func = globals()[pa] + else: + func = pa + if "bucket_size_mb" in { + v.name for v in inspect.signature(func).parameters.values() + }: + func(graph, bucket_size_mb=bucket_size_mb) + else: + func(graph) diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index 793d29383f56..dba2f62e7d6f 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import List @@ -19,12 +20,12 @@ min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION max_other_dimention_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION -if "decompose_mem_bound_mm" in config.post_grad_fusion_options: +if "decompose_mm_pass" in config.post_grad_fusion_options: min_first_dimension_decomposition = config.post_grad_fusion_options[ - "decompose_mem_bound_mm" + "decompose_mm_pass" ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION) max_other_dimention_decomposition = config.post_grad_fusion_options[ - "decompose_mem_bound_mm" + "decompose_mm_pass" ].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) diff --git a/torch/_inductor/fx_passes/dedupe_symint_uses.py b/torch/_inductor/fx_passes/dedupe_symint_uses.py index 7145508a3ae2..646e8d16f4d2 100644 --- a/torch/_inductor/fx_passes/dedupe_symint_uses.py +++ b/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import Union diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 7ab01e0abbb2..7aecc3f15f33 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index fe39b13033a7..039fea2dcca2 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 3fbb67cb2776..fad49d404827 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -1,9 +1,11 @@ +# mypy: allow-untyped-defs import functools import inspect import logging import math import torch +from torch.nn.attention import sdpa_kernel, SDPBackend from ..._dynamo.utils import counters from ..pattern_matcher import ( filter_nodes, @@ -16,6 +18,16 @@ aten = torch.ops.aten +if torch.version.hip: + + def _scaled_dot_product_attention(*args, **kwargs): + with sdpa_kernel(backends=[SDPBackend.MATH, SDPBackend.FLASH_ATTENTION]): + return aten.scaled_dot_product_attention(*args, **kwargs) + +else: + _scaled_dot_product_attention = aten.scaled_dot_product_attention + + def _sfdp_pattern_1(query, key, value, inv_scale): return ( torch.matmul(query, key.transpose(-2, -1)) @@ -27,7 +39,7 @@ def _sfdp_pattern_1(query, key, value, inv_scale): def _sfdp_replacement_1(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -49,7 +61,7 @@ def _sfdp_pattern_2(query, key, value, scale_factor): def _sfdp_replacement_2(query, key, value, scale_factor): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -71,7 +83,7 @@ def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -91,7 +103,7 @@ def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -112,7 +124,7 @@ def _sfdp_pattern_5(query, key, value, attn_mask): def _sfdp_replacement_5(query, key, value, attn_mask): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -132,7 +144,7 @@ def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.contiguous(), key.contiguous(), value.contiguous(), @@ -167,7 +179,7 @@ def _sfdp_replacement_7(query, key, value, dropout_p): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -194,7 +206,7 @@ def _sfdp_replacement_8(query, key, value): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -222,7 +234,7 @@ def _sfdp_replacement_9(query, key, value, dropout_p): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -250,7 +262,7 @@ def _sfdp_replacement_10(query, key, value): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( q, k, v, @@ -270,7 +282,7 @@ def _sfdp_pattern_11(query, key, value, inv_scale): def _sfdp_replacement_11(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -293,7 +305,7 @@ def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -312,7 +324,7 @@ def _sfdp_pattern_13(query, key, value, dropout_p): def _sfdp_replacement_13(query, key, value, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), @@ -336,7 +348,7 @@ def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -369,11 +381,11 @@ def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): n_head = query.size(2) q_len = query.size(1) k_len = key.size(1) - # do attn_mask->logical_not() in aten.scaled_dot_product_attention + # do attn_mask->logical_not() in _scaled_dot_product_attention attn_mask = ( (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) ) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -403,7 +415,7 @@ def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): counters["inductor"]["fuse_attention"] += 1 - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -439,11 +451,11 @@ def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): n_head = query.size(2) q_len = query.size(1) k_len = key.size(1) - # do attn_mask->logical_not() in aten.scaled_dot_product_attention + # do attn_mask->logical_not() in _scaled_dot_product_attention attn_mask = ( (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) ) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), @@ -488,7 +500,7 @@ def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): permuted_key = key.transpose(1, 2) permuted_value = value.transpose(1, 2) return ( - aten.scaled_dot_product_attention( + _scaled_dot_product_attention( query.transpose(1, 2), permuted_key, permuted_value, @@ -525,7 +537,7 @@ def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): counters["inductor"]["fuse_attention"] += 1 fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) attn_mask = torch.where(causal_mask, attn_mask, fill_value) - return aten.scaled_dot_product_attention( + return _scaled_dot_product_attention( query, key, value, @@ -883,6 +895,9 @@ def _get_sfdp_patterns(): "pass_dicts": patterns, "extra_check": extra_check, "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, } diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 90c59a06bab7..9a9d4cd136da 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import logging import operator @@ -18,6 +19,7 @@ import torch from torch._dynamo.utils import counters, optimus_scuba_log from torch._utils_internal import upload_graph +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .. import config from ..pattern_matcher import ( @@ -44,6 +46,12 @@ MAX_FUSE_SEARCH_DEPTH = 5 # The maximum tensor size that can go into the fusion group MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# Whether we only fuse nodes with same parent node +FUSE_NODES_WITH_SAME_PARENT = False +# Whether we enable the add broadcast in batch linear +SHAPE_BROADCAST_BATCH_LINEAR = False +# Whether we enable the fuse nodes with same users +Fuse_NODES_WITH_SAME_USERS = False # exclude these nodes from BFS # excluding get item improves optimizer compilation time by 60s @@ -55,6 +63,9 @@ "max_fuse_set_size": MAX_FUSE_SET_SIZE, "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, + "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, + "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, + "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, } graph_search_options = default_graph_search_options @@ -125,14 +136,18 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]: def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any: unsqueezed_inputs = [] + unsqueezed_inputs_meta = [] for input_tensor in input_tensors: unsqueezed_input = graph.call_function( aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} ) unsqueezed_inputs.append(unsqueezed_input) + unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] + unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) stacked_inputs = graph.call_function( aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} ) + stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] return stacked_inputs @@ -165,19 +180,22 @@ class PostGradBatchLinearFusion(BatchFusion): """ def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + # pyre-fixme[7]: Incompatible return type return ( node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value] ) def _is_input_2d(self, input: torch.fx.Node) -> bool: - input_shapes = input.meta["tensor_meta"].shape + input_shapes = input.meta["val"].shape return ( len(input_shapes) == 2 and isinstance(input_shapes[0], int) and isinstance(input_shapes[1], int) ) - def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]: + def match( + self, node: torch.fx.Node + ) -> Optional[Tuple[str, int, int, int, bool, str]]: if CallFunctionVarArgs(aten.mm).match(node): input_m, weight_m = node.args bias_m = None @@ -188,13 +206,17 @@ def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool] bias_m, input_m, weight_m = node.args else: return None - + # get the user of the node + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] # only handle the cases where inputs are 2D tensors if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] return None - m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr] - n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr] - batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None) + m, k = input_m.meta["val"].shape # type: ignore[union-attr] + n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) return batch_key def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): @@ -202,6 +224,9 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): batch_weights = [] batch_biases = [] batch_nodes = [] + batch_inputs_meta = [] + batch_weights_meta = [] + batch_biases_meta = [] for node in subset: if CallFunctionVarArgs(aten.addmm.default).match(node): @@ -213,24 +238,62 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): batch_inputs.append(input) # type: ignore[possibly-undefined] batch_weights.append(weight) # type: ignore[possibly-undefined] batch_biases.append(bias) # type: ignore[possibly-undefined] + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] + if bias is not None: # type: ignore[possibly-undefined] + batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] + else: + batch_biases_meta.append(None) with graph.inserting_before(subset[-1]): fused_inputs = decompose_stack(graph, batch_inputs) fused_weights = decompose_stack(graph, batch_weights) + fused_inputs_meta_val = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + fused_weights_meta_val = torch.stack( + [weight["val"] for weight in batch_weights_meta] + ) fused_bmm = graph.call_function( aten.bmm, args=(fused_inputs, fused_weights), ) - + fused_bmm.meta["val"] = aten.bmm( + fused_inputs_meta_val, fused_weights_meta_val + ) for i, original_mm in enumerate(batch_nodes): has_bias = False with graph.inserting_after(fused_bmm): new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) + new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) if batch_biases[i]: has_bias = True - new_bias_add = graph.call_function( - aten.add, args=((batch_biases[i], new_mm)) - ) + # broadcast the bias to the same shape as the mm output + if self.graph_search_options.get( + "shape_broadcast_batch_linear", False + ): + broadcast_shape = torch.broadcast_shapes( + batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape + ) + broadcast_bias = graph.call_function( + aten.broadcast_to.default, + args=(batch_biases[i],), + kwargs={"size": broadcast_shape}, + ) + broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment] + new_bias_add = graph.call_function( + aten.add.Tensor, args=((broadcast_bias, new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + broadcast_bias.meta["val"], new_mm.meta["val"] + ) + else: + new_bias_add = graph.call_function( + aten.add, args=((batch_biases[i], new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + batch_biases_meta[i]["val"], new_mm.meta["val"] + ) new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] original_mm.replace_all_uses_with(new_mm_cont) new_mm_cont.meta.update(original_mm.meta) @@ -241,8 +304,8 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): @register_fusion("group_linear", pre_grad=False) class GroupLinearFusion(GroupFusion): def _addmm_node_can_be_fused(self, node: torch.fx.Node): - input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr] - weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr] + input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] return ( node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 @@ -256,8 +319,8 @@ def _addmm_node_can_be_fused(self, node: torch.fx.Node): ) def _mm_node_can_be_fused(self, node: torch.fx.Node): - input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr] - weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr] + input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] return ( len(input_shape) == 2 and len(weight_shape) == 2 @@ -319,9 +382,9 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): counters["inductor"]["group_linear"] += 1 -class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): +class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): """ - Batch pointwise operator (e.g., add, mul) in post grad pass. + Batch pointwise math operator (e.g., add, mul) in post grad pass. """ def __init__(self, op, **kwargs): @@ -336,11 +399,11 @@ def _pointwise_node_can_be_fused(self, node: torch.fx.Node): # its inputs, and cause dtype not same error in mm or addmm input, other = node.args return ( - input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr] + input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] if hasattr(input, "meta") and hasattr(other, "meta") - and "tensor_meta" in input.meta # type: ignore[union-attr] - and "tensor_meta" in other.meta # type: ignore[union-attr] + and "val" in input.meta # type: ignore[union-attr] + and "val" in other.meta # type: ignore[union-attr] else False ) @@ -351,14 +414,30 @@ def match(self, node: torch.fx.Node): alpha = node.kwargs.get("alpha", 1.0) rounding_mode = node.kwargs.get("rounding_mode", None) input, other = node.args - shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr] + shape = list(input.meta["val"].shape) # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # only consider the linear case so far + # pyre-fixme[16] + if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] + parent = ( + # pyre-fixme[16] + input.args[0] # type: ignore[union-attr] + # pyre-fixme[16] + if input.target == aten.select # type: ignore[union-attr] + else other.args[0] # type: ignore[union-attr] + ) + else: + parent = "" + else: + parent = "" group_key = ( "batch_aten_" + self.op.__name__.lower().split(".")[0], str(shape), - str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr] - str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr] + str(input.meta["val"].dtype), # type: ignore[union-attr] + str(other.meta["val"].dtype), # type: ignore[union-attr] str(alpha), str(rounding_mode), + str(parent), ) else: group_key = None @@ -367,21 +446,31 @@ def match(self, node: torch.fx.Node): def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): batch_inputs, batch_others = [], [] alpha = subset[0].kwargs.get("alpha", 1.0) + batch_inputs_meta, batch_others_meta = [], [] for node in subset: input, other = node.args batch_inputs.append(input) batch_others.append(other) + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] with graph.inserting_before(subset[0]): stack_inputs = decompose_stack(graph, batch_inputs) stack_others = decompose_stack(graph, batch_others) + stack_inputs_meta = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + stack_others_meta = torch.stack( + [other["val"] for other in batch_others_meta] + ) batch_op = graph.call_function( self.op, args=(stack_inputs, stack_others), kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, ) + batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) for i, original_add in enumerate(subset): with graph.inserting_after(batch_op): new_add = graph.call_function( @@ -475,7 +564,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): def is_node_meta_valid(node: Optional[torch.fx.Node]): if node is None: return True - if "example_value" not in node.meta: + if "example_value" not in node.meta and "val" not in node.meta: return False return True @@ -513,12 +602,17 @@ def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") weight = get_arg_value(node, 1, "weight") bias = get_arg_value(node, 2, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] group_key = ( "batch_linear", self._getitem_args(input), str(input.meta["example_value"].shape), str(weight.meta["example_value"].shape), bias is None, + str(users), ) else: group_key = None @@ -596,6 +690,10 @@ def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") weight = get_arg_value(node, 2, "weight") bias = get_arg_value(node, 3, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] group_key = ( ( "batch_layernorm", @@ -606,6 +704,7 @@ def match(self, node: torch.fx.Node): str(bias.meta["example_value"].shape) if bias is not None else "", str(get_arg_value(node, 1, "normalized_shape")), str(get_arg_value(node, 4, "eps")), + str(users), ) if "example_value" in input.meta and is_node_meta_valid(weight) @@ -761,11 +860,18 @@ def __init__(self, op, **kwargs): def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # pyre-fixme[16] + parent = node.args[0] + parent = parent.target if parent is not None else "" # type: ignore[union-attr] + else: + parent = "" # for relu op, we also use the inplace to construct the key group_key = ( "batch_" + self.op.__name__.lower().split(".")[0], str(input.meta["example_value"].shape), str(node.kwargs.get("inplace", False)), + str(parent), ) else: group_key = None @@ -810,6 +916,63 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. + The introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + # we batch the ops with same parent to enable followup split cat + parent = node.args[0] + parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr] + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(input.meta["val"].shape), + str(node.kwargs.get("inplace", False)), + # pyre-fixme[16] + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["val"]) + + with graph.inserting_before(subset[0]): + stack_inputs = decompose_stack(graph, batch_inputs) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(batch_op): + getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + @register_fusion("batch_tanh") class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): def __init__(self, **kwargs): @@ -828,26 +991,44 @@ def __init__(self, **kwargs): super().__init__(torch.nn.functional.relu, **kwargs) +@register_fusion("batch_aten_tanh", pre_grad=False) +class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.tanh.default, **kwargs) + + +@register_fusion("batch_aten_sigmoid", pre_grad=False) +class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.sigmoid.default, **kwargs) + + +@register_fusion("batch_aten_relu", pre_grad=False) +class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.relu.default, **kwargs) + + @register_fusion("batch_aten_add", pre_grad=False) -class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.add.Tensor, **kwargs) @register_fusion("batch_aten_sub", pre_grad=False) -class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.sub.Tensor, **kwargs) @register_fusion("batch_aten_div", pre_grad=False) -class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.div.Tensor, **kwargs) @register_fusion("batch_aten_mul", pre_grad=False) -class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion): +class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): def __init__(self, **kwargs): super().__init__(aten.mul.Tensor, **kwargs) @@ -1062,5 +1243,10 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): if has_fbgemm: fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) - for rule in fusions: - apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] + for i, rule in enumerate(fusions): + with GraphTransformObserver( + graph.owning_module, + f"group_batch_fusion_{i}", + config.trace.log_url_for_graph_xform, + ): + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 3302dfd63292..ad134decd228 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging import typing @@ -7,8 +8,8 @@ import torch import torch._guards from torch._inductor.constant_folding import ConstantFolder -from torch._inductor.virtualized import V from torch.fx.experimental.symbolic_shapes import statically_known_true +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.multiprocessing.reductions import StorageWeakRef from .. import config @@ -311,11 +312,21 @@ def joint_graph_passes(graph: torch.fx.GraphModule): lazy_init() count = 0 if config.joint_custom_pre_pass is not None: - config.joint_custom_pre_pass(graph.graph) - count += 1 + with GraphTransformObserver( + graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform + ): + config.joint_custom_pre_pass(graph.graph) + count += 1 + + from .post_grad import remove_noop_ops + + remove_noop_ops(graph.graph) if config.joint_graph_constant_folding: - constant_fold_uniform_value(graph) + with GraphTransformObserver( + graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform + ): + constant_fold_uniform_value(graph) if config.pattern_matcher: for patterns in pass_patterns: @@ -325,8 +336,11 @@ def joint_graph_passes(graph: torch.fx.GraphModule): count += replace_random_passes(graph) if config.joint_custom_post_pass is not None: - config.joint_custom_post_pass(graph.graph) - count += 1 + with GraphTransformObserver( + graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform + ): + config.joint_custom_post_pass(graph.graph) + count += 1 if count: stable_topological_sort(graph.graph) @@ -463,8 +477,7 @@ def repl(inp, other): max_ = torch.amax(inp, dim=dim, keepdim=keepdim) return (inp - max_) * (sign * other) - with V.fake_mode: - match.replace_by_example(repl, [inp, other]) + match.replace_by_example(repl, [inp, other]) for reverse, to_dtype in itertools.product((False, True), repeat=2): @@ -491,8 +504,7 @@ def repl(inp, other): max_ = torch.amax(inp, dim=dim, keepdim=keepdim) return (inp - max_) / (sign * other) - with V.fake_mode: - match.replace_by_example(repl, [inp, other]) + match.replace_by_example(repl, [inp, other]) for to_dtype in (False, True): diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py new file mode 100644 index 000000000000..fdac76f75e43 --- /dev/null +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +import operator +from dataclasses import dataclass +from typing import cast, List, Set, Tuple, Union + +import torch + +from ..pattern_matcher import ( + CallFunction, + Ignored, + KeywordArg, + ListOf, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, +) + +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]: + ancestors = set() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return {node for node in ancestors if node.op != "placeholder"} + + +def _can_schedule_y_before_x( + x: torch.fx.Node, y: torch.fx.Node +) -> Tuple[bool, Set[torch.fx.Node]]: + """ + Check if y can be reordered before x and return the ancestors of y + (inclusive). + """ + y_ancestors = _find_ancestors(y) + if x in y_ancestors: + return False, y_ancestors + + return True, y_ancestors + + +@dataclass +class _2DMatmul: + node: torch.fx.Node + B_node: torch.fx.Node + B_node_ancestors: Set[torch.fx.Node] + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + self.node.replace_all_uses_with(new_node) + + +@dataclass +class _NDMatmul: + nodes: List[torch.fx.Node] + B_node: torch.fx.Node + B_node_ancestors: Set[torch.fx.Node] + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + + ND-matmul is a sequence of reshape -> mm -> reshape in the graph. The + second reshape node is replaced with `new_node`. + + In addition, we ensure that the original mm node ends up with zero + users by replacing it with a reverse reshape of `new_node`. + """ + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target == aten.mm.default + assert output_reshape_node.target == aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(mm_node.meta["val"].shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + +def _find_consumer_matmuls(node: torch.fx.Node) -> List[Union[_2DMatmul, _NDMatmul]]: + """ + Find the matmuls that use `node` as the lhs argument. + This function effective normalizes 2D and ND matmuls. + """ + matmuls: List[Union[_2DMatmul, _NDMatmul]] = [] + + for user in node.users: + # ND matmuls + if user.target == aten.reshape.default: + for mm_node in user.users: + if mm_node.target != aten.mm.default: + continue + + B_node = cast(torch.fx.Node, mm_node.args[1]) + can_schedule, B_node_ancestors = _can_schedule_y_before_x(user, B_node) + if not can_schedule: + continue + + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + matmul_out_shape = torch.Size( + [ + *node.meta["val"].shape[:-1], + B_node.meta["val"].shape[-1], + ] + ) + if reshape_node.meta["val"].shape != matmul_out_shape: + continue + + matmuls.append( + _NDMatmul( + nodes=[user, mm_node, reshape_node], + B_node=B_node, + B_node_ancestors=B_node_ancestors, + ) + ) + # 2D matmuls + elif user.target == aten.mm.default: + B_node = cast(torch.fx.Node, user.args[1]) + can_schedule, B_node_ancestors = _can_schedule_y_before_x(user, B_node) + if not can_schedule: + continue + + matmuls.append( + _2DMatmul( + node=user, + B_node=B_node, + B_node_ancestors=B_node_ancestors, + ), + ) + return matmuls + + +def _find_all_gather_node_from_match(match) -> Tuple[torch.fx.Node, torch.fx.Node]: + """ + Processes match for ZeroDimAllGather and NonZeroDimAllGather. Returns the + all-gather node (all_gather_into_tensor.default) and the all-gather result + node (wait_tensor.default for gather_dim == 0 and aten.cat.default for + gather_dim == 1). This function effectively normalizes zero-dim and + non-zero-dim all_gather_tensor. + """ + # gather_dim == 0 + if len(match.nodes) == 2: + return match.nodes[0], match.nodes[1] + # gather_dim == 1 + ag_node = _filter_nodes_by_target( + match.nodes, + torch.ops._c10d_functional.all_gather_into_tensor.default, + )[0] + ag_res_node = _filter_nodes_by_target( + match.nodes, + aten.cat.default, + )[0] + shard_node = ag_node.args[0] + return ag_node, ag_res_node + + +def fuse_all_gather_matmul_zero_dim(match, shard, group_name): + fuse_all_gather_matmul(match, shard, 0, group_name) + + +def fuse_all_gather_matmul(match, shard, gather_dim, group_name): + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops.cuda_p2p.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + from torch.distributed._cuda_p2p import is_cuda_p2p_group + from torch.distributed.distributed_c10d import _resolve_process_group + + if gather_dim >= len(shard.meta["val"].shape) - 1: + # Decomposing the matmul on the K dimension is not supported + return + + if not is_cuda_p2p_group(_resolve_process_group(group_name)): + return + + # Normalize zero-dim and non-zero-dim all_gather_tensor + ag_node, ag_res_node = _find_all_gather_node_from_match(match) + + # Find consumer matmuls for eligible for fusion + matmuls = _find_consumer_matmuls(ag_res_node) + if len(matmuls) == 0: + return + + shard_node = ag_node.args[0] + B_nodes = [matmul.B_node for matmul in matmuls] + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + fused_node = graph.call_function( + torch.ops.cuda_p2p.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes, idx), + ) + matmul.replace_with(new_out_node) + ag_res_node.replace_all_uses_with(new_ag_node) + + # Raise ancestors of B that are topologically ordered between ag_res_node + # and the matmul above fused_node. _find_consumer_matmuls guarantees that + # ag_res_node is not an ancestor of B. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + {x for matmul in matmuls for x in matmul.B_node_ancestors}, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + graph.eliminate_dead_code() + return + + +def fuse_matmul_reduce_scatter_zero_dim(match, rs_input, reduce_op, group_name): + fuse_matmul_reduce_scatter(match, rs_input, reduce_op, 0, group_name) + + +def fuse_matmul_reduce_scatter(match, rs_input, reduce_op, scatter_dim, group_name): + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops.cuda_p2p.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + from torch.distributed._cuda_p2p import is_cuda_p2p_group + from torch.distributed.distributed_c10d import _resolve_process_group + + if not is_cuda_p2p_group(_resolve_process_group(group_name)): + return + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(rs_input.users) != 1: + return + + # 2D matmul + if rs_input.target == aten.mm.default: + A_node, B_node = rs_input.args[0], rs_input.args[1] + # ND matmul + elif rs_input.target == aten.reshape.default: + mm_node = rs_input.args[0] + if mm_node.target != aten.mm.default or len(mm_node.users) != 1: + return + + A_node, B_node = mm_node.args[0], mm_node.args[1] + if A_node.target != aten.reshape.default: + return + A_node = A_node.args[0] + # Not matmul + else: + return + + rs_res_node = _filter_nodes_by_target(match.nodes, c10d.wait_tensor.default)[0] + if not _can_schedule_y_before_x(rs_res_node, B_node): + return + + graph = rs_res_node.graph + with graph.inserting_before(rs_res_node): + fused_node = graph.call_function( + torch.ops.cuda_p2p.fused_matmul_reduce_scatter.default, + args=(A_node, B_node, reduce_op, scatter_dim, group_name), + ) + rs_res_node.replace_all_uses_with(fused_node) + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + _find_ancestors(B_node), + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + graph.eliminate_dead_code() + + +def _register_passes(): + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + c10d = torch.ops._c10d_functional + + # Matches funcol.all_gather_tensor with gather_dim == 0 + ZeroDimAllGather = CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + KeywordArg("shard"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + # NOTE: this pattern may need to be updated if funcol.all_gather_tensor changes + NonZeroDimAllGather = CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + KeywordArg("shard"), + Ignored(), + KeywordArg("group_name"), + ), + ), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + KeywordArg("gather_dim"), + _users=MULTIPLE, + ) + + register_graph_pattern( + ZeroDimAllGather, + pass_dict=patterns, + )(fuse_all_gather_matmul_zero_dim) + + register_graph_pattern( + NonZeroDimAllGather, + pass_dict=patterns, + )(fuse_all_gather_matmul) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + ZeroDimReduceScatter = CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + KeywordArg("rs_input"), + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + # NOTE: this pattern may need to be updated if funcol.reduce_scatter_tensor + # changes + NonZeroDimReduceScatter = CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("rs_input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + ), + ) + + register_graph_pattern( + ZeroDimReduceScatter, + pass_dict=patterns, + )(fuse_matmul_reduce_scatter_zero_dim) + + register_graph_pattern( + NonZeroDimReduceScatter, + pass_dict=patterns, + )(fuse_matmul_reduce_scatter) + + +_register_passes() diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index 76c641e3e8eb..f2d943cab241 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Dict, Set, Tuple diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 3edb4a397932..97d45ae4f5f2 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import operator from functools import reduce @@ -197,9 +198,15 @@ def _binary_fusion_v1(computation_call, binary_fn): def _binary_fusion_v2(computation_call, binary_fn): return CallFunction(binary_fn, computation_call, KeywordArg("other")) - def _is_single_computation_op(computation_op): + def _is_single_computation_op(computation_op, lowp_dtype=None): def fn(match): computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + if len(computation_nodes) < 1: return False if any(n.args[-3] != "none" for n in computation_nodes): @@ -210,7 +217,7 @@ def fn(match): def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): def fn(match): - matched = _is_single_computation_op(computation_op)(match) + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) computation_node = filter_nodes(match.nodes, computation_op)[0] if lowp_dtype: conversion_dtype_nodes = filter_nodes( @@ -249,7 +256,7 @@ def fn(match, *args, **kwargs): def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op) + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) ) def fn(match, *args, **kwargs): negative_slope = kwargs.get("negative_slope") @@ -291,7 +298,7 @@ def fn(match, *args, **kwargs): def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): @register_lowering_pattern( - pattern, extra_check=_is_single_computation_op(computation_op) + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) ) def fn(match, *args, **kwargs): min_value = kwargs.get("min_value") @@ -782,14 +789,22 @@ def get_val(val): def is_linear_add_bias(match): add_node = match.output_node() linear_node = add_node.args[0] - weight_meta = linear_node.args[1].meta.get("val") + packed_weight_node = linear_node.args[1] + assert packed_weight_node.name == "_reorder_linear_weight" + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.name == "permute_default" + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False bias_meta = add_node.args[1].meta.get("val") if weight_meta is None or bias_meta is None: return False return ( linear_node.args[2] is None and bias_meta.dim() == 1 - and bias_meta.size(0) == weight_meta.size(0) + and bias_meta.size(0) == weight_meta.size(1) ) # convert linear+bias to a single linear for applying fusion path. diff --git a/torch/_inductor/fx_passes/numeric_utils.py b/torch/_inductor/fx_passes/numeric_utils.py index 44d0564fe3ea..5bad4ed9489c 100644 --- a/torch/_inductor/fx_passes/numeric_utils.py +++ b/torch/_inductor/fx_passes/numeric_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gc import logging import os diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 626897950746..f7b7977bffc1 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -1,6 +1,8 @@ +# mypy: allow-untyped-defs import functools import itertools import operator +import typing from typing import List, Optional, Union import torch @@ -11,7 +13,14 @@ from torch.utils._mode_utils import no_dispatch from ...utils._triton import has_triton -from ..pattern_matcher import fwd_only, gen_register_replacement, joint_fwd_bwd, Match +from ..pattern_matcher import ( + fwd_only, + gen_register_replacement, + joint_fwd_bwd, + Match, + ReplaceFn, + SearchFn, +) aten = torch.ops.aten @@ -636,22 +645,22 @@ def _pad_mm_init(): for pattern, replacement, args, workaround, extra_check in [ ( - mm_pattern, - mm_replace, + typing.cast(SearchFn, mm_pattern), + typing.cast(ReplaceFn, mm_replace), [dim2a(), dim2b()], {}, should_pad_mm, ), ( - bmm_pattern, - bmm_replace, + typing.cast(SearchFn, bmm_pattern), + typing.cast(ReplaceFn, bmm_replace), [dim3a(), dim3b()], {}, should_pad_bmm, ), ( - addmm_pattern, - addmm_replace, + typing.cast(SearchFn, addmm_pattern), + typing.cast(ReplaceFn, addmm_replace), [dim1a(), dim2a(), dim2b()], rep, should_pad_addmm, diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 585d261787e4..4d1dfe830e01 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging @@ -16,6 +17,7 @@ from torch._utils_internal import upload_graph from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .. import config, ir, pattern_matcher from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage @@ -43,6 +45,7 @@ from ..virtualized import V from .ddp_fusion import fuse_ddp_communication from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS +from .micro_pipeline_tp import patterns as micro_pipeline_tp_patterns from .pre_grad import is_same_dict, save_inductor_dict from .reinplace import reinplace_inplaceable_ops from .split_cat import POST_GRAD_PATTERNS @@ -80,7 +83,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): fake_tensor_updater = FakeTensorUpdater(gm.graph) if config.post_grad_custom_pre_pass is not None: - config.post_grad_custom_pre_pass(gm.graph) + with GraphTransformObserver( + gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform + ): + config.post_grad_custom_pre_pass(gm.graph) if config.pattern_matcher: lazy_init() @@ -103,6 +109,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): f"{pattern_matcher_pass.pass_name}_post_grad" ] = upload_graph(gm.graph) + if config._micro_pipeline_tp: + micro_pipeline_tp_patterns.apply(gm) + if config._fuse_ddp_communication: fuse_ddp_communication( gm.graph, @@ -111,7 +120,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): ) if config.post_grad_custom_post_pass is not None: - config.post_grad_custom_post_pass(gm.graph) + with GraphTransformObserver( + gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform + ): + config.post_grad_custom_post_pass(gm.graph) stable_topological_sort(gm.graph) @@ -268,11 +280,12 @@ def cuda_and_enabled_mixed_mm_and_not_int8(match): KeywordArg("mat2"), 0xF, ), - CallFunction( - aten.__rshift__.Scalar, - KeywordArg("mat2"), - 4, - ), + # CallFunction( + # aten.__rshift__.Scalar, + # KeywordArg("mat2"), + # 4, + # ), + True, ), 1, ), @@ -347,8 +360,7 @@ def repl(*shape): # only replace the output node, not all nodes match.nodes = [match.output_node()] - with V.fake_mode: - match.replace_by_example(repl, list(shape)) + match.replace_by_example(repl, list(shape)) def shape_of_mm(a, b): @@ -708,8 +720,7 @@ def decomp(*flat_args): args, kwargs = pytree.tree_unflatten(flat_args, spec) return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) - with V.fake_mode: - match.replace_by_example(decomp, flat_args, run_dce=False) + match.replace_by_example(decomp, flat_args, run_dce=False) graph_pass.apply(graph) for node in graph.find_nodes( @@ -825,8 +836,7 @@ def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): def repl(inp, x1, x2): return x1 @ x2 + inp - with V.fake_mode: - match.replace_by_example(repl, [inp, mat1, mat2]) + match.replace_by_example(repl, [inp, mat1, mat2]) def is_valid_addmm_fusion(match): @@ -869,8 +879,7 @@ def addmm(match, mat1, mat2, *, inp): def repl(inp, mat1, mat2): return aten.addmm(inp, mat1, mat2) - with V.fake_mode: - match.replace_by_example(repl, [inp, mat1, mat2]) + match.replace_by_example(repl, [inp, mat1, mat2]) def check_shape_cuda_and_fused_int_mm_mul_enabled(match): diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 9af2440eb80b..a93c987fe051 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import logging @@ -11,6 +12,7 @@ matches_module_pattern, replace_node_module, ) +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights @@ -207,7 +209,10 @@ def shape_prop(mod) -> None: inductor_before_change = save_inductor_dict( [pattern_matcher_pass.pass_name] ) - pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + # we support run same pattern multiple times, the default is to run only once + counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) + for _ in range(counter): + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] if not is_same_dict(counters["inductor"], inductor_before_change): optimus_scuba_log[ f"{pattern_matcher_pass.pass_name}_pre_grad" @@ -216,7 +221,10 @@ def shape_prop(mod) -> None: efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] if config.pre_grad_custom_pass is not None: - config.pre_grad_custom_pass(gm.graph) + with GraphTransformObserver( + gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform + ): + config.pre_grad_custom_pass(gm.graph) stable_topological_sort(gm.graph) from .quantization import quant_lift_up @@ -257,16 +265,31 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: # For linear permute fusion, we need to check input info to identify # and perform proper permutation/transpose ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) - gm = linear_permute_fusion(gm) - gm = permute_linear_fusion(gm) - gm = permute_matmul_fusion(gm) + with GraphTransformObserver( + gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform + ): + gm = linear_permute_fusion(gm) + with GraphTransformObserver( + gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform + ): + gm = permute_linear_fusion(gm) + with GraphTransformObserver( + gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform + ): + gm = permute_matmul_fusion(gm) # make sure the autograd is disabled. if torch.is_grad_enabled() or not is_cpu: return gm if config.freezing: - gm = remove_identity(gm) - gm = fuse_conv_bn(gm) + with GraphTransformObserver( + gm, "remove_identity", config.trace.log_url_for_graph_xform + ): + gm = remove_identity(gm) + with GraphTransformObserver( + gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform + ): + gm = fuse_conv_bn(gm) return gm diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 4476a9ccd512..5d2a087face4 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import itertools diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 27730ea17905..bae75aae249d 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import operator from collections import defaultdict diff --git a/torch/_inductor/fx_passes/replace_random.py b/torch/_inductor/fx_passes/replace_random.py index 59d4c3891226..c028eb353791 100644 --- a/torch/_inductor/fx_passes/replace_random.py +++ b/torch/_inductor/fx_passes/replace_random.py @@ -1,8 +1,9 @@ +# mypy: allow-untyped-defs import collections import logging import torch - +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import _extract_tensor_metadata from .. import config, inductor_prims from ..pattern_matcher import ( @@ -24,7 +25,10 @@ def replace_random_passes(gm: torch.fx.GraphModule): return 0 count = patterns.apply(gm) - count += fuse_seed_creation_pass(gm.graph) + with GraphTransformObserver( + gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform + ): + count += fuse_seed_creation_pass(gm.graph) return count diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py index ce678d28833b..55d2216b4e1f 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -42,23 +42,19 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -123,11 +119,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py index a9c38dd92fd0..860ef1c8551f 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -46,7 +46,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -56,18 +56,14 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) @@ -137,7 +133,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -147,17 +143,13 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py index e324c7943e21..d8119c33ed93 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -46,7 +46,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) @@ -55,16 +55,12 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -144,11 +140,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py index 09220864f13e..40834960904a 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -59,11 +59,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -71,8 +67,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) @@ -116,13 +111,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -158,11 +152,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -171,8 +161,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -220,12 +209,11 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py index ad05c6ed4014..bef5eab2bee9 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -38,22 +38,17 @@ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) -alias_default = CallFunction(aten.alias.default, div_Tensor) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor) permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) @@ -78,8 +73,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor) -_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) @@ -96,19 +90,14 @@ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -137,5 +126,4 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value'), _users=0) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py index a25976ad6672..a1e87c009fcc 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -47,7 +47,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) @@ -56,16 +56,12 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -148,11 +144,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py index e5cc2e1cfb61..289585111a54 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -50,7 +50,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) @@ -60,16 +60,12 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) @@ -161,11 +157,7 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py index 8895782436b4..e3c1b5c60235 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -49,7 +49,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -60,11 +60,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -72,8 +68,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -119,13 +114,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -147,7 +141,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -157,11 +151,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -169,8 +159,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) @@ -214,8 +203,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) @@ -256,11 +244,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -269,8 +253,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -320,13 +303,12 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -360,11 +342,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -373,8 +351,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -422,8 +399,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) @@ -451,7 +427,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -463,11 +439,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -476,8 +448,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -524,14 +495,13 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -553,7 +523,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -564,11 +534,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -577,8 +543,7 @@ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -623,8 +588,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py index 225dce51a19a..f741b23c0dd3 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -52,7 +52,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -64,11 +64,7 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -76,8 +72,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) @@ -128,13 +123,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) -view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) _sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -175,11 +169,7 @@ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -188,8 +178,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -244,12 +233,11 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) -view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) _sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py index cf3fe7cff4a2..25c482876a99 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -51,7 +51,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -62,11 +62,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -74,8 +70,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) @@ -126,13 +121,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) _sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, @@ -160,7 +154,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -170,11 +164,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -182,8 +172,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) @@ -232,8 +221,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) @@ -280,11 +268,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -293,8 +277,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -349,13 +332,12 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) _sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, @@ -395,11 +377,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -408,8 +386,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -462,8 +439,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py index c2b71b521b2b..3cba2215bc76 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -57,11 +57,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -69,8 +65,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) @@ -114,8 +109,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -141,7 +135,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) @@ -151,11 +145,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -163,9 +153,8 @@ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -211,8 +200,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py index cdaa975bcfc0..f573cb373491 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -42,23 +42,19 @@ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) @@ -123,11 +119,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py index 481c704f709e..d7eb251ba52d 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -44,7 +44,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -53,11 +53,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -65,8 +61,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) @@ -103,8 +98,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -137,11 +131,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -150,8 +140,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -192,8 +181,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py index d9f8bf2ebc99..773b2be31bde 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -44,7 +44,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) @@ -53,11 +53,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -65,8 +61,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) -clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) -mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) @@ -103,8 +98,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -137,11 +131,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -150,8 +140,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) -clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) @@ -192,8 +181,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py index 64f99e2ac21e..fe481c8293be 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -43,23 +43,19 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -126,11 +122,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py index 9836142aade5..7de8b8229ea8 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -45,7 +45,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) @@ -54,11 +54,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) @@ -66,8 +62,7 @@ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -105,8 +100,7 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default = CallFunction(aten.clone.default, div_Tensor_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) @@ -140,11 +134,7 @@ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2) -convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) @@ -153,8 +143,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) -clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) @@ -196,8 +185,7 @@ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) -clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) -expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py index 87c233a2ae18..ff198232b5e6 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -60,11 +60,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -74,8 +70,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -118,14 +113,13 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -149,7 +143,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -161,11 +155,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -174,8 +164,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -220,13 +209,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py index eb6ffee4614c..8c4b27c8a6fb 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -46,7 +46,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -56,18 +56,14 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) @@ -137,7 +133,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) @@ -147,17 +143,13 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) -mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py index f2456fbef495..78380c1bb341 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -48,7 +48,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -60,11 +60,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -74,8 +70,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) @@ -118,14 +113,13 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) @@ -149,7 +143,7 @@ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) -div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) @@ -161,11 +155,7 @@ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) -alias_default = CallFunction(aten.alias.default, div_Tensor_1) -alias_default_1 = CallFunction(aten.alias.default, alias_default) -alias_default_2 = CallFunction(aten.alias.default, alias_default_1) -alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) -neg_default = CallFunction(aten.neg.default, alias_default_3) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) @@ -174,8 +164,7 @@ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) -clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) -mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) @@ -220,13 +209,12 @@ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) -clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) -clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) -view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) _sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index ad6adf748dd2..8a2c571ee612 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging import operator @@ -54,7 +55,9 @@ ] post_grad_pass_names = [ + "normalization_aten_pass", "decompose_mm_pass", + "unbind_stack_aten_pass", ] for pass_name in pre_grad_pass_names: @@ -78,7 +81,7 @@ ) -def construct_pattern_matcher_pass(pass_name: str) -> PatternMatcherPass: +def construct_pattern_matcher_pass(pass_name: str): """ Return the specific pattern_matcher_pass given the pass name. """ @@ -1609,3 +1612,112 @@ def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int): split_sections = new_split_sections counters["inductor"]["merge_stack_tahn_unbind_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.info("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.warning("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target == torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["unbind_stack_aten_pass"] += 1 diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index 5ccff50c1d45..8f3ed2e9177c 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from collections import defaultdict from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index b7e8a1c48b74..f2bdf22e2d96 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import logging import operator @@ -296,7 +297,6 @@ def __init__( gm: torch.fx.GraphModule, example_inputs: Optional[List[torch.Tensor]] = None, shape_env=None, - num_static_inputs=None, graph_id=None, cpp_wrapper=False, aot_mode=False, @@ -311,7 +311,6 @@ def __init__( name=None, ): super().__init__(gm) - self.example_inputs = example_inputs self.layout_opt = ( layout_opt @@ -374,7 +373,6 @@ def __init__( Callable[[List[ir.ExternKernelNode]], Any] ] = extern_node_serializer self.current_node: torch.fx.Node = None # type: ignore[assignment] - self.num_static_inputs = num_static_inputs self.lists: Dict[str, List[str]] = {} self.mutated_inputs: Set[str] = set() self.mutated_input_idxs: List[int] = [] @@ -802,46 +800,46 @@ def get_original_value_of_constant(self, name: str): else self.constants[name] ) - def add_tensor_constant(self, data, name=None): - def allocate(name): - if not config.aot_inductor.use_runtime_constant_folding: - for constant_name, value in self.constants.items(): - if ( - not data.is_mkldnn - and data.size() == value.size() - and data.stride() == value.stride() - and data.dtype == value.dtype - and data.device == value.device - and data.untyped_storage().data_ptr() - == value.untyped_storage().data_ptr() - and data.storage_offset() == value.storage_offset() - ): - return constant_name - - if name is None: - name = f"constant{len(self.constants)}" - if name[0].isdigit(): - name = f"constant_{name}" - name = self.qualify_name(name) - # We may generate a var name for each constant in the codegen. - # Let's only keep sane characters. - prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) - name = prefix - cnt = 0 - while name in self.constants: - name = f"{prefix}_{cnt}" - cnt += 1 - self.constants[name] = data - self.constant_reprs[name] = ( - f"{data.device!r} {data.dtype!r} " - f"{tuple(data.size())!r} {tuple(data.stride())!r} " - f"{hash(data):x}" - ) - return name - - new_name = allocate(name) - self.allocated_constant_name[new_name] = name + def allocate_non_dup_const_name(self, name, data): + orig_name = name + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() + == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + self.allocated_constant_name[name] = orig_name + return name + def add_tensor_constant(self, data, name=None): + new_name = self.allocate_non_dup_const_name(name, data) return TensorBox.create( ir.ConstantBuffer( new_name, @@ -857,10 +855,13 @@ def constant_name(self, name: str, device_override: Optional[torch.device]): """ if self.constants[name].device == device_override or device_override is None: return name - alt_name = f"{name}_{device_override.type}{device_override.index or 0}" - if alt_name not in self.constants: - self.constants[alt_name] = self.constants[name].to(device_override) - return alt_name + with torch.utils._python_dispatch._disable_current_modes(): + # caller might have set fake tensor mode which will create a fake tensor + # when calling .to, so unset modes here + return self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) def placeholder(self, target: str, args, kwargs): example = super().placeholder(target, args, kwargs) @@ -929,22 +930,7 @@ def get_custom_op_layout_constraints(target, args, kwargs): # which run through implicit fallback must constrain their # arguments' fx strides layout_constraint = None - - def needs_fixed_stride_order(target): - if ( - torch._C.Tag.needs_fixed_stride_order in target.tags - and torch._C.Tag.does_not_need_fixed_stride_order in target.tags - ): - # If both tags were specified, pessimistically assume that we do need it. - return True - if torch._library.utils.is_builtin(target): - return torch._C.Tag.needs_fixed_stride_order in target.tags - else: - return ( - torch._C.Tag.does_not_need_fixed_stride_order not in target.tags - ) - - if needs_fixed_stride_order(target): + if torch._C.Tag.needs_fixed_stride_order in target.tags: # We have to set the current args because call_function will immediately # evaluate this lowering after creating the fallback, without evaluating # the layout constraint @@ -1211,8 +1197,11 @@ def debug(msg): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance(n.meta["val"], torch.SymInt): + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): result = n.meta["val"].node.expr else: result = super().run_node(n) @@ -1681,15 +1670,32 @@ def count_bytes(self): node_runtimes.append((node, node.get_estimated_runtime())) return total_bytes, node_counts, node_runtimes - @dynamo_timed(phase_name="code_gen") + @dynamo_timed(phase_name="code_gen", fwd_only=False) def compile_to_module(self): from .codecache import PyCodeCache code, linemap = ( self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) - linemap = [(line_no, node.stack_trace) for line_no, node in linemap] - key, path = PyCodeCache.write(code) + + output_code_log.debug("Output code: \n%s", code) + try: + linemap = [(line_no, node.stack_trace) for line_no, node in linemap] + key, path = PyCodeCache.write(code) + except Exception: + trace_structured( + "inductor_output_code", + # Just omit the filename, I still want the code though! + payload_fn=lambda: code, + ) + raise + else: + trace_structured( + "inductor_output_code", + lambda: {"filename": path}, + payload_fn=lambda: code, + ) + mod = PyCodeCache.load_by_key_path( key, path, @@ -1706,12 +1712,6 @@ def compile_to_module(self): log_module_code(mod.__file__) log.debug("Output code written to: %s", mod.__file__) - output_code_log.debug("Output code: \n%s", code) - trace_structured( - "inductor_output_code", - lambda: {"filename": mod.__file__}, - payload_fn=lambda: code, - ) output_code_log.info("Output code written to: %s", mod.__file__) if config.benchmark_kernel: print(f"Compiled module path: {mod.__file__}", file=sys.stderr) diff --git a/torch/_inductor/hooks.py b/torch/_inductor/hooks.py index 2b558f4350a7..bf4a8bb090aa 100644 --- a/torch/_inductor/hooks.py +++ b/torch/_inductor/hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Callable, List, TYPE_CHECKING diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 77b73ffd6842..2ec43bce36f0 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file implements the IndexPropagation ops handler, which wraps an underlying handler to add a limited form of constant propagation, as well as propagation of sympy expressions downstream of ops.index_expr calls. diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 0a00650b1c38..c50686d9ee61 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import logging diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 59b563c4e660..da0c1b120676 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import contextlib import dataclasses @@ -44,7 +45,6 @@ is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, - make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -236,7 +236,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2766,6 +2766,7 @@ class FlexibleLayout(Layout): allow_indexing = False + # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -3415,17 +3416,9 @@ def simplify_and_reorder( *body.writes_name2expr.values(), ] - # the reordering_reindex in reads' simplify_reorder_and_tile - reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs) - for i, reads_buf in enumerate(reads_bufs): - if isinstance(reads_buf, ComputedBuffer) and hasattr( - reads_buf, "iter_reordering_reindex" - ): - reordering_reindex[i] = reads_buf.iter_reordering_reindex # type: ignore[has-type] - - def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): + def simplify_and_reorder(x_vars, support_vars, sizes): sizes, reindex0, reindex1 = self._apply_loop_reordering( - x_vars, support_vars, sizes, memory_addrs, reordering_reindex + x_vars, support_vars, sizes, memory_addrs ) # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] x_vars = reindex0(x_vars) @@ -3442,16 +3435,15 @@ def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): return sizes, reindex, reindex1 support_vars = index_vars + reduce_vars - iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder( - index_vars, support_vars, index_size, reordering_reindex + iter_ranges, iter_reindex, _ = simplify_and_reorder( + index_vars, + support_vars, + index_size, ) reduce_ranges, reduce_reindex, _ = simplify_and_reorder( reduce_vars, support_vars, reduce_size ) - # remember the reordering if not have loop collapse. - if len(iter_ranges) == len(index_vars): - self.iter_reordering_reindex = iter_reordering_reindex # retrace the loop body with simplification and reordering applied (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( iter_ranges, reduce_ranges, prefix="z" @@ -3467,7 +3459,6 @@ def _apply_loop_reordering( support_vars, sizes, memory_addrs, - reordering_reindex=None, priority_idx=None, ): """ @@ -3486,14 +3477,6 @@ def _apply_loop_reordering( assert len(strides) == len(memory_addrs) and len(strides[0]) == len( index_vars ) - # consider both layout(strides) and reordering(reordering_reindex) - if reordering_reindex is not None: - for i in range(len(memory_addrs)): - try: - strides[i] = reordering_reindex[i](strides[i]) - # if len(order) != len(strides), do not reorder - except AssertionError: - pass order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) except Exception: if config.debug: @@ -4010,13 +3993,6 @@ def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: def collect_arg_kwarg_properties(self): # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen - if ( - isinstance(self.op_overload, torch._ops.OpOverload) - and not self.ordered_kwargs_for_cpp_kernel - ): - self.ordered_kwargs_for_cpp_kernel = [ - x.name for x in self.op_overload._schema.arguments if x.kwarg_only - ] self.arg_properties = ( [ { @@ -4030,15 +4006,23 @@ def collect_arg_kwarg_properties(self): if isinstance(self.op_overload, torch._ops.OpOverload) else [{} for i in range(len(self.inputs))] ) - self.kwarg_properties = ( + self.allarg_properties = ( { x.name: {"type": x.real_type, "default_value": x.default_value} for x in self.op_overload._schema.arguments - if x.kwarg_only } if isinstance(self.op_overload, torch._ops.OpOverload) else {} ) + # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes + # ordered_kwargs_for_cpp_kernel is explicilty passed in. + if ( + isinstance(self.op_overload, torch._ops.OpOverload) + and not self.ordered_kwargs_for_cpp_kernel + ): + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): # Previously, we want to maintain forward-compatibility by skipping @@ -4400,7 +4384,21 @@ def apply_constraint(self): pass def codegen_const_args(self): - return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + if V.graph.cpp_wrapper: + result = [] + for i, x in enumerate(self.constant_args): + idx = len(self.inputs) + i + type_ = ( + self.arg_properties[i].get("type") + if self.arg_properties and idx < len(self.arg_properties) + else None + ) + result.append( + V.graph.wrapper_code.val_to_arg_str(x, type_) # type: ignore[arg-type] + ) + return result + else: + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) def codegen_args(self): args = [] @@ -4413,10 +4411,10 @@ def codegen_args(self): if V.graph.cpp_wrapper: assert self.arg_properties and i < len( self.arg_properties - ), "Invalid arg_properties accessing" + ), "Invalid access to ExternKernel.arg_properties" type_ = self.arg_properties[i].get("type") args.append( - V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] x, type_ ) ) @@ -4428,10 +4426,10 @@ def codegen_args(self): def get_kwargs_value(self, arg_name): if arg_name in self.kwargs: return self.kwargs.get(arg_name) - if self.kwarg_properties and self.kwarg_properties.get(arg_name): - return self.kwarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + if self.allarg_properties and self.allarg_properties.get(arg_name): + return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] else: - raise AssertionError(f"{arg_name} not in self.kwarg_properties") + raise AssertionError(f"{arg_name} not in self.allarg_properties") def codegen_kwargs(self, skip_out=False): if V.graph.cpp_wrapper: @@ -4446,12 +4444,12 @@ def codegen_kwargs(self, skip_out=False): kwargs.append(v) else: type_ = ( - self.kwarg_properties.get(arg_name).get("type") # type: ignore[union-attr] - if self.kwarg_properties and arg_name in self.kwarg_properties + self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.allarg_properties and arg_name in self.allarg_properties else None ) kwargs.append( - V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + V.graph.wrapper_code.val_to_arg_str( # type: ignore[arg-type] v, type_ ) ) @@ -4777,15 +4775,10 @@ def get_mutation_names(self): def __init__(self, layout, mutated_node, node_doing_mutating): # NB: Do not directly construct this - use `mark_node_as_mutating` - super().__init__(None, layout, [mutated_node], ()) + super().__init__(None, layout, [mutated_node, node_doing_mutating], ()) self.node_doing_mutating = node_doing_mutating self.name = V.graph.register_buffer(self) - def get_read_writes(self): - read_writes = super().get_read_writes() - read_writes.reads.add(dependencies.WeakDep(self.node_doing_mutating.get_name())) - return read_writes - def should_allocate(self): return False @@ -5434,7 +5427,7 @@ def __repr__(self): if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): args = self.fill_non_provided_args(args, kwargs) args = [ - V.graph.wrapper_code.val_to_cpp_arg_str(x, param.real_type) + V.graph.wrapper_code.val_to_arg_str(x, param.real_type) for param, x in zip(self.op_overload._schema.arguments, args) ] else: @@ -5488,6 +5481,9 @@ def export_extern_kernel_node(self): ordered_kwargs = [ kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel ] + if not V.graph.aot_mode: + # No need to serialize in the cpp wrapper JIT mode + return [*args, *ordered_kwargs] serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] @@ -5915,7 +5911,7 @@ def _original_deconv_weight_size( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5967,7 +5963,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6283,7 +6279,7 @@ def create(cls, x, packed_w, orig_w, B, batch_size): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6601,13 +6597,13 @@ def create( def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return make_contiguous_strides_for(output_shape) + return FlexibleLayout.contiguous_strides(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - make_contiguous_strides_for(hy_shape), - make_contiguous_strides_for(cy_shape), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), ] output_ir = [ MultiOutput( @@ -7546,7 +7542,7 @@ def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): """ The heuristic for realizing reused result of heavy ops on cpu """ - heavy_ops = ["exp"] # a list of heavy ops + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops fn_str = loops.inner_fn_str() return any((op + "(") in fn_str for op in heavy_ops) diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index a8650cd32c3f..7d1fbc0b35e8 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import torch diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 205919b48723..f3b4cd8ac430 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 5a1f45e767a7..932bcd50b920 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,12 +1,11 @@ +# mypy: allow-untyped-defs """ Triton Implementation of the flex_attention Kernel""" import logging -import math from enum import auto, Enum from typing import Any, List, Tuple import torch -from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -189,7 +188,8 @@ def build_subgraph_buffer( Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - N_CTX = {{size("Q", 2)}} + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty @@ -197,26 +197,27 @@ def build_subgraph_buffer( start_m = tl.program_id(0) off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh + q_offset = off_hz * stride_qh + kv_offset = off_hz * stride_kh Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), + base=Q + q_offset, + shape=(Q_LEN, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), + base=K + kv_offset, + shape=(BLOCK_DMODEL, KV_LEN), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), + base=V + kv_offset, + shape=(KV_LEN, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -236,7 +237,7 @@ def build_subgraph_buffer( q = (q * qk_scale).to(MATMUL_PRECISION) # loop over k, v and update accumulator lo = 0 - hi = N_CTX + hi = KV_LEN for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- load k, v -- @@ -299,7 +300,7 @@ def build_subgraph_buffer( # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: - l_ptrs = LSE + off_hz * N_CTX + offs_m + l_ptrs = LSE + off_hz * Q_LEN + offs_m lse = m_i + tl.math.log2(l_i) tl.store(l_ptrs, lse) """, @@ -388,7 +389,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - make_contiguous_strides_for(query.get_size()), + FlexibleLayout.contiguous_strides(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -426,6 +427,7 @@ def flex_attention(*args, **kwargs): ], num_stages=num_stages, num_warps=num_warps, + call_sizes=query.get_size(), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=query.get_size()[-1], @@ -446,13 +448,22 @@ def flex_attention(*args, **kwargs): # ---------------------------- Backward HOP Implementation ---------------------------- -def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): +def flex_attention_backward_grid( + batch_size, num_heads, num_queries, d_model, num_key_value, meta +): """How is this kernel parallelized? Currently this is only parallelizing over batch * num_heads, but we can, and want to parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require atomic updates to some grad values or to have a two pass kernel design. """ - return (batch_size * num_heads, 1, 1) + import triton + + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * num_heads, + ) flex_attention_backward_template = TritonTemplate( @@ -468,97 +479,96 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, # DK: Derivative of Key, is the written to via the store_output call due to some limitations with # inductor codegen # M: Number of queries, N: Number of keys/values, D: Model dimension - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim # (Modifiable) Config options: - # BLOCK_M - # BLOCK_N + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the # change of base out of the loop - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad # Define Q Strides stride_qz = {{stride("Q", 0)}} stride_qh = {{stride("Q", 1)}} stride_qm = {{stride("Q", 2)}} - stride_qk = {{stride("Q", 3)}} + stride_qd = {{stride("Q", 3)}} # Define K Strides stride_kz = {{stride("K", 0)}} stride_kh = {{stride("K", 1)}} - stride_kn = {{stride("K", 2)}} - stride_kk = {{stride("K", 3)}} + stride_km = {{stride("K", 2)}} + stride_kd = {{stride("K", 3)}} # Define V Strides stride_vz = {{stride("V", 0)}} stride_vh = {{stride("V", 1)}} - stride_vn = {{stride("V", 2)}} - stride_vk = {{stride("V", 3)}} + stride_vm = {{stride("V", 2)}} + stride_vd = {{stride("V", 3)}} Z = {{size("Q", 0)}} H = {{size("Q", 1)}} - N_CTX = {{size("Q", 2)}} + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} - qk_scale = 1.0 MATMUL_PRECISION = Q.dtype.element_ty - off_hz = tl.program_id(0) + pid = tl.program_id(0) + NUM_KV_BLOCKS = KV_LEN // BLOCK_N1 + + off_hz = tl.program_id(2) off_z = off_hz // H # batch idx off_h = off_hz % H # head idx + off_chz = (off_hz * Q_LEN).to(tl.int64) + q_adj = (stride_qh * (off_hz % H) + stride_qz * (off_hz // H)).to(tl.int64) + k_adj = (stride_kh * (off_hz % H) + stride_kz * (off_hz // H)).to(tl.int64) + v_adj = (stride_vh * (off_hz % H) + stride_vz * (off_hz // H)).to(tl.int64) + # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - - # Asserting contiguous for now... - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_vz + off_h * stride_vh - - # TODO I think that this should be N_CTX/BLOCK_N blocks - for start_n in range(0, NUM_Q_BLOCKS): - # We are not doing the causal optimization yet allowing us to start further down the - # kv column - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_DMODEL) - - # initialize pointers to value-like data - q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) - do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) - - # pointer to row-wise quantities in value-like data - D_ptrs = DELTA + off_hz * N_CTX - l_ptrs = LSE + off_hz * N_CTX - - # initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - - # Key and Value stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - - for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - - if SCORE_MOD_IS_LINEAR: - qk_scale *= 1.44269504 - q = (q * qk_scale).to(MATMUL_PRECISION) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) - pre_mod_scores = qk + Q += q_adj + K += k_adj + V += v_adj + DO += q_adj + DQ += q_adj + DV += v_adj + LSE += off_chz + DELTA += off_chz + + offs_k = tl.arange(0, BLOCK_DMODEL) + + if pid >= NUM_KV_BLOCKS: + # THIS BLOCK DOES DQ + off_pid = pid - NUM_KV_BLOCKS + start_m2 = off_pid * BLOCK_M2 + + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) + + lse = tl.load(LSE + offs_m2) + lse = lse[:, None] + + start_n2 = 0 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + offs_n2 = start_n2 + tl.arange(0, BLOCK_N2) + kT_ptrs = K + offs_n2[None, :] * stride_km + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vm + offs_k[:, None] * stride_vd + Di = tl.load(DELTA + offs_m2) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + curr_n = start_n2 + num_steps = KV_LEN // BLOCK_N2 + for blk_idx in range(num_steps): + offs_n2= curr_n + tl.arange(0, BLOCK_N2) + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = offs_m_curr[:, None] - n = offs_n[None, :] + pre_mod_scores = qk + m = offs_m2[:, None] + n = offs_n2[None, :] {{ modification( subgraph_number=0, output_name="post_mod_scores", @@ -569,25 +579,13 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, n="n", out="qk" ) | indent_except_first(3) }} - # TODO: In the case that score_mod is linear, this can be LICMed + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if not SCORE_MOD_IS_LINEAR: post_mod_scores *= 1.44269504 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(post_mod_scores - l_i[:, None]) - - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) - - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] - - # compute ds = p * (dp - delta[:, None]) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, tl.trans(v)) - ds = p * dp - + p = tl.math.exp2(post_mod_scores - lse).to(MATMUL_PRECISION) + # Compute dP and dS. + dp = tl.dot(do, vT) + ds = p * (dp - Di[:, None]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( subgraph_number=1, @@ -601,32 +599,101 @@ def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, ) | indent_except_first(3) }} ds = grad_scores # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(MATMUL_PRECISION), k) - - # Store grad_query - tl.store(dq_ptrs, dq) - - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - - # write-back - index_n = offs_n[:, None] - index_k = offs_k[None, :] + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += BLOCK_N2 + kT_ptrs += BLOCK_N2 * stride_km + vT_ptrs += BLOCK_N2 * stride_km + # Write back dQ. + dq_ptrs = DQ + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd + tl.store(dq_ptrs, dq) + else: + # THIS BLOCK DOES DK & DV + start_n1 = pid * BLOCK_N1 + start_m1 = 0 + + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n1[:, None] * stride_km + offs_k[None, :] * stride_kd) + v = tl.load(V + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd) + + offs_m1 = start_m1 + tl.arange(0, BLOCK_M1) + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_qm + offs_k[None, :] * stride_qd + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + curr_m = start_m1 + num_steps = Q_LEN // BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load LSE before computing qk to reduce pipeline stall. + offs_m1 = curr_m + tl.arange(0, BLOCK_M1) + lse = tl.load(LSE + offs_m1) + qkT = tl.dot(k, qT) + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m1[None, :] + n = offs_n1[:, None] + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_h", + m="m", + n="n", + out="qkT" + ) | indent_except_first(3) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not SCORE_MOD_IS_LINEAR: + post_mod_scores *= 1.44269504 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do) + Di = tl.load(DELTA + offs_m1) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m1[None, :] + n = offs_n1[:, None] + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_h", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(3) }} + dsT = grad_scores + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) + # Increment pointers. + curr_m += BLOCK_M1 + qT_ptrs += BLOCK_M1 * stride_qm + do_ptrs += BLOCK_M1 * stride_qm - # Store grad_key and grad_value - dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) + dv_ptrs = DV + offs_n1[:, None] * stride_vm + offs_k[None, :] * stride_vd tl.store(dv_ptrs, dv) + # Write back dK. + index_n = offs_n1[:, None] + index_k = offs_k[None, :] # TODO generalize and add proper mask support mask = (index_n != -1) & (index_k != -1) {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} - """, ) @@ -678,7 +745,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - make_contiguous_strides_for(key.get_size()), + FlexibleLayout.contiguous_strides(key.get_size()), ) # Create delta which will is needed for the bwd's kernel @@ -720,12 +787,14 @@ def flex_attention_backward(*args, **kwargs): layout=layout_k, # We use store_output only for grad_key subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], mutated_inputs=[grad_query, grad_value], + call_sizes=query.get_size() + [key.get_size()[2]], num_stages=num_stages, num_warps=num_warps, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, + BLOCK_M1=BLOCK_M, + BLOCK_N1=BLOCK_N, + BLOCK_M2=BLOCK_N, + BLOCK_N2=BLOCK_M, BLOCK_DMODEL=query.get_size()[-1], - NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), # For now, we always assume the "sound" option SCORE_MOD_IS_LINEAR=False, ) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index a90fdbfa33d9..de811fd41c0f 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging from typing import Any, Dict, List, Optional @@ -26,6 +27,7 @@ from .mm_common import ( addmm_epilogue, int8_mm_configs, + mixed_mm_configs, mm_args, mm_configs, mm_grid, @@ -407,7 +409,8 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) skip_triton = ( - mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() + mat1.layout.dtype != torch.float32 + and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed()) ) or _is_sm7x_or_older_gpu(layout.device.index) if inductor_config.force_mixed_mm: @@ -415,7 +418,7 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): if not skip_triton: b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) - for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 97741cc0f8eb..9ffaba040e7f 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging @@ -31,10 +32,10 @@ def filtered_configs( ): """Heuristic to shrink configs when they are bigger than the input size""" - # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424 - # it's safer to use at least [32, 32] block size for int8/uint8 - # tensors - min_block_size = 32 if has_int8_tensor else 16 + min_block_size = 16 + # block_k=16 seems to be causing issues + # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 + min_block_size_k = 32 if has_int8_tensor else 16 m = max( next_power_of_2( V.graph.sizevars.size_hint( @@ -57,14 +58,14 @@ def filtered_configs( k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] ) ), - min_block_size, + min_block_size_k, ) used = set() for block_m, block_n, block_k, num_stages, num_warps in configs: # shrink configs for small sizes block_m = max(min(block_m, m), min_block_size) block_n = max(min(block_n, n), min_block_size) - block_k = max(min(block_k, k), min_block_size) + block_k = max(min(block_k, k), min_block_size_k) # each warp computes 16x16 tile = 256 num_warps = min(num_warps, block_m * block_n // 256) if torch.version.hip: @@ -166,6 +167,18 @@ def filtered_configs( {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, ] +# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). +mixed_mm_kernel_configs_small_m = [ + {"config": (16, 128, 256, 3, 4), "cond": True}, + {"config": (16, 128, 256, 5, 8), "cond": True}, +] + +mixed_mm_kernel_configs = ( + mm_kernel_configs + mixed_mm_kernel_configs_small_m + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else mm_kernel_configs +) + # Create filtered list of configs based on cond evaluation @@ -179,6 +192,11 @@ def filtered_configs( for config in int8_mm_kernel_configs if config["cond"] ) +mixed_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mixed_mm_kernel_configs + if config["cond"] +) # On ROCm convert num_stages to 0 to enable software pipelining if torch.version.hip: @@ -190,6 +208,10 @@ def filtered_configs( (config[0], config[1], config[2], 0, config[4]) for config in mm_platform_configs ) + mixed_mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in mixed_mm_platform_configs + ) mm_configs = functools.partial( filtered_configs, @@ -201,6 +223,11 @@ def filtered_configs( configs=int8_platform_configs, ) +mixed_mm_configs = functools.partial( + filtered_configs, + configs=mixed_mm_platform_configs, +) + def mm_grid(m, n, meta): """ diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index 931aa592556b..f2f810d1fe02 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/_inductor/kernel/unpack_mixed_mm.py b/torch/_inductor/kernel/unpack_mixed_mm.py index c0053b15c16a..c483dbff2b85 100644 --- a/torch/_inductor/kernel/unpack_mixed_mm.py +++ b/torch/_inductor/kernel/unpack_mixed_mm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import List, TYPE_CHECKING diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index fcf77cae6e3a..3b59620c7b89 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging @@ -34,7 +35,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -170,7 +171,7 @@ def is_boolean_type(x): def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): def construct_input(inp): - if isinstance(inp, (Number, sympy.Expr)): + if isinstance(inp, (Number, sympy.Basic)): return inp else: assert hasattr(inp, "get_dtype") @@ -209,7 +210,7 @@ def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): promoting_args = [ a for a in args - if isinstance(a, (Number, sympy.Expr)) + if isinstance(a, (Number, sympy.Basic)) or getattr(a, "dtype", None) is not None ] dtype = get_promoted_dtype( @@ -361,15 +362,15 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No if override_return_dtype is None and type_promotion_kind is None: type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs): + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): return inputs - if all(isinstance(x, (int, float, sympy.Expr)) for x in inputs): + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): dtype = override_return_dtype or get_promoted_dtype( *inputs, type_promotion_kind=type_promotion_kind ) def const_func(x): - if isinstance(x, sympy.Expr): + if isinstance(x, sympy.Basic): return ir.IndexingConstant(x, dtype, decode_device(None)) else: return ir.Constant(x, dtype, decode_device(None)) @@ -384,7 +385,7 @@ def const_func(x): ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) ) ) - elif isinstance(x, sympy.Expr): + elif isinstance(x, sympy.Basic): out.append( ExpandView.create( IndexingConstant(x, ex.get_dtype(), ex.get_device()), @@ -892,7 +893,7 @@ def repeat(x, repeats): if zero_tensor: return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): - return expand(x, new_size) + return clone(expand(x, new_size)) x_loader: Callable[[Any], Any] @@ -2146,7 +2147,6 @@ def is_aligned(x): # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp -make_fallback(aten.avg_pool3d_backward) make_fallback(aten.max_pool3d_with_indices_backward) make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) make_fallback(aten._adaptive_avg_pool3d_backward) @@ -2421,7 +2421,7 @@ def inner_fn(idx): ops.index_expr( ModularIndexing(idx[dim] - start, 1, step), torch.int64 ), - ops.constant(0, torch.torch.int64), + ops.constant(0, torch.int64), ) ) assert mask @@ -2462,7 +2462,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): ranges: List[sympy.Expr] = [] - if isinstance(data, sympy.Expr): + if isinstance(data, sympy.Basic): def inner_fn(index): return ops.index_expr(data, dtype) @@ -2588,7 +2588,7 @@ def _full(fill_value, device, dtype, size): def inner_fn(index): return ops.constant(value, dtype) - elif isinstance(value, sympy.Expr): + elif isinstance(value, sympy.Basic): def inner_fn(index): return ops.index_expr(value, dtype) @@ -2780,18 +2780,29 @@ def gather(x, dim, index, sparse_grad=False): # sparse_grad doesn't affect forward computation, # and backward tracing is taken care of by AOT Autograd assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + assert index.get_dtype() == torch.int64 size = x.get_size() offset = len(size) == 0 dim = _validate_dim(x, dim, offset) + if offset: + x = expand(x, [1]) + size = [1] + x_loader = x.make_loader() index_loader = index.make_loader() def fn(idx): idx = list(idx) - if len(idx) != 0: - idx[dim] = ops.indirect_indexing(index_loader(idx), size[dim]) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx return x_loader(idx) return Pointwise.create( @@ -3271,6 +3282,9 @@ def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = if isinstance(index, TensorBox) and len(index.get_size()) == 0: index = view(index, [1]) + if index.get_numel() == 0: + return self + dim = _validate_dim(self, dim) self.realize() @@ -4011,11 +4025,32 @@ def load(prefix, increments, start_indices, end_indices): return load -def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): - h_start_index_fn, w_start_index_fn = start_index_fns - h_end_index_fn, w_end_index_fn = end_index_fns +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) - def fn_sum(idx, loader): + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): *prefix, bh, bw = idx h_start_index = h_start_index_fn(bh) @@ -4024,7 +4059,7 @@ def fn_sum(idx, loader): w_start_index = w_start_index_fn(bw) w_end_index = w_end_index_fn(bw) - total = None + result = None for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): val = loader( prefix, @@ -4032,13 +4067,66 @@ def fn_sum(idx, loader): [h_start_index, w_start_index], [h_end_index, w_end_index], ) - if total is None: - total = val + if result is None: + result = val else: - total = ops.add(val, total) - return total + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex - return fn_sum + return fn fallback_adaptive_avg_pool2d = fallback_handler( @@ -4076,27 +4164,24 @@ def _adaptive_avg_pool2d(x, output_size): new_size = list(batch) + [h_out, w_out] dtype = x.get_dtype() + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + def start_index(index, out_dim, inp_dim): return FloorDiv((index * inp_dim), out_dim) def end_index(index, out_dim, inp_dim): return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) - h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) - h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) - - w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) - w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) - - window_size = h_kernel_max * w_kernel_max - if window_size > 25: - # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. - return fallback_adaptive_avg_pool2d(x, output_size) - - fn_sum = _adaptive_pooling_idx_sum( - [h_kernel_max, w_kernel_max], - [h_start_index, w_start_index], - [h_end_index, w_end_index], + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, ) ones_loader = pad_adaptive_loader(ones_like(x)) @@ -4116,60 +4201,6 @@ def fn(idx): return rv -def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader): - # NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d - # Look into refactoring/deduplication after #116418 is merged. - h_in, w_in = in_sizes - h_out, w_out = out_sizes - - def start_index(index, out_dim, inp_dim): - return FloorDiv((index * inp_dim), out_dim) - - def end_index(index, out_dim, inp_dim): - return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) - - h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) - h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) - w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) - w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) - - def fn_max(idx): - *prefix, bh, bw = idx - - h_start_index = h_start_index_fn(bh) - h_end_index = h_end_index_fn(bh) - - w_start_index = w_start_index_fn(bw) - w_end_index = w_end_index_fn(bw) - maxval = None - maxindex = None - for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): - val = loader( - prefix, - [ih, iw], - [h_start_index, w_start_index], - [h_end_index, w_end_index], - ) - index = ops.index_expr( - (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 - ) - if return_index: - if maxindex is None: - maxindex = index - else: - maxindex = ops.where(ops.gt(val, maxval), index, maxindex) - if maxval is None: - maxval = val - else: - maxval = ops.maximum(val, maxval) - if return_index: - return maxindex - else: - return maxval - - return fn_max - - fallback_adaptive_max_pool2d = fallback_handler( aten.adaptive_max_pool2d.default, add_to_fallback_set=False ) @@ -4222,32 +4253,46 @@ def adaptive_max_pool2d(x, output_size): # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. return fallback_adaptive_max_pool2d(x, output_size) - inner_func_max_val = _adaptive_pooling_idx_max( + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, kernel_maxes=[h_kernel_max, w_kernel_max], in_sizes=[h_in, w_in], out_sizes=[h_out, w_out], - return_index=False, - loader=pad_adaptive_loader(x, float("-inf")), + pooling_fn=ops.maximum, ) - inner_func_max_idx = _adaptive_pooling_idx_max( + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, kernel_maxes=[h_kernel_max, w_kernel_max], in_sizes=[h_in, w_in], out_sizes=[h_out, w_out], - return_index=True, - loader=pad_adaptive_loader(x, float("-inf")), + pooling_fn=ops.maximum, ) + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + rv = Pointwise.create( device=x.get_device(), dtype=dtype, - inner_fn=inner_func_max_val, + inner_fn=inner_fn_max_val, ranges=new_size, ) ri = Pointwise.create( device=x.get_device(), dtype=torch.int64, - inner_fn=inner_func_max_idx, + inner_fn=inner_fn_max_idx, ranges=new_size, ) return rv, ri @@ -4262,7 +4307,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = (in_sz - kernel_sz) / (out_sz - 1) + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4372,21 +4417,18 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, out_dim) + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) - h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h) - h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h) - - w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w) - w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w) - - fn_sum = _adaptive_pooling_idx_sum( - [h_kernel_max, w_kernel_max], - [h_start_index, w_start_index], - [h_end_index, w_end_index], + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, ) def fn(idx): @@ -4738,6 +4780,207 @@ def fn(idx): return rv +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *batch, depth, height, width = x.get_size() + + d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) + h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + def _validate_reduction_axis(x, axis): size = x.get_size() if isinstance(axis, int): diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 76f15243c5ba..3d8de535542e 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import csv diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index f9f2e66ab68c..721c54385d33 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5630061b4426..1f0a0bc1a6b3 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import ( Any, @@ -138,6 +139,38 @@ def to_dtype( """ ... + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -398,21 +431,23 @@ def isinf(self, x0: T) -> T: def isnan(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties def round(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - def to_int(self, x0: T) -> T: - ... - + # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -449,6 +484,7 @@ def sub(self, x0: T, x1: T) -> T: def mul(self, x0: T, x1: T) -> T: ... + # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -617,14 +653,21 @@ def truncdiv(self, x0: T, x1: T) -> T: def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid: to do - Python style (int, int) -> float division, promote the inputs to float - first.""" + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ ... def div(self, x0: T, x1: T) -> T: @@ -640,6 +683,10 @@ def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" ... + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index 0d5f2d0b2db7..63887b347364 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import sympy diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index b9b66874aba3..7c43b23efdd2 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1,5 +1,44 @@ +""" +# Inductor Pattern Matcher + +The pattern matcher enables search/replace within an FX graph. + +The main entrypoint to the pattern matcher is register_replacement(). Given a +search function and a replacement function this will register a replacement with +a pass (such as torch._inductor.fx_passes.joint_graph.patterns). + +Internally the pattern matcher represents patterns as a graph (a DAG). Creating +new patterns manually as a graph is cumbersome and error-prone so the standard +way to create patterns (using register_replacement()) is to provide a search +function and a replacement function which is traced and converted into a graph. + +Because the search functions are built somewhat generic (they tend to ignore +tensor sizes, for example) register_replacement() allows you to specify an +`extra_check` function which performs additional checks to verify that the +matched pattern fully matches before returning it. + +## Precompiled Patterns + +New patterns are added using register_replacement(). Patterns added in this way +can have a compile-time overhead because they need to be traced before +use. Patterns can be precompiled and added using gen_register_replacement() +instead. To do this you call gen_register_replacement() instead of +register_replacement(). The arguments are the same except for an additional +unique name which is used as a lookup key. + +## Internals + +The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr +implements a `_match` method which returns either a `Match` object for a +successful match or a `FailedMatch` object for a failure to match. +""" + +# mypy: disallow-untyped-defs + from __future__ import annotations +import contextlib + import dataclasses import functools import importlib @@ -11,6 +50,7 @@ import re import textwrap import typing +from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path from typing import ( @@ -18,12 +58,18 @@ Callable, DefaultDict, Dict, + Generator, Iterable, List, + Mapping, NoReturn, Optional, + Protocol, + Sequence, Set, Tuple, + Type, + TypeVar, Union, ) from typing_extensions import Self, TypeGuard @@ -34,10 +80,12 @@ import torch.utils._pytree as pytree from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import counters +from torch._inductor.config import trace as trace_config from torch._prims_common import is_integer_dtype from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from .._functorch import config as functorch_config from .._functorch.aot_autograd import aot_function, make_boxed_func @@ -48,9 +96,6 @@ from .decomposition import select_decomp_table from .lowering import fallback_node_due_to_unsupported_type -if typing.TYPE_CHECKING: - from torch.fx import Node - log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims @@ -59,8 +104,33 @@ NodeOrConstant = Union[Constant, torch.fx.Node] +class SearchFn(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class ReplaceFn(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +class TraceFn(Protocol): + def __call__( + self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any + ) -> torch.fx.GraphModule: + ... + + +T = TypeVar("T") + +# What's a better name for this? +FnsType = Union[torch.fx.node.Target, str] + + class Multiple: - def __init__(self): + def __init__(self) -> None: # Ensure we're really a singleton. assert "MULTIPLE" not in globals() or self is MULTIPLE @@ -72,27 +142,47 @@ def __init__(self): class Match: """ Represents a successfully matched pattern. + + The `Match` object is returned to represent a successfully matched + pattern. Included in the Match are the pattern that was matched, the graph + nodes matched, and any args that were used during the matching. + + The args and kwargs are specific to the type of pattern that was matched and + provide hints about what was matched. """ - def __init__(self, pattern: PatternExpr, args=None, kwargs=None): + pattern: PatternExpr + args: List[Any] + kwargs: Dict[str, Any] + nodes: List[torch.fx.Node] + targets: Dict[_TargetExpr, torch.fx.node.Target] + ctx: MatchContext + replacement_graph: Optional[torch.fx.Graph] + + def __init__( + self, + ctx: MatchContext, + pattern: PatternExpr, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> None: super().__init__() self.pattern = pattern # The input nodes that must be passed in to the result - self.args = args or [] + self.args = list(args or []) self.kwargs = kwargs or {} # The nodes matched in this expression - self.nodes: List[torch.fx.Node] = [] + self.nodes = [] # Mapping CallFunction to the node.target - self.targets: Dict[_TargetExpr, torch.fx.node.Target] = {} - self.ctx: Optional[MatchContext] = None - self.replacement_graph: Optional[torch.fx.Graph] = None + self.targets = {} + self.ctx = ctx + self.replacement_graph = None @property def graph(self) -> torch.fx.Graph: - assert self.ctx return self.ctx.graph - def extend(self, other: Match): + def extend(self, other: Match) -> None: if self.kwargs: for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): if self.kwargs[key] != other.kwargs[key]: @@ -107,16 +197,15 @@ def bundle(self) -> Match: self.args = [tuple(self.args)] if self.args else [] return self - def __repr__(self): + def __repr__(self) -> str: return f"Match(..., {self.args}, {self.kwargs})" - def erase_nodes(self, graph: torch.fx.Graph): + def erase_nodes(self, graph: torch.fx.Graph) -> None: for n in reversed(self.nodes): if not n._erased: graph.erase_node(n) def output_nodes(self) -> List[Optional[torch.fx.Node]]: - assert self.ctx return [ (self.ctx.pattern_to_node[p] if p is not None else None) for p in self.ctx.outputs @@ -125,29 +214,49 @@ def output_nodes(self) -> List[Optional[torch.fx.Node]]: def output_node(self) -> torch.fx.Node: return next(p for p in self.output_nodes() if p) - def replace_with_graph(self, replacement_graph, args): - assert self.ctx + def replace_with_graph( + self, replacement_graph: torch.fx.Graph, args: Sequence[Any] + ) -> None: ReplacementPatternEntry.replace_with_graph( self, self.ctx.graph, replacement_graph, args ) - def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True): - assert self.ctx - if trace_fn is None: - trace_fn = functools.partial(fwd_only, run_dce=run_dce) - replacement = trace_fn( - replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) - ) - ReplacementPatternEntry.replace_with_graph( - self, - self.ctx.graph, - replacement, - args, - ) + def replace_by_example( + self, + replacement_fn: ReplaceFn, + args: Sequence[Any], + trace_fn: Optional[TraceFn] = None, + run_dce: bool = True, + ) -> None: + from torch._inductor.virtualized import V + + context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext + + with context: + if trace_fn is None: + trace_fn = functools.partial(fwd_only, run_dce=run_dce) + replacement = trace_fn( + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + ) + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) class FailedMatch(RuntimeError): - def __init__(self, format_string, *args, **kwargs): + """ + Represents a unsuccessful match. + + The `FailedMatch` object is returned to represent a failure to match a + pattern. + """ + + format_string: str + + def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: self.format_string = format_string # We want to construct error messages lazily instead of eagerly, as # constructing them eagerly can significantly worsen compile times. @@ -158,14 +267,17 @@ def __init__(self, format_string, *args, **kwargs): self.args = args self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: return self.format_string.format(*self.args, **self.kwargs) - def __bool__(self): + def __bool__(self) -> bool: return False -def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]: +MatchResult = Union[Match, FailedMatch] + + +def is_match(m: MatchResult) -> TypeGuard[Match]: """ TypeGuards cannot act on `self`. Thus this function exists to let mypy recognize FailedMatch.__bool__ as a TypeGuard. @@ -175,35 +287,39 @@ def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]: class MatchContext: """ - State needed while running PatternExpr._match(). + Internal state needed while running PatternExpr._match(). """ + outputs: List[Optional[PatternExpr]] + pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]] + graph: torch.fx.Graph + exclusive_node_set: List[NodeOrConstant] + def __init__( self, outputs: List[Optional[PatternExpr]], - pattern_to_node: Optional[Dict[PatternExpr, Node]] = None, + pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None, *, graph: torch.fx.Graph, - ): + ) -> None: self.outputs = outputs - self.pattern_to_node = {} if pattern_to_node is None else pattern_to_node + self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) self.graph = graph - self.exclusive_node_set: List[NodeOrConstant] = [] + self.exclusive_node_set = [] - def match(self, pattern, node): + def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: """wrapper to check reused nodes in patterns""" if pattern in self.pattern_to_node: if self.pattern_to_node[pattern] == node: - return Match(pattern) # already checked this node + return Match(self, pattern) # already checked this node else: return FailedMatch("repeated pattern differs") m = pattern._match(node, self) assert pattern not in self.pattern_to_node self.pattern_to_node[pattern] = node if m else None - m.ctx = self return m - def filter_multi_user_patterns(self): + def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: return { pattern: node for pattern, node in self.pattern_to_node.items() @@ -211,17 +327,16 @@ def filter_multi_user_patterns(self): } -class PatternExpr: +class PatternExpr(ABC): """ - Base class for types of patterns + Base class for types of patterns. """ - def _match( - self, node: torch.fx.Node, ctx: MatchContext - ) -> Union[Match, FailedMatch]: - raise NotImplementedError + @abstractmethod + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + ... - def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + def match(self, node: torch.fx.Node) -> MatchResult: try: return MatchContext([self], graph=node.graph).match(self, node) except FailedMatch as e: @@ -230,10 +345,12 @@ def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: def has_multiple_users(self) -> bool: return False - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ + "()" - def find_anchor_nodes(self, ctx: MatchContext, searched): + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: if self in ctx.pattern_to_node: yield ctx.pattern_to_node[self] @@ -252,8 +369,8 @@ class Arg(PatternExpr): passed in depth first order. """ - def _match(self, node: NodeOrConstant, ctx: MatchContext): - return Match(self, args=[node]) # matches anything + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, args=[node]) # matches anything class Ignored(PatternExpr): @@ -261,13 +378,13 @@ class Ignored(PatternExpr): Match an arg, but don't pass it to handler """ - def _match(self, node: NodeOrConstant, ctx: MatchContext): - return Match(self) # matches anything + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self) # matches anything - def __repr__(self): + def __repr__(self) -> str: return "*" - def pretty_print(self, pp: PatternPrettyPrinter): + def pretty_print(self, pp: PatternPrettyPrinter) -> str: return "Ignored()" @@ -276,15 +393,15 @@ class KeywordArg(PatternExpr): Capture a kwarg which will become an input to the handler. """ - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__() self.name = name - def __repr__(self): + def __repr__(self) -> str: return f"KeywordArg({self.name!r})" - def _match(self, node: NodeOrConstant, ctx: MatchContext): - return Match(self, kwargs={self.name: node}) # matches anything + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, kwargs={self.name: node}) # matches anything def pattern_eq(self, other: Any) -> bool: other = typing.cast(Self, other) # super makes sure this is true @@ -296,19 +413,21 @@ class ExclusiveKeywordArg(PatternExpr): Capture a kwarg which will become an input to the handler. """ - def __init__(self, name): + name: str + + def __init__(self, name: str) -> None: super().__init__() self.name = name - def __repr__(self): + def __repr__(self) -> str: return f"ExclusiveKeywordArg({self.name!r})" - def _match(self, node: NodeOrConstant, ctx: MatchContext): + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: if node in ctx.exclusive_node_set: return FailedMatch("exclusive arg appears twice") ctx.exclusive_node_set.append(node) - return Match(self, kwargs={self.name: node}) # matches anything + return Match(ctx, self, kwargs={self.name: node}) # matches anything def pattern_eq(self, other: Any) -> bool: other = typing.cast(Self, other) # super makes sure this is true @@ -320,21 +439,27 @@ class _TargetExpr(PatternExpr): Base class for filtering match by node.target """ - op: Optional[str] = None + fns: List[FnsType] + fns_set: Set[FnsType] - def __init__(self, fns, users: Union[Multiple, int] = 1): - if not self.op: - raise NotImplementedError("Shouldn't directly use _BaseNodeMatch") + def __init__( + self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 + ) -> None: super().__init__() fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) - for fn in list(fns): + for fn in fns: if isinstance(fn, torch._ops.OpOverloadPacket): - fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + fns.extend(getattr(fn, overload) for overload in fn.overloads()) - self.fns: List[Union[Callable[..., Any], str]] = fns - self.fns_set: Set[Union[Callable[..., Any], str]] = set(fns) + self.fns = fns + self.fns_set = set(fns) self.users = users + @property + @abstractmethod + def op(self) -> str: + ... + def fns_repr(self) -> str: first_repr = self.fns[0] if not isinstance(first_repr, str): @@ -349,7 +474,7 @@ def fns_repr(self) -> str: else: return first_repr - def __repr__(self): + def __repr__(self) -> str: if self.users is MULTIPLE: comma_users = ", MULTIPLE" elif self.users != 1: @@ -361,17 +486,19 @@ def __repr__(self): def has_multiple_users(self) -> bool: return isinstance(self.users, Multiple) or self.users > 1 - def find_anchor_nodes(self, ctx: MatchContext, searched): + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: raise NotImplementedError - def _match_fns(self, node: torch.fx.Node): + def _match_fns(self, node: torch.fx.Node) -> bool: return ( isinstance(node, torch.fx.Node) and node.op == self.op and extract_target(node) in self.fns_set ) - def _match_users(self, node: torch.fx.Node, ctx: MatchContext): + def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: return ( self in ctx.outputs or self.users is MULTIPLE @@ -388,12 +515,21 @@ def pattern_eq(self, other: Any) -> bool: ) +_SimpleSpec = Tuple[Any, ...] + + class _TargetArgsExpr(_TargetExpr): """ Base class for filtering match by node.{target,args,kwargs} """ - def __init__(self, fns, *args, _users=1, **kwargs): + def __init__( + self, + fns: Union[torch.fx.node.Target, str, Sequence[Any]], + *args: Any, + _users: Union[int, Multiple] = 1, + **kwargs: Any, + ) -> None: super().__init__(fns, _users) self.args = tuple(args) self.kwargs = dict(kwargs) @@ -407,12 +543,18 @@ def __init__(self, fns, *args, _users=1, **kwargs): self.flat_args_kwargs = self.flatten(self.args, self.kwargs) @staticmethod - def simple_flatten(args, kwargs: Dict[Any, Any]): - return (*args, *kwargs.values()), (len(args), *kwargs.keys()) + def simple_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + values = (*args, *kwargs.values()) + spec = (len(args), *kwargs.keys()) + return values, spec @staticmethod - def pytree_flatten(args, kwargs: Dict[Any, Any]): - def norm_spec(s: pytree.TreeSpec): + def pytree_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: if s.type is None: return s mapping = {immutable_list: list, tuple: list, immutable_dict: dict} @@ -426,7 +568,7 @@ def norm_spec(s: pytree.TreeSpec): spec = norm_spec(spec) return flat, spec - def __repr__(self): + def __repr__(self) -> str: args = [ self.fns_repr(), *map(repr, self.args), @@ -438,7 +580,7 @@ def __repr__(self): args.append(f"_users={self.users}") return f"{self.__class__.__name__}({', '.join(args)})" - def pretty_print(self, pp: PatternPrettyPrinter): + def pretty_print(self, pp: PatternPrettyPrinter) -> str: args = [ self.fns_repr(), *(pp.pretty_print(x) for x in self.args), @@ -452,7 +594,7 @@ def pretty_print(self, pp: PatternPrettyPrinter): joiner_str = ", " return f"{self.__class__.__name__}({joiner_str.join(args)})" - def _match(self, node: torch.fx.Node, ctx: MatchContext): + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: if not self._match_fns(node) or len(node.args) != len(self.args): return FailedMatch("function_mismatch: node={}, pattern={}", node, self) @@ -487,11 +629,11 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext): return FailedMatch("args_structure {} {}", node_spec, self_spec) assert len(node_items) == len(self_items) - m = Match(self) + m = Match(ctx, self) for i, pattern, child_node in zip(itertools.count(), self_items, node_items): if isinstance(pattern, PatternExpr): child_match = ctx.match(pattern, child_node) - if not child_match: + if not is_match(child_match): return child_match m.extend(child_match) elif isinstance(child_node, torch.fx.Node) or child_node != pattern: @@ -502,7 +644,9 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext): m.targets[self] = node.target return m - def find_anchor_nodes(self, ctx: MatchContext, searched): + def find_anchor_nodes( + self, ctx: MatchContext, searched: Set[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: """ This is used when we are matching a pattern with multiple outputs. There is a partial match (stored in ctx) and we want to walk @@ -566,14 +710,14 @@ class _TargetExprVarArgs(_TargetExpr): Matches a call_function node with any arguments which are passed into the pattern """ - def _match(self, node: torch.fx.Node, ctx: MatchContext): + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: if not self._match_fns(node): return FailedMatch("function_mismatch") if not self._match_users(node, ctx): return FailedMatch("multiple_users") - m = Match(self) + m = Match(ctx, self) m.nodes.append(node) m.targets[self] = node.target m.args.extend(node.args) @@ -598,19 +742,19 @@ class ListOf(PatternExpr): Matches a repeated pattern """ - def __init__(self, pattern: PatternExpr, partial=False): + def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: super().__init__() assert isinstance(pattern, PatternExpr) self.pattern = pattern self.partial = partial - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.pattern})" - def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override] + def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] if not isinstance(node, (list, tuple)) or len(node) == 0: return FailedMatch("non_list") - m = Match(self) + m = Match(ctx, self) # Propagating patterns with multiple users will ensure we don't revisit # the same nodes pattern_to_node = ctx.filter_multi_user_patterns() @@ -621,7 +765,7 @@ def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[ ) child_match = child_ctx.match(self.pattern, child_node) pattern_to_node = child_ctx.filter_multi_user_patterns() - if not child_match: + if not is_match(child_match): if not self.partial: return FailedMatch("list[{}]: {}", i, child_match) continue @@ -641,54 +785,61 @@ def pattern_eq(self, other: Any) -> bool: class MultiOutputPattern(PatternExpr): - def __init__(self, outputs): + outputs: List[Optional[PatternExpr]] + + def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: super().__init__() - assert all(isinstance(x, (PatternExpr, type(None))) for x in outputs), outputs - self.outputs: List[Optional[PatternExpr]] = outputs + assert isinstance(outputs[0], _TargetExpr) + assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs + self.outputs = list(outputs) self.op = outputs[0].op @property - def fns(self): - assert self.outputs[0] and hasattr(self.outputs[0], "fns") - return self.outputs[0].fns + def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: + # This cast is checked above in __init__() + output = typing.cast(_TargetExpr, self.outputs[0]) + return output.fns - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.outputs})" - def pretty_print(self, pp: PatternPrettyPrinter): + def pretty_print(self, pp: PatternPrettyPrinter) -> str: args = [pp.pretty_print(x) for x in self.outputs] joiner_str = f",\n{' '}" str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" str_out = f"{str_out}\n])" return str_out - def _match(self, node: torch.fx.Node, ctx: MatchContext): - m = ctx.match(self.outputs[0], node) - if not m: + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + output = typing.cast(_TargetExpr, self.outputs[0]) + m = ctx.match(output, node) + if not is_match(m): return m for pattern in self.outputs[1:]: if pattern is None: continue child_match = self._match_from_anchors(pattern, ctx) - if not child_match: + if not is_match(child_match): return child_match m.extend(child_match) return m - def _match_from_anchors(self, pattern, ctx): + def _match_from_anchors( + self, pattern: PatternExpr, ctx: MatchContext + ) -> MatchResult: prior = dict(ctx.pattern_to_node) - m = FailedMatch("no anchor found") + m: MatchResult = FailedMatch("no anchor found") for node in pattern.find_anchor_nodes(ctx, set()): m = ctx.match(pattern, node) - if m: + if is_match(m): return m # revert any partial matches ctx.pattern_to_node = dict(prior) return m - def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + def match(self, node: torch.fx.Node) -> MatchResult: try: return MatchContext(self.outputs, graph=node.graph).match(self, node) except FailedMatch as e: @@ -711,19 +862,18 @@ class RepeatedExpr(PatternExpr): Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` """ - def __init__(self, inner_pattern: PatternExpr): + def __init__(self, inner_pattern: _TargetExpr) -> None: super().__init__() - assert hasattr(inner_pattern, "fns") self.inner_pattern = inner_pattern - self.op = inner_pattern.op # type: ignore[attr-defined] + self.op = inner_pattern.op @property - def fns(self): + def fns(self) -> Sequence[FnsType]: return self.inner_pattern.fns - def _match(self, node: torch.fx.Node, ctx: MatchContext): + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: m = ctx.match(self.inner_pattern, node) - if not m: + if not is_match(m): return m ctx.pattern_to_node.pop( self.inner_pattern, @@ -733,7 +883,7 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext): anchor_m = MatchContext([self], graph=node.graph).match( self.inner_pattern, anchor_node ) - if not anchor_m: + if not is_match(anchor_m): return anchor_m m.extend(anchor_m) return m @@ -752,13 +902,14 @@ class PatternPrettyPrinter: all patterns. """ - def __init__(self): + def __init__(self) -> None: self.namespace = torch.fx.graph._Namespace() self.memoized_objs_names: Dict[PatternExpr, str] = {} self.memoized_objs_pp: Dict[PatternExpr, str] = {} @staticmethod - def run(obj: PatternExpr, output_name="output"): + @functools.lru_cache(None) + def run(obj: PatternExpr, output_name: str = "output") -> str: """ Serializes obj to python code with obj written out to `output_name` """ @@ -775,7 +926,7 @@ def run(obj: PatternExpr, output_name="output"): return "\n".join(output) - def pretty_print(self, obj): + def pretty_print(self, obj: Any) -> str: if isinstance(obj, _TargetArgsExpr): if memoized_name := self.memoized_objs_names.get(obj): return memoized_name @@ -786,7 +937,7 @@ def pretty_print(self, obj): return repr(obj) - def memoize(self, obj): + def memoize(self, obj: _TargetArgsExpr) -> str: obj_str = obj.pretty_print(self) obj_name = obj.fns_repr() for prefix in ("aten.", "torch.", "prims."): @@ -798,15 +949,25 @@ def memoize(self, obj): return tmp_name +class _PassDictsType(Protocol): + def __getitem__(self, k: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + ... + + @dataclasses.dataclass class PatternEntry: pattern: PatternExpr extra_check: Callable[[Match], bool] - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: raise NotImplementedError - def register(self, pass_dicts, target=None, prepend=False): + def register( + self, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + target: Union[torch.fx.node.Target, None] = None, + prepend: bool = False, + ) -> None: if target is None: assert hasattr(self.pattern, "fns") for fn in self.pattern.fns: @@ -818,6 +979,7 @@ def register(self, pass_dicts, target=None, prepend=False): else: pass_dicts[(self.pattern.op, target)].append(self) else: + pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) for x in pass_dicts: self.register(x, target, prepend=prepend) @@ -826,7 +988,7 @@ def register(self, pass_dicts, target=None, prepend=False): class LoweringPatternEntry(PatternEntry): handler: Callable[..., Any] - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) with graph.inserting_before(node): replacement = graph.call_function(handler, tuple(match.args), match.kwargs) @@ -844,7 +1006,7 @@ class GraphPatternEntry(PatternEntry): handler: Callable[..., Any] - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: with graph.inserting_before(node): self.handler(match, *match.args, **match.kwargs) @@ -857,9 +1019,9 @@ class ReplacementPatternEntry(PatternEntry): def replace_with_graph( match: Match, graph: torch.fx.Graph, - replacement_graph: torch.fx.Graph, - args: List[Any], - ): + replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], + args: Sequence[torch.fx.Node], + ) -> None: output_nodes = match.output_nodes() first_node = output_nodes[0] @@ -868,7 +1030,7 @@ class Replacer(torch.fx.Interpreter): call_module = None # type: ignore[assignment] get_attr = None # type: ignore[assignment] - def run_node(self, node) -> Any: + def run_node(self, node: torch.fx.Node) -> Any: if node.op in ("placeholder", "output"): return super().run_node(node) if node.op == "call_function": @@ -897,7 +1059,9 @@ def run_node(self, node) -> Any: ] last_node = min(indices, key=operator.itemgetter(0))[1] - def percolate_tags(node, recompute_tag, input_stops): + def percolate_tags( + node: torch.fx.Node, recompute_tag: str, input_stops: Set[torch.fx.Node] + ) -> None: queue = [node] visited = set() @@ -917,7 +1081,7 @@ def percolate_tags(node, recompute_tag, input_stops): if isinstance(replacement, torch.fx.Node): replacement = [replacement] - def maybe_getitem(node): + def maybe_getitem(node: torch.fx.Node) -> Any: if node.op != "call_function": return None if node.target != operator.getitem: @@ -925,7 +1089,10 @@ def maybe_getitem(node): assert len(node.args) == 2 return node.args[1] - def replace(old, new): + def replace( + old: Union[torch.fx.Node, None], + new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], + ) -> None: if old is None: assert new is None return @@ -947,12 +1114,13 @@ def replace(old, new): # recomputable tags. It is possible in some scenarios that we # incorrectly tag some nodes as recomputables. if "recompute" in old.meta: - percolate_tags(new, old.meta["recompute"], args) + percolate_tags(new, old.meta["recompute"], set(args)) old.replace_all_uses_with(new) graph.erase_node(old) return + new = typing.cast(Sequence[torch.fx.Node], new) # `new` is not a node: it's a list of nodes. # # This happens when we want to replace a node that has a single @@ -997,20 +1165,21 @@ def replace(old, new): match.erase_nodes(graph) - def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + assert match.replacement_graph is not None self.replace_with_graph( match, graph, - match.replacement_graph, # type: ignore[arg-type] + match.replacement_graph, self.normalize_args(*match.args, **match.kwargs), ) -def _return_true(match): +def _return_true(match: Match) -> bool: return True -def log_trace_failure(search_fn, e): +def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: log.info( "Replacement pattern %s failed to apply due to shape mismatch: %s", search_fn.__name__, @@ -1019,16 +1188,16 @@ def log_trace_failure(search_fn, e): def register_replacement( - search_fn, - replace_fn, + search_fn: SearchFn, + replace_fn: ReplaceFn, example_inputs: Iterable[Any], - trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - pass_dicts, - extra_check=_return_true, - scalar_workaround=(), - exclusive_arg_names=(), - search_fn_pattern=None, -): + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + search_fn_pattern: Union[PatternExpr, None] = None, +) -> bool: """ Create a replacement rule based on example functions that get traced to create patterns. This supports both training and inference when @@ -1044,7 +1213,7 @@ def register_replacement( """ argnames_static = [*inspect.signature(search_fn).parameters.keys()] - def check_fn(match: Match): + def check_fn(match: Match) -> bool: """ Often shapes get burned into the pattern, so our initial match ran with `ignore_types=(int, ...)`. @@ -1098,7 +1267,7 @@ def check_fn(match: Match): # Later, when we actually do the replacement, the symbolic shape # sizes will get re-traced and added to the graph. - def search_fn_new(*args_new): + def search_fn_new(*args_new: Any) -> Any: return search_fn(*args_new[len(args_new) - len(args) :]) try: @@ -1140,15 +1309,17 @@ def search_fn_new(*args_new): scalar_workaround=scalar_workaround, ) - specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type] + node = match.output_nodes()[0] + assert node is not None + specific_pattern_match = specific_pattern.match(node) - if specific_pattern_match and extra_check(specific_pattern_match): + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): # trace the pattern using the shapes from the user program match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] return True return False - def normalize_args(**kwargs): + def normalize_args(**kwargs: Any) -> List[Any]: args = [] for name in argnames_static: args.append(kwargs.pop(name)) @@ -1198,11 +1369,11 @@ def normalize_args(**kwargs): def _serialize_pattern( unique_name: str, - search_fn, + search_fn: SearchFn, example_inputs: Iterable[Any], - trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - scalar_workaround, -): + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None], +) -> PatternExpr: def get_file_template() -> str: auto_generated_msg = textwrap.dedent( """\ @@ -1266,6 +1437,8 @@ def get_file_template() -> str: f.write(serialized_pattern) f.write("\n") + return pattern + SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" @@ -1278,22 +1451,23 @@ def get_file_template() -> str: Iterable[Any], Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], Any, - str, + PatternExpr, ] ] = [] def gen_register_replacement( unique_name: str, - search_fn, - replace_fn, + search_fn: SearchFn, + replace_fn: ReplaceFn, example_inputs: Iterable[Any], - trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], - pass_dicts, - extra_check=_return_true, - scalar_workaround=(), - exclusive_arg_names=(), -): + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + skip_duplicates: bool = False, +) -> None: # Make sure the example_inputs is materialized. example_inputs = tuple(example_inputs) @@ -1308,7 +1482,7 @@ def gen_register_replacement( ) if not m or not hasattr(m, unique_name): log.warning( - "Precompiled pattern %r not found. Run torchen/fuse/gen_patterns.py.", + "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", unique_name, ) pat = getattr(m, unique_name) @@ -1321,6 +1495,8 @@ def gen_register_replacement( # Since this is just an optimization we can clear it out. arg.constant = None + if PatternPrettyPrinter.run(pat) in _seen_patterns and skip_duplicates: + return _known_precompiled_patterns.append( (search_fn, example_inputs, trace_fn, scalar_workaround, pat) ) @@ -1339,11 +1515,15 @@ def gen_register_replacement( @functorch_config.patch(functionalize_rng_ops=False) def gen_pattern( - search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=() + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), ) -> PatternExpr: argnames = [*inspect.signature(search_fn).parameters.keys()] - if scalar_workaround == (): + if scalar_workaround is None: scalar_workaround = {} flat_inputs = [] input_idx = 0 # Positional arguments index @@ -1366,34 +1546,42 @@ def gen_pattern( def register_lowering_pattern( - pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False -): + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ Register an aten to inductor IR replacement pattern. The decorated function is saved and then called a lowering time allowing direct pattern to inductor IR conversion. """ - def decorator(handler): + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: assert callable(handler) LoweringPatternEntry( pattern=pattern, extra_check=extra_check, handler=handler ).register(pass_dict, prepend=prepend) - handler._inductor_lowering_function = True + handler._inductor_lowering_function = True # type: ignore[attr-defined] return handler return decorator def register_graph_pattern( - pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False -): + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ Register a pattern that runs a function on the FX graph, allowing custom transformation code. """ - def decorator(handler): + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: assert callable(handler) GraphPatternEntry( pattern=pattern, extra_check=extra_check, handler=handler @@ -1439,7 +1627,7 @@ def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: return "mutation_region_id" not in next(iter(graph.nodes)).meta -def compute_mutation_region_ids(graph: torch.fx.GraphModule): +def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None: mutation_region_id = 0 for nd in graph.nodes: if is_mutation_op(nd): @@ -1449,8 +1637,10 @@ def compute_mutation_region_ids(graph: torch.fx.GraphModule): class PatternMatcherPass: def __init__( - self, prevent_match_across_mutations=False, pass_name: Optional[str] = None - ): + self, + prevent_match_across_mutations: bool = False, + pass_name: Optional[str] = None, + ) -> None: super().__init__() self.patterns: DefaultDict[ Tuple[str, torch.fx.node.Target], List[PatternEntry] @@ -1461,11 +1651,18 @@ def __init__( def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: return self.patterns[item] - def apply(self, graph: torch.fx.GraphModule) -> int: + def apply(self, gm: torch.fx.GraphModule) -> int: if not self.patterns: return 0 - if isinstance(graph, torch.fx.GraphModule): - graph = graph.graph + if isinstance(gm, torch.fx.GraphModule): + graph = gm.graph + elif isinstance(gm, torch.fx.Graph): + graph = gm + gm = graph.owning_module + else: + raise RuntimeError( + f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" + ) if self.prevent_match_across_mutations: if should_compute_mutation_region_ids(graph): compute_mutation_region_ids(graph) @@ -1482,52 +1679,56 @@ def apply(self, graph: torch.fx.GraphModule) -> int: nodes.append(graph.find_nodes(op=op, target=target, sort=False)) if has_call_module: nodes.append(graph.find_nodes(op="call_module", sort=False)) - for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): - target = extract_target(node) - if node.op == "call_module": - if (node.op, target) not in self.patterns: - continue - - # conservatively not applying pattern for cpu input, - # since some of the patterns induce codegen and split nodes. - # Note: we will only skip cpu compute if disable_cpp_codegen=True - if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): - continue + pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + with GraphTransformObserver( + gm, pass_name, trace_config.log_url_for_graph_xform + ): + for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): + target = extract_target(node) + if node.op == "call_module": + if (node.op, target) not in self.patterns: + continue - for entry in self.patterns[(node.op, target)]: - if node._erased: - break - m = entry.pattern.match(node) - # pattern match crosses mutation barrier - discard - if ( - self.prevent_match_across_mutations - and is_match(m) - and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] - ): + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): continue - if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: - log.warning("%s%s %s %s", node, node.args, m, entry.pattern) - if is_match(m) and entry.extra_check(m): - count += 1 - entry.apply(m, graph, node) # type: ignore[arg-type] - counters["inductor"]["pattern_matcher_count"] += 1 - counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + + for entry in self.patterns[(node.op, target)]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + self.prevent_match_across_mutations + and is_match(m) + and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) # type: ignore[arg-type] + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) return count - def clear(self): + def clear(self) -> None: self.patterns.clear() -def _not_implemented(*args, **kwargs) -> NoReturn: +def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError def fx_to_pattern( - gm, - ignore_types=(), - argnames=(), - scalar_workaround=(), - exclusive_arg_names=(), + gm: Union[torch.fx.GraphModule, torch.fx.Graph], + ignore_types: Sequence[Type[Any]] = (), + argnames: Sequence[str] = (), + scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), ) -> PatternExpr: """ Convert an FX graph into a PatternExpr. This is useful for simple @@ -1539,7 +1740,7 @@ def fx_to_pattern( inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} assert len(inv_scalar_workaround) == len(scalar_workaround) - def process_arg(x): + def process_arg(x: T) -> Union[T, KeywordArg, Ignored]: if isinstance(x, (float, int)) and x in inv_scalar_workaround: return KeywordArg(inv_scalar_workaround[x]) if type(x) in ignore_types: @@ -1555,7 +1756,9 @@ class Converter(torch.fx.Interpreter): call_module = _not_implemented get_attr = _not_implemented - def placeholder(self, target, args, kwargs): + def placeholder( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] + ) -> Union[ExclusiveKeywordArg, KeywordArg]: n = next(argnum) if n < len(argnames): name = argnames[n] @@ -1570,7 +1773,9 @@ def placeholder(self, target, args, kwargs): else: return KeywordArg(name) - def call_function(self, target, args, kwargs): + def call_function( + self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] + ) -> PatternExpr: args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) if list in ignore_types: # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] @@ -1578,11 +1783,11 @@ def call_function(self, target, args, kwargs): kwargs = {k: process_arg(a) for k, a in kwargs.items()} return CallFunction(target, *args, **kwargs) - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: rv = super().run_node(n) if n.op == "output" and isinstance(rv, tuple): - assert len(rv) == len(n.args[0]) - for r, arg in zip(rv, n.args[0]): + assert len(rv) == len(n.args[0]) # type: ignore[arg-type] + for r, arg in zip(rv, n.args[0]): # type: ignore[arg-type] r.users = len(arg.users) else: rv.users = len(n.users) @@ -1595,11 +1800,18 @@ def run_node(self, n): @torch.no_grad() -def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule: +def fwd_only( + fn: Callable[..., Any], args: Sequence[Any], *, run_dce: bool = True +) -> torch.fx.GraphModule: """Build a normalized inference graph, for use with fx_to_pattern""" # TODO - look into using aot autograd, asserting no mutating ops here with enable_python_dispatcher(): gm = make_fx(fn, select_decomp_table(), tracing_mode="real")(*args) + + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + if run_dce: gm.graph.eliminate_dead_code() gm.recompile() @@ -1607,11 +1819,13 @@ def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule: @torch.enable_grad() -def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule: +def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: """Build a normalized training graph, for use with fx_to_pattern""" gm: Optional[torch.fx.GraphModule] = None - def record_joint_graph(joint_graph, inputs, **kwargs): + def record_joint_graph( + joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any + ) -> Tuple[torch.fx.GraphModule, torch.fx.GraphModule]: nonlocal gm assert not gm gm = clone_graph(joint_graph) @@ -1628,6 +1842,10 @@ def record_joint_graph(joint_graph, inputs, **kwargs): )(*args) assert gm + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + from .fx_passes.joint_graph import pointless_view matcher_pass = PatternMatcherPass() @@ -1653,7 +1871,7 @@ def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: return args -def stable_topological_sort(graph: torch.fx.Graph): +def stable_topological_sort(graph: torch.fx.Graph) -> None: # Nodes are in exactly one of these three collections: # - Nodes in `pending` are waiting to be processed (in reverse order): @@ -1689,12 +1907,12 @@ def stable_topological_sort(graph: torch.fx.Graph): assert not waiting and len(ready) == len(graph.nodes) -def init_once_fakemode(fn: Callable[..., Any]): +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: """Wrapper around lazy init functions in fx_passes/""" @functools.lru_cache(None) @functools.wraps(fn) - def lazy_init(): + def lazy_init() -> Any: counters_ref = counters["inductor"].copy() with torch._guards.tracing( @@ -1710,10 +1928,10 @@ def lazy_init(): return lazy_init -def config_flag(name): +def config_flag(name: str) -> Callable[[Match], Any]: """Function for extra_check to put pass behind a flag""" - def flag_check(match): + def flag_check(match: Match) -> Any: return getattr(config, name) return flag_check @@ -1721,7 +1939,7 @@ def flag_check(match): def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: class CopyGraph(Transformer): - def run_node(self, old_node): + def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: new_node = super().run_node(old_node) if isinstance(new_node, torch.fx.Proxy): new_node.node.meta.update(old_node.meta) @@ -1738,7 +1956,7 @@ def run_node(self, old_node): def get_arg_value( node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None -): +) -> Any: return ( node.args[arg_number] if len(node.args) > arg_number @@ -1746,7 +1964,7 @@ def get_arg_value( ) -def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]: +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]: fns = [fn] if isinstance(fn, torch._ops.OpOverloadPacket): fns.extend([getattr(fn, overload) for overload in fn.overloads()]) @@ -1754,7 +1972,7 @@ def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]: return [node for node in nodes if node.target in fns] -def extract_target(node: Node): +def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: """For call_function and call_method, we directly use the target function; For call_module, the target is string, and we treat the module class as a function. diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index 7b4edf0627dd..954a85abe52e 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from . import lowering diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py new file mode 100644 index 000000000000..91b69b3bf6f5 --- /dev/null +++ b/torch/_inductor/remote_cache.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import os +from abc import abstractmethod + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, cache_id: str): + pass + + @abstractmethod + def get(self, key: str): + pass + + @abstractmethod + def put(self, key: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + """ + A Redis implementation of a remote/distributed cache. + """ + + def __init__(self, cache_id: str): + import redis + + self._key_fmt = f"pt2:{cache_id}:{{key}}" + self._redis = redis.Redis( + host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), + port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), + ) + + def _get_key(self, key: str) -> str: + return self._key_fmt.format(key=key) + + def get(self, key: str): + return self._redis.get(self._get_key(key)) + + def put(self, key: str, data: bytes): + return self._redis.set(self._get_key(key), data) diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 878125c9fabc..da30bd46b112 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index b5d10478a03c..31ff94774613 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import logging diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 46acd83c7377..ba36f40a2263 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import typing from dataclasses import fields diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 7d24be0ded47..51a6c22644b8 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools @@ -187,7 +188,7 @@ def get_first_attr(obj, *attrs): dynamo_timed = torch._dynamo.utils.dynamo_timed except AttributeError: # Compile workers only have a mock version of torch - def dynamo_timed(original_function=None, phase_name=None): + def dynamo_timed(original_function=None, phase_name=None, fwd_only=True): if original_function: return original_function return dynamo_timed diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 71b746bdf49a..845bec583f6d 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs try: import triton import triton.language as tl @@ -16,15 +17,21 @@ class tl: # type: ignore[no-redef] # In the latest triton, math functions were shuffled around into different modules: # https://github.com/openai/triton/pull/3172 -if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): - libdevice = tl.extra.cuda.libdevice - math = tl.math -elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): - libdevice = tl.extra.intel.libdevice +try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 math = tl.math -else: - libdevice = tl.math - math = tl +except ImportError: + if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl @triton.jit @@ -195,7 +202,7 @@ def bucketize_binary_search( while full_range > 1: mid = (high + low) // 2 mask = mid < OFFSETS_SIZE - bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask) + bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0) if right: is_above = values >= bucket_upper_bound else: diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 82c8f9a4fb71..5396ccf3e70d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import copy import functools @@ -748,7 +749,6 @@ def save_cuda_kernel(self, grid, stream, launcher): # User defined triton kernels will have arbitrary kwarg names "meta": launcher.config.kwargs, } - from torch._inductor.codecache import CudaKernelParamCache binary = ( @@ -1031,7 +1031,7 @@ def should_use_remote_autotune_cache(inductor_meta): if inductor_meta.get("is_hip"): return False - from triton.runtime.fb_memcache import MEMCACHE_VERSION + from triton.fb.fb_memcache import MEMCACHE_VERSION return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int( "pytorch/remote_cache:autotune_memcache_version" @@ -1075,11 +1075,17 @@ def cached_autotune( try: if inductor_meta.get("is_fbcode"): - remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( - key + import triton.fb.fb_memcache + + remote_cache = ( + triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( + key + ) ) else: - remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) + from torch._inductor.remote_cache import RedisRemoteCacheBackend + + remote_cache = RedisRemoteCacheBackend(key) except Exception: remote_cache = None log.warning("Unable to create a remote cache", exc_info=True) @@ -1737,7 +1743,7 @@ def grid_fn(meta): max_y_grid = get_max_y_grid() if znumel is None: div = ceildiv(y_grid, max_y_grid) - y_grid = y_grid // div + y_grid = ceildiv(y_grid, div) z_grid = div else: z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None)) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 3b8a13c49cb1..46d80569125f 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -28,6 +28,7 @@ import sympy import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols @@ -62,107 +63,10 @@ fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") -class WhyNoFuse: - # TODO when we drop support for Python < 3.10, we can use - # @dataclass(slots=True) instead of manually specifying __slots__. - __slots__ = ["node1", "node2", "reason", "args"] - reason: str - args: Tuple[Any, ...] - - def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"): - self.node1 = node1 - self.node2 = node2 - - def __call__(self, reason: str, *args: Any) -> None: - self.reason = reason - self.args = args - fusion_log.debug(self) - - def __str__(self) -> str: - return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( - self.reason % self.args - ) - - -def pformat(obj: Any) -> str: - if isinstance(obj, set): - # pformat has trouble with sets of sympy exprs - obj = sorted(obj, key=str) - result = pprint.pformat(obj, indent=4) - if "\n" in result: - return f"\n{textwrap.indent(result, ' '*4)}" - return result - - -class OutputNode: - def __init__(self, dep: StarDep) -> None: - self.unmet_dependencies = {dep} - self.inverse_users: List[BaseSchedulerNode] = [] - - def is_reduction(self) -> bool: - return False - - def get_inputs_that_alias_output(self) -> Sequence[str]: - return () - - def get_name(self) -> str: - return "OUTPUT" - - __repr__ = get_name - - -def _prune_redundant_deps( - node: "BaseSchedulerNode", name_to_fused_node: Dict[str, "BaseSchedulerNode"] -) -> None: - """ - Prunes weakdeps intended for mutation ordering - on an upstream fused node if after fusion there is another dependency - on the fused upstream node, making the weakdep redundant - - In essence this enforces an ordering on fusions. As fusions occur, weakdeps will - be incrementally removed, enabling other fusions, ensuring they are fused in order. - """ - name_to_dep_count: Counter[str] = collections.Counter() - - for dep in node.unmet_dependencies: - if not isinstance(dep, WeakDep): - name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 - - def should_prune(dep: Dep) -> bool: - if isinstance(dep, WeakDep): - is_redundant = ( - name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 - ) - # These can occur because fused nodes always gather deps from their snodes - # If B has a weakdep on A - # B gets fused with C, then any time BC is fused, the weakdep will reappear - is_self_dep = name_to_fused_node[dep.name] == node - return is_redundant or is_self_dep - else: - return False - - deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} - - if deps_to_prune: - node.unmet_dependencies = node.unmet_dependencies - deps_to_prune - node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) - - -# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel -kernel_name_to_op = { - "extern_kernels.convolution": torch.ops.aten.convolution, - "extern_kernels.mm": torch.ops.aten.mm, - "extern_kernels.bmm": torch.ops.aten.bmm, - "extern_kernels.addmm": torch.ops.aten.addmm, -} - - class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites unmet_dependencies: Set[Dep] - # Processed deps used while scoring fusion - read_and_write_deps_with_hint: Set[Tuple[Dep, int]] def __init__(self, scheduler: "Scheduler", node: ir.Buffer) -> None: self.scheduler: Scheduler = scheduler @@ -251,25 +155,6 @@ def set_read_writes(self, rw: dependencies.ReadWrites) -> None: self.unmet_dependencies = self.read_writes.reads self.prune_deps() - # read_and_write_deps_with_hint are a summary of read_writes used by - # score_fusion_memory() - def dep_size_hint(dep: Dep) -> int: - try: - if dep.has_unbacked_symbols(): - return 0 - return dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - return 0 - - self.read_and_write_deps_with_hint = { - (dep, hint) - for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) - if (hint := dep_size_hint(dep)) > 0 - } - def op_counts(self) -> Counter[str]: return self.read_writes.op_counts @@ -725,6 +610,101 @@ def get_template_node(self) -> Optional[ir.TemplateBuffer]: return None +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["node1", "node2", "reason", "args"] + reason: str + args: Tuple[Any, ...] + + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + self.node1 = node1 + self.node2 = node2 + + def __call__(self, reason: str, *args: Any) -> None: + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self) -> str: + return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( + self.reason % self.args + ) + + +def pformat(obj: Any) -> str: + if isinstance(obj, set): + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' '*4)}" + return result + + +class OutputNode: + def __init__(self, dep: StarDep) -> None: + self.unmet_dependencies = {dep} + self.inverse_users: List[BaseSchedulerNode] = [] + + def is_reduction(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return () + + def get_name(self) -> str: + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps( + node: BaseSchedulerNode, name_to_fused_node: Dict[str, BaseSchedulerNode] +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep): + is_redundant = ( + name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 + ) + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[dep.name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel +kernel_name_to_op = { + "extern_kernels.convolution": torch.ops.aten.convolution, + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, +} + + class ExternKernelSchedulerNode(BaseSchedulerNode): def debug_str_extra(self) -> str: return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" @@ -741,36 +721,6 @@ class NopKernelSchedulerNode(BaseSchedulerNode): pass -def debug_triton_code(node: Union["SchedulerNode", "FusedSchedulerNode"]) -> List[str]: - lines = [] - multi_template = node.get_template_node() - assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) - if multi_template and multi_template.make_kernel_render is None: - lines.append(f"{node.get_name()} Unfinalized multi template buffer") - else: - from torch._inductor.codegen.cuda_combined_scheduling import ( - CUDACombinedScheduling, - ) - from torch._inductor.codegen.triton import TritonScheduling - - snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes - device = snodes[0].get_device() - backend = node.scheduler.get_backend(device) - assert isinstance(backend, (TritonScheduling, CUDACombinedScheduling)) - V.graph.scheduler.current_device = device - - # Don't increment kernel count when generating debug string. - # This will confuse some unit tests that check the number of - # generated kernels. - old_generated_kernel_count = metrics.generated_kernel_count - triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() - metrics.generated_kernel_count = old_generated_kernel_count - - lines.append(f"{node.get_name()} Triton code:") - lines.append(textwrap.indent(triton_code, " ")) - return lines - - class SchedulerNode(BaseSchedulerNode): def __init__( self, @@ -1329,8 +1279,10 @@ def index_cmp(a: int, b: int) -> int: # 1-sizes don't matter, just move them to the end return cmp(sizes[a] == 1, sizes[b] == 1) - stride_len_a = [sl[a] for sl in stride_lengths] - stride_len_b = [sl[b] for sl in stride_lengths] + # Take abs, otherwise flipped dimensions are treated as smaller + # strides than contiguous dims + stride_len_a = [abs(sl[a]) for sl in stride_lengths] + stride_len_b = [abs(sl[b]) for sl in stride_lengths] # equivalent to # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() @@ -1393,9 +1345,12 @@ def merge(self, other: "NodeUser") -> "NodeUser": class Scheduler: + __dep_size_hint_cache: Dict[Dep, int] + @dynamo_timed def __init__(self, nodes: List[ir.Buffer]) -> None: super().__init__() + self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: Dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -1638,8 +1593,9 @@ def add_user( # generate a dependency because if we do, Inductor will start trying # to free the unbacked int but that's pointless for name, val in V.graph.graph_inputs.items(): - if isinstance(val, sympy.Symbol): - unbacked_symbol_to_origin_node[val] = None + if isinstance(val, sympy.Expr): + for fs in val.free_symbols: + unbacked_symbol_to_origin_node[fs] = None for node in self.nodes: log.debug("scheduling %s", node.node) @@ -2504,6 +2460,22 @@ def score_fusion( proximity_score, ) + def dep_size_hint(self, dep: Dep) -> int: + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res + def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> int: @@ -2511,11 +2483,10 @@ def score_fusion_memory( The first term in our fusion score that estimates number of saved memory operations. """ - return sum( - hint - for dep, hint in node1.read_and_write_deps_with_hint - & node2.read_and_write_deps_with_hint + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes ) + return sum(self.dep_size_hint(dep) for dep in common_memory_deps) def get_possible_fusions_with_highest_priority( self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] @@ -2772,9 +2743,6 @@ def codegen(self) -> None: assert isinstance(node, NopKernelSchedulerNode) node.allocate() - if config.debug_check_inf_and_nan: - V.graph.wrapper_code.generate_inf_and_nan_checker(node) - if config.triton.debug_sync_kernel: self.get_backend(device).codegen_sync() @@ -2889,3 +2857,33 @@ def get_fusion_pair_priority( The smaller is with higher priority. """ return 0 + + +def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + from torch._inductor.codegen.triton import TritonScheduling + + snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes + device = snodes[0].get_device() + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (TritonScheduling, CUDACombinedScheduling)) + V.graph.scheduler.current_device = device + + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 7aafcfe31488..467af6f57812 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import contextlib import functools @@ -22,6 +23,7 @@ from filelock import FileLock import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.testing import rand_strided from torch._dynamo.utils import counters, identity, preserve_rng_state @@ -39,6 +41,7 @@ ) from .codegen.triton_utils import config_of, signature_to_meta +from .codegen.wrapper import pexpr from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType from .runtime.hints import DeviceProperties @@ -150,7 +153,7 @@ def __init__( @contextlib.contextmanager def set_subgraph_body(self, body_name: str): old_body = self.body - assert body_name in self.subgraph_bodies + assert body_name in self.subgraph_bodies, body_name self.body = self.subgraph_bodies[body_name] yield self.body = old_body @@ -309,7 +312,10 @@ def modification( Args: subgraph_number (int): The index of the subgraph in self.subgraphs """ - with self.create_subgraph_body(f"modification_{subgraph_number}"): + num = 0 + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): assert isinstance(subgraph_number, int) assert isinstance(self.subgraphs, list) assert ( @@ -381,7 +387,7 @@ def store_output( assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x) for x in indices] + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -405,7 +411,7 @@ def store_output( output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex") + output_index = sympy.Symbol("xindex", integer=True) epilogue_args = [val] for input_node in itertools.chain( @@ -452,7 +458,7 @@ def make_load(self, name, indices, mask): index = " + ".join( f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) ) - return f"tl.load({name} + ({index}), {mask})" + return f"tl.load({name} + ({index}), {mask}, other=0.0)" def template_env(self): """ @@ -533,7 +539,7 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): meta = wrapper.add_meta_once(self.meta) grid_call = [ - texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes + pexpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes ] + [meta] grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" wrapper.writeline( @@ -576,6 +582,7 @@ def generate( epilogue_fn=identity, subgraphs=None, mutated_inputs=None, + call_sizes=None, **kwargs, ): """This function generates a TritonTemplateCaller @@ -610,6 +617,9 @@ def generate( "64-bit indexing is not yet implemented for triton templates" ) + if call_sizes is None: + call_sizes = layout.size + kernel_options = dict( input_nodes=input_nodes, defines=defines, @@ -617,13 +627,14 @@ def generate( num_warps=num_warps, grid_fn=self.grid, meta=kwargs, - call_sizes=layout.size, + call_sizes=call_sizes, prefix_args=prefix_args, suffix_args=suffix_args, epilogue_fn=epilogue_fn, index_dtype="tl.int32", subgraphs=subgraphs, ) + with patch.object( V.graph, "get_dtype", self._fake_get_dtype(fake_out) ), TritonTemplateKernel( @@ -697,7 +708,7 @@ def make_kernel_render(out_node): assert mod.__file__ is not None grid = self.grid( *V.graph.sizevars.size_hints( - layout.size, + call_sizes, fallback=config.unbacked_symint_fallback, ), kwargs, @@ -769,7 +780,7 @@ def to_callable(self): def call_name(self): return f"extern_kernels.{self.name}" - @functools.lru_cache(None) + @functools.lru_cache(None) # noqa: B019 def hash_key(self): fn = self.to_callable() parts = [ diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index bc8803a5e715..f48c0884d3ad 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools import logging @@ -161,9 +162,9 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base"), - sympy.Wild("divisor"), - sympy.Wild("modulus"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), ), visit_modular_indexing, ) @@ -171,8 +172,8 @@ def visit_modular_indexing(base, divisor, modulus): if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base"), - sympy.Wild("divisor"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), ), visit_indexing_div, ) @@ -192,7 +193,16 @@ def _simplify_loops_impl( """ sizes = list(map(self.simplify, sizes)) - strides = [self.stride_vars(x, index_vars) for x in index_formulas] + strides = [ + # index_formulas may contain boolean expressions (e.g. s0 < 10), + # for which "strides" don't make sense so we ignore them here. + # NOTE: These expressions may still block merging dims in the sound + # substitution test performed in can_merge_dims. + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + for x in index_formulas + ] assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) for i in range(len(sizes)): @@ -583,6 +593,137 @@ def lookup_precomputed_size(self, expr: Expr) -> Expr: def free_symbols(self) -> Set[sympy.Symbol]: return set(self.var_to_val.keys()) - set(self.replacements.keys()) + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should conatins a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should constains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, Tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplfied. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.Integer(0) + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + def join_dimensions(expr: Expr) -> Expr: if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): @@ -604,11 +745,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0]) - base = sympy.Wild("base") - divisor = sympy.Wild("divisor") - mod1 = sympy.Wild("modulus") - mod2 = sympy.Wild("modulus2") + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 9413ac1b2659..4f7eec8ff50c 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for lowering subgraphs used by higher order operators """ diff --git a/torch/_inductor/test_case.py b/torch/_inductor/test_case.py index 3933c9dbc004..3acc68ff22a5 100644 --- a/torch/_inductor/test_case.py +++ b/torch/_inductor/test_case.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import os diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index e8421722568c..3c105ba7db2d 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -1,24 +1,26 @@ +# mypy: allow-untyped-defs import torch.library from torch import Tensor from torch.autograd import Function -_test_lib_def = torch.library.Library("_inductor_test", "DEF") -_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) -_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") -for dispatch_key in ("CPU", "CUDA", "Meta"): - _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + class Realize(Function): + @staticmethod + def forward(ctx, x): + return torch.ops._inductor_test.realize(x) -class Realize(Function): - @staticmethod - def forward(ctx, x): - return torch.ops._inductor_test.realize(x) + @staticmethod + def backward(ctx, grad_output): + return grad_output - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -def realize(x: Tensor) -> Tensor: - return Realize.apply(x) + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 8b66b496fd43..d39713be81dd 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import collections @@ -51,6 +52,7 @@ from torch._dynamo.utils import detect_fake_mode from torch.autograd import DeviceType from torch.autograd.profiler_util import EventList +from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import make_symbol, SymT @@ -63,7 +65,36 @@ _T = TypeVar("_T") VarRanges = Dict[sympy.Expr, sympy.Expr] -ALIGNMENT = 16 +GPU_ALIGN_BYTES = 16 + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes): + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr): + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value): + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: @@ -192,7 +223,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(numer, denom) + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes @@ -390,7 +421,7 @@ def sort_func(elem): RV = TypeVar("RV", covariant=True) -class CachedMethod(Generic[P, RV], Protocol): +class CachedMethod(Protocol, Generic[P, RV]): @staticmethod def clear_cache(self) -> None: ... @@ -1373,7 +1404,8 @@ def pass_execution_and_save(func, gm, inp, msg): print(f"Before:\n{gm.graph}", file=f) print(gm.graph, file=before_io) start_time = datetime.now() - func(gm.graph) + with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform): + func(gm.graph) time_elapsed = datetime.now() - start_time # recompile graph stable_topological_sort(gm.graph) @@ -1554,7 +1586,9 @@ def tensor_is_aligned(tensor: torch.Tensor): # but symbolic storage_offsets are. For consistency, we suppress guard creation # upon performing this check: that ensures that we don't add recompiles when we # add this logic. - return (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % ALIGNMENT == 0 + return ( + tensor.storage_offset() * get_dtype_size(tensor.dtype) + ) % GPU_ALIGN_BYTES == 0 def should_assume_input_aligned(example_input: torch.Tensor): diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 07c6ea8190a6..ac8d3c640141 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file provides a number of "global" variables/handlers that are actually thread local and dynamically scoped, with Inductor patching them to various diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 3e952765695f..976d0c7458e7 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import tempfile from collections import defaultdict diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 57458a0801ab..4ed425f0435a 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The weak_script annotation needs to be here instead of inside torch/jit/ so it can be used in other places in torch/ (namely torch.nn) without running into diff --git a/torch/_lazy/__init__.py b/torch/_lazy/__init__.py index 249ce9b11578..c074abd14372 100644 --- a/torch/_lazy/__init__.py +++ b/torch/_lazy/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import threading import torch._C._lazy diff --git a/torch/_lazy/closure.py b/torch/_lazy/closure.py index 07f1055ee827..32b2c58ba2b8 100644 --- a/torch/_lazy/closure.py +++ b/torch/_lazy/closure.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import threading from queue import Empty as EmptyQueue, Queue diff --git a/torch/_lazy/computation.py b/torch/_lazy/computation.py index 27b73c42e5c0..17a61e36cb9f 100644 --- a/torch/_lazy/computation.py +++ b/torch/_lazy/computation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy import torch._C._lazy_ts_backend diff --git a/torch/_lazy/config.py b/torch/_lazy/config.py index e7a4d1dd24f8..f7ebca12de7f 100644 --- a/torch/_lazy/config.py +++ b/torch/_lazy/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/debug.py b/torch/_lazy/debug.py index 286aa049280c..84534fb23250 100644 --- a/torch/_lazy/debug.py +++ b/torch/_lazy/debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/device_context.py b/torch/_lazy/device_context.py index 840c7f8e50d0..bc47835fd912 100644 --- a/torch/_lazy/device_context.py +++ b/torch/_lazy/device_context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import threading from typing import Any, Dict diff --git a/torch/_lazy/extract_compiled_graph.py b/torch/_lazy/extract_compiled_graph.py index 033d000c69d8..7c1cb95855b9 100644 --- a/torch/_lazy/extract_compiled_graph.py +++ b/torch/_lazy/extract_compiled_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import itertools diff --git a/torch/_lazy/ir_cache.py b/torch/_lazy/ir_cache.py index 4270684d2943..a6e654566f29 100644 --- a/torch/_lazy/ir_cache.py +++ b/torch/_lazy/ir_cache.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/metrics.py b/torch/_lazy/metrics.py index 2d7db7305567..a77981feb90d 100644 --- a/torch/_lazy/metrics.py +++ b/torch/_lazy/metrics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy diff --git a/torch/_lazy/ts_backend.py b/torch/_lazy/ts_backend.py index 184223771932..5c6ce13746e9 100644 --- a/torch/_lazy/ts_backend.py +++ b/torch/_lazy/ts_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C._lazy_ts_backend diff --git a/torch/_library/abstract_impl.py b/torch/_library/abstract_impl.py index 2946b743ee53..1f0f4c87bab7 100644 --- a/torch/_library/abstract_impl.py +++ b/torch/_library/abstract_impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools from typing import Callable, Optional diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index ebd35361a940..1ff5696417f3 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from typing import Any, Callable, Optional, Protocol diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 3272ffc1a18f..ce692f16a097 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import weakref from typing import ( @@ -453,7 +454,7 @@ def _register_to_dispatcher(self) -> None: lib.define( schema_str, - tags=[_C.Tag.pt2_compliant_tag], + tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order], ) self._opoverload = _library.utils.lookup_op(self._qualname) diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index d77989cd829b..f206b68fc3be 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Any, Dict, Optional, Protocol, Tuple @@ -16,6 +17,23 @@ def __init__(self, wrapped_obj: Any, script_class_name: str): self.script_class_name = script_class_name +class FakeScriptMethod: + def __init__( + self, + self_fake_obj: FakeScriptObject, + method_name: str, + schema: Optional[torch.FunctionSchema], + ): + self.self_fake_obj = self_fake_obj + self.method_name = method_name + self.schema = schema + + def __call__(self, *args, **kwargs): + from torch._higher_order_ops.torchbind import call_torchbind + + return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs) + + class HasStaticMethodFromReal(Protocol): @classmethod def from_real(cls, real_obj: torch.ScriptObject): @@ -95,25 +113,25 @@ def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject: fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) - def _call_torchbind(method_name): - from torch._higher_order_ops.torchbind import call_torchbind - - def wrapped(self_, *args, **kwargs): - return call_torchbind(self_, method_name, *args, **kwargs) - - return wrapped - fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name()) # type: ignore[attr-defined] + for name in x._method_names(): # type: ignore[attr-defined] attr = getattr(fake_x, name, None) if attr: if not callable(attr): raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") + real_attr = getattr(x, name) # type: ignore[attr-defined] + + # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__ + method_schema: Optional[torch.FunctionSchema] = None + if isinstance(real_attr, torch.ScriptMethod): + method_schema = real_attr.schema # type: ignore[attr-defined] + setattr( fake_x_wrapped, name, - _call_torchbind(name).__get__(fake_x_wrapped), + FakeScriptMethod(fake_x_wrapped, name, method_schema), ) else: log.warning("fake object of %s doesn't implement method %s.", x, name) diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index fd03f9182434..6305375e4433 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import typing diff --git a/torch/_library/simple_registry.py b/torch/_library/simple_registry.py index 64a543e99b0b..65ecf8ef0d75 100644 --- a/torch/_library/simple_registry.py +++ b/torch/_library/simple_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .abstract_impl import AbstractImplHolder __all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] diff --git a/torch/_library/utils.py b/torch/_library/utils.py index d3577dbbf9d1..27d1ef92b5b3 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import inspect import sys diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index c9d5cde41f60..fd5f574ad7eb 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Various linear algebra utility methods for internal use. """ @@ -43,30 +44,9 @@ def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: return torch.matmul(A, B) -def conjugate(A): - """Return conjugate of tensor A. - - .. note:: If A's dtype is not complex, A is returned. - """ - if A.is_complex(): - return A.conj() - return A - - -def transpose(A): - """Return transpose of a matrix or batches of matrices.""" - ndim = len(A.shape) - return A.transpose(ndim - 1, ndim - 2) - - -def transjugate(A): - """Return transpose conjugate of a matrix or batches of matrices.""" - return conjugate(transpose(A)) - - def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: """Return bilinear form of matrices: :math:`X^T A Y`.""" - return matmul(transpose(X), matmul(A, Y)) + return matmul(X.mT, matmul(A, Y)) def qform(A: Optional[Tensor], S: Tensor): diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 6ca1e7294217..3f7bdf456c39 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Locally Optimal Block Preconditioned Conjugate Gradient methods. """ # Author: Pearu Peterson @@ -924,7 +925,7 @@ def _update_ortho(self): S_, mm( Z[:, n - nc :], - _utils.basis(_utils.transpose(Z[: n - nc, n - nc :])), + _utils.basis(Z[: n - nc, n - nc :].mT), ), ) np = P.shape[-1] @@ -1045,7 +1046,7 @@ def _get_svqb( # The original algorithm 4 from [DuerschPhD2015]. d_col = (d**-0.5).reshape(d.shape[0], 1) - DUBUD = (UBU * d_col) * _utils.transpose(d_col) + DUBUD = (UBU * d_col) * d_col.mT E, Z = _utils.symeig(DUBUD) t = tau * abs(E).max() if drop: @@ -1057,7 +1058,7 @@ def _get_svqb( else: E[(torch.where(E < t))[0]] = t - return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5) + return torch.matmul(U * d_col.mT, Z * E**-0.5) def _get_ortho(self, U, V): """Return B-orthonormal U with columns are B-orthogonal to V. @@ -1105,7 +1106,7 @@ def _get_ortho(self, U, V): BV_norm = torch.norm(mm_B(self.B, V)) BU = mm_B(self.B, U) - VBU = mm(_utils.transpose(V), BU) + VBU = mm(V.mT, BU) i = j = 0 stats = "" for i in range(i_max): @@ -1125,7 +1126,7 @@ def _get_ortho(self, U, V): self.ivars["ortho_j"] = j return U BU = mm_B(self.B, U) - UBU = mm(_utils.transpose(U), BU) + UBU = mm(U.mT, BU) U_norm = torch.norm(U) BU_norm = torch.norm(BU) R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) @@ -1136,7 +1137,7 @@ def _get_ortho(self, U, V): self.fvars[vkey] = rerr if rerr < tau_ortho: break - VBU = mm(_utils.transpose(V), BU) + VBU = mm(V.mT, BU) VBU_norm = torch.norm(VBU) U_norm = torch.norm(U) rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 28a57e39bf3b..bfc071b0d53a 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import hashlib import itertools @@ -795,7 +796,11 @@ def format(self, record): ) if self._is_trace: assert s == "" - r = f"{prefix} {json.dumps(record.metadata)}" + try: + r = f"{prefix} {json.dumps(record.metadata)}" + except TypeError: + log.warning("failing metadata: %r", record.metadata) + raise if record.payload is not None: r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) return r diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 7a920ef4a455..4641c4c4717c 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -17,10 +17,11 @@ def get_approximate_basis( """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` - approximates :math:`A - M`. + approximates :math:`A - M`. without instantiating any tensors + of the size of :math:`A` or :math:`M`. .. note:: The implementation is based on the Algorithm 4.4 from - Halko et al, 2009. + Halko et al., 2009. .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be @@ -46,7 +47,7 @@ def get_approximate_basis( default value 2 is more than enough. M (Tensor, optional): the input tensor's mean of size - :math:`(*, 1, n)`. + :math:`(*, m, n)`. References:: - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding @@ -57,27 +58,27 @@ def get_approximate_basis( """ niter = 2 if niter is None else niter - m, n = A.shape[-2:] - dtype = _utils.get_floating_dtype(A) + dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype matmul = _utils.matmul - R = torch.randn(n, q, dtype=dtype, device=A.device) + R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device) # The following code could be made faster using torch.geqrf + torch.ormqr # but geqrf is not differentiable - A_H = _utils.transjugate(A) - if M is None: - Q = torch.linalg.qr(matmul(A, R)).Q - for i in range(niter): - Q = torch.linalg.qr(matmul(A_H, Q)).Q - Q = torch.linalg.qr(matmul(A, Q)).Q - else: - M_H = _utils.transjugate(M) - Q = torch.linalg.qr(matmul(A, R) - matmul(M, R)).Q - for i in range(niter): - Q = torch.linalg.qr(matmul(A_H, Q) - matmul(M_H, Q)).Q - Q = torch.linalg.qr(matmul(A, Q) - matmul(M, Q)).Q + X = matmul(A, R) + if M is not None: + X = X - matmul(M, R) + Q = torch.linalg.qr(X).Q + for i in range(niter): + X = matmul(A.mH, Q) + if M is not None: + X = X - matmul(M.mH, Q) + Q = torch.linalg.qr(X).Q + X = matmul(A, Q) + if M is not None: + X = X - matmul(M, Q) + Q = torch.linalg.qr(X).Q return Q @@ -89,19 +90,26 @@ def svd_lowrank( ) -> Tuple[Tensor, Tensor, Tensor]: r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, batches of matrices, or a sparse matrix :math:`A` such that - :math:`A \approx U diag(S) V^T`. In case :math:`M` is given, then + :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then SVD is computed for the matrix :math:`A - M`. .. note:: The implementation is based on the Algorithm 5.1 from - Halko et al, 2009. + Halko et al., 2009. - .. note:: To obtain repeatable results, reset the seed for the - pseudorandom number generator + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. - .. note:: The input is assumed to be a low-rank matrix. + .. note:: This is a randomized method. To obtain repeatable results, + set the seed for the pseudorandom number generator .. note:: In general, use the full-rank SVD implementation - :func:`torch.linalg.svd` for dense matrices due to its 10-fold + :func:`torch.linalg.svd` for dense matrices due to its 10x higher performance characteristics. The low-rank SVD will be useful for huge sparse matrices that :func:`torch.linalg.svd` cannot handle. @@ -116,7 +124,7 @@ def svd_lowrank( integer, and defaults to 2 M (Tensor, optional): the input tensor's mean of size - :math:`(*, 1, n)`, which will be broadcasted + :math:`(*, m, n)`, which will be broadcasted to the size of A in this function. References:: @@ -144,48 +152,30 @@ def _svd_lowrank( niter: Optional[int] = 2, M: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: + # Algorithm 5.1 in Halko et al., 2009 + q = 6 if q is None else q m, n = A.shape[-2:] matmul = _utils.matmul - if M is None: - M_t = None - else: + if M is not None: M = M.broadcast_to(A.size()) - M_t = _utils.transpose(M) - A_t = _utils.transpose(A) - - # Algorithm 5.1 in Halko et al 2009, slightly modified to reduce - # the number conjugate and transpose operations - if m < n or n > q: - # computing the SVD approximation of a transpose in - # order to keep B shape minimal (the m < n case) or the V - # shape small (the n > q case) - Q = get_approximate_basis(A_t, q, niter=niter, M=M_t) - Q_c = _utils.conjugate(Q) - if M is None: - B_t = matmul(A, Q_c) - else: - B_t = matmul(A, Q_c) - matmul(M, Q_c) - assert B_t.shape[-2] == m, (B_t.shape, m) - assert B_t.shape[-1] == q, (B_t.shape, q) - assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) - V = Vh.mH - V = Q.matmul(V) - else: - Q = get_approximate_basis(A, q, niter=niter, M=M) - Q_c = _utils.conjugate(Q) - if M is None: - B = matmul(A_t, Q_c) - else: - B = matmul(A_t, Q_c) - matmul(M_t, Q_c) - B_t = _utils.transpose(B) - assert B_t.shape[-2] == q, (B_t.shape, q) - assert B_t.shape[-1] == n, (B_t.shape, n) - assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) - V = Vh.mH - U = Q.matmul(U) + + # Assume that A is tall + if m < n: + A = A.mH + if M is not None: + M = M.mH + + Q = get_approximate_basis(A, q, niter=niter, M=M) + B = matmul(Q.mH, A) + if M is not None: + B = B - matmul(Q.mH, M) + U, S, Vh = torch.linalg.svd(B, full_matrices=False) + V = Vh.mH + U = Q.matmul(U) + + if m < n: + U, V = V, U return U, S, V @@ -198,7 +188,7 @@ def pca_lowrank( This function returns a namedtuple ``(U, S, V)`` which is the nearly optimal approximation of a singular value decomposition of - a centered matrix :math:`A` such that :math:`A = U diag(S) V^T`. + a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` .. note:: The relation of ``(U, S, V)`` to PCA is as follows: @@ -293,7 +283,7 @@ def pca_lowrank( ) ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) - M = _utils.transpose(torch.sparse.mm(C_t, ones_m1_t)) + M = torch.sparse.mm(C_t, ones_m1_t).mT return _svd_lowrank(A, q, niter=niter, M=M) else: C = A.mean(dim=(-2,), keepdim=True) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 89b452bca505..89262a7a203c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,6 +1,6 @@ +# mypy: allow-untyped-defs import math from enum import Enum -from functools import partial from typing import List, Optional, Sequence, Tuple, Union import torch @@ -19,7 +19,6 @@ corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, - FloatLike, IntLike, make_contiguous_strides_for, Number, @@ -369,8 +368,14 @@ def meta_copy_(self, src, non_blocking=False): # which runs most of the meta checks that we care about. # In theory, we should make this more robust by carefully # auditing our C++ copy_() kernel and copying the checks here. - - if torch._debug_has_internal_overlap(self) == 1: # 1 == MemOverlap::Yes + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are + # calling an actual copy_, you'll get that automatically + # https://github.com/pytorch/pytorch/issues/122477 + if ( + not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1 + ): # 1 == MemOverlap::Yes raise RuntimeError( "more than one element of the written-to tensor refers to a single memory location" ) @@ -3135,327 +3140,6 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): return self.new_empty(self.size()) -def register_meta_foreach(ops): - def wrapper(fn): - def register(op): - op_name = str(op).split(".")[1] - scalar_op = getattr(aten, op_name.replace("_foreach_", "")) - - _add_op_to_registry( - meta_table, - op, - partial( - fn, - _scalar_op=scalar_op, - ), - ) - - pytree.tree_map_(register, ops) - return fn - - return wrapper - - -@register_meta_foreach( - [ - aten._foreach_abs, - aten._foreach_acos, - aten._foreach_asin, - aten._foreach_atan, - aten._foreach_ceil, - aten._foreach_cos, - aten._foreach_cosh, - aten._foreach_erf, - aten._foreach_erfc, - aten._foreach_exp, - aten._foreach_expm1, - aten._foreach_frac, - aten._foreach_floor, - aten._foreach_lgamma, - aten._foreach_log, - aten._foreach_log10, - aten._foreach_log1p, - aten._foreach_log2, - aten._foreach_neg, - aten._foreach_reciprocal, - aten._foreach_round, - aten._foreach_sigmoid, - aten._foreach_sign, - aten._foreach_sin, - aten._foreach_sinh, - aten._foreach_sqrt, - aten._foreach_tan, - aten._foreach_tanh, - aten._foreach_trunc, - aten._foreach_zero, - aten._foreach_add, - aten._foreach_sub, - aten._foreach_mul, - aten._foreach_div, - aten._foreach_clamp_min, - aten._foreach_clamp_max, - aten._foreach_lerp, - ], -) -def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs): - torch._check( - isinstance(args[0], list), - lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."), - ) - - nelem = len(args[0]) - torch._check( - nelem > 0, - lambda: ("Tensor list must have at least one tensor."), - ) - - nlists = 1 - for iarg, arg in enumerate(args[1:]): - if isinstance(arg, list): - nlists += 1 - torch._check( - len(arg) == nelem, - lambda: ( - f"self and argument-{iarg+2} must match in length, " - f"but got {nelem} and {len(arg)}." - ), - ) - elif isinstance(arg, Tensor): - torch._check( - arg.dim() == 0 and arg.numel() == 1, - lambda: ( - "scalar tensor expected to be 0 dim but it has " - f"{arg.dim()} dimensions and {arg.numel()} elements." - ), - ) - else: - break - - result = [] - for elem in range(nelem): - each_args = [args[i][elem] for i in range(nlists)] - result.append(_scalar_op(*each_args, *args[nlists:], **kwargs)) - - return result - - -@register_meta_foreach( - [ - aten._foreach_abs_, - aten._foreach_acos_, - aten._foreach_asin_, - aten._foreach_atan_, - aten._foreach_ceil_, - aten._foreach_cos_, - aten._foreach_cosh_, - aten._foreach_erf_, - aten._foreach_erfc_, - aten._foreach_exp_, - aten._foreach_expm1_, - aten._foreach_frac_, - aten._foreach_floor_, - aten._foreach_lgamma_, - aten._foreach_log_, - aten._foreach_log10_, - aten._foreach_log1p_, - aten._foreach_log2_, - aten._foreach_neg_, - aten._foreach_reciprocal_, - aten._foreach_round_, - aten._foreach_sigmoid_, - aten._foreach_sign_, - aten._foreach_sin_, - aten._foreach_sinh_, - aten._foreach_sqrt_, - aten._foreach_tan_, - aten._foreach_tanh_, - aten._foreach_trunc_, - aten._foreach_zero_, - aten._foreach_add_, - aten._foreach_sub_, - aten._foreach_mul_, - aten._foreach_div_, - aten._foreach_clamp_min_, - aten._foreach_clamp_max_, - aten._foreach_lerp_, - aten._foreach_copy_, - ] -) -def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs): - _meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs) - return - - -@register_meta([aten._foreach_pow_.Scalar]) -def meta__foreach_pow__scalar(self, exponent): - torch._check( - isinstance(exponent, FloatLike), - lambda: f"exponent must be a float but got {type(exponent)}", - ) - return - - -@register_meta([aten._foreach_pow.ScalarAndTensor]) -def meta__foreach_pow_scalar_and_tensor(self, exponent): - # Only foreach_pow has a ScalarAndTensor method and needs special - # handling because it does not work with _meta_foreach_out_of_place. - torch._check( - isinstance(exponent, List), - lambda: f"exponent must be a tensor list but got {type(exponent)}", - ) - return [torch.empty_like(e) for e in exponent] - - -@register_meta([aten._foreach_norm]) -def meta__foreach_norm(self, ord=2, dtype=None): - torch._check( - isinstance(self, list), - lambda: f"self must be a tensor list but got {type(self)}", - ) - torch._check( - isinstance(ord, Number), - lambda: f"ord must be an integer but got {type(ord)}", - ) - torch._check( - dtype is None or isinstance(dtype, torch.dtype), - lambda: f"dtype must be either None or torch.dtype but got {type(dtype)}", - ) - return [ - torch.empty( - (), - device=t.device, - dtype=t.dtype.to_real() if dtype is None else dtype.to_real(), - ) - for t in self - ] - - -def _check_foreach_binop_tensor_lists(self, other): - torch._check( - isinstance(self, List) and isinstance(other, List), - lambda: ( - "The first two arguments of must be List[Tensor], " - f"but got {type(self)} and {type(other)}." - ), - ) - torch._check( - len(self) > 0 and len(self) == len(other), - lambda: ( - "self and other must be non-empty and match in length, " - f"but got {len(self)} and {len(other)}." - ), - ) - - -@register_meta( - [ - aten._foreach_maximum, - aten._foreach_minimum, - ] -) -def meta__foreach_binop_scalar(*args): - # aten.maximum(Tensor, Scalar) does not exist. - return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min) - - -@register_meta( - [ - aten._foreach_maximum_, - aten._foreach_minimum_, - ] -) -def meta__foreach_binop__scalar(*args): - # aten.maximum(Tensor, Scalar) does not exist - _meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_) - return - - -@register_meta( - [ - aten._foreach_addcdiv.Scalar, - aten._foreach_addcmul.Scalar, - ] -) -def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1): - # forach_addcdiv and addcdiv have different signatures and - # cannot use _meta_foreach_out_of_place. - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2]), - lambda: ( - "All arguments must be List[Tensor], " - f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) and len(self) == len(tensor2), - lambda: "All input tensor lists must have the same length", - ) - - return [torch.empty_like(s) for s in self] - - -@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor]) -def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars): - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2]) - and isinstance(scalars, torch.Tensor), - lambda: ( - "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, " - f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) and len(self) == len(tensor2), - lambda: "All input tensor lists must have the same length", - ) - - -@register_meta( - [ - aten._foreach_addcdiv_.Scalar, - aten._foreach_addcmul_.Scalar, - ] -) -def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1): - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2]), - lambda: ( - "All arguments of _foreach_addc*_ must be List[Tensor], " - f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) and len(self) == len(tensor2), - lambda: "All input tensor lists must have the same length", - ) - - -@register_meta( - [ - aten._foreach_addcdiv_.ScalarList, - aten._foreach_addcmul_.ScalarList, - ] -) -def meta__foreach_addcop__scalarlist(self, tensor1, tensor2, scalars): - torch._check( - all(isinstance(l, List) for l in [self, tensor1, tensor2, scalars]), - lambda: ( - "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], List[Scalar], " - f"but got {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}" - ), - ) - torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") - torch._check( - len(self) == len(tensor1) - and len(self) == len(tensor2) - and len(self) == len(scalars), - lambda: "All input tensor lists must have the same length", - ) - - @register_meta([aten._fused_adam_.default]) def meta__fused_adam_( self, @@ -3945,6 +3629,14 @@ def meta_masked_fill_(self, mask, value): return self +@register_meta(aten._masked_scale.default) +def meta__masked_scale(self, mask, scale): + masked_scale = self.new_empty(self.size()).to( + memory_format=utils.suggest_memory_format(self) + ) + return masked_scale + + @register_meta(aten.masked_scatter_) def meta_masked_scatter_(self, mask, source): torch._check( @@ -5551,10 +5243,29 @@ def meta__efficient_attention_backward( bias_requires_grad: bool, scale: Optional[float] = None, num_splits_key: Optional[int] = None, + shared_storage_dqdkdv: bool = False, ): - grad_query = torch.empty_like(query) - grad_key = torch.empty_like(key) - grad_value = torch.empty_like(value) + if shared_storage_dqdkdv: + torch._check( + query.shape[1] == key.shape[1], + lambda: "seqlen must match for `shared_storage_dqdkdv", + ) + torch._check( + query.shape[3] == key.shape[3], + lambda: "embedding dim must match for `shared_storage_dqdkdv", + ) + chunk = torch.empty( + (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), + dtype=query.dtype, + device=query.device, + ) + grad_query = chunk.select(-3, 0) + grad_key = chunk.select(-3, 1) + grad_value = chunk.select(-3, 2) + else: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) if bias is not None: lastDim = bias.size(-1) @@ -6344,6 +6055,11 @@ def meta_channel_shuffle(input, groups): ) +@register_meta(aten._local_scalar_dense) +def meta_local_scalar_dense(self: Tensor): + raise RuntimeError("Tensor.item() cannot be called on meta tensors") + + def _create_unary_float_meta_func(func): @register_meta(func) @out_wrapper() diff --git a/torch/_namedtensor_internals.py b/torch/_namedtensor_internals.py index cbc9de2de091..3791d17c2e42 100644 --- a/torch/_namedtensor_internals.py +++ b/torch/_namedtensor_internals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict """ diff --git a/torch/_numpy/_casting_dicts.py b/torch/_numpy/_casting_dicts.py index 513e73ef2efe..b30ce7c55604 100644 --- a/torch/_numpy/_casting_dicts.py +++ b/torch/_numpy/_casting_dicts.py @@ -3,7 +3,7 @@ import torch # These two dicts are autogenerated with autogen/gen_dtypes.py, -# using numpy version 1.23.5. +# using numpy version 1.24.3. _can_cast_dict = { "no": { @@ -14,6 +14,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -27,6 +30,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -40,6 +46,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -53,6 +62,9 @@ torch.complex64: True, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -66,6 +78,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -79,6 +94,57 @@ torch.complex64: False, torch.complex128: False, torch.uint8: True, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: True, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, torch.int8: False, torch.int16: False, torch.int32: False, @@ -92,6 +158,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: False, torch.int32: False, @@ -105,6 +174,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: True, torch.int32: False, @@ -118,6 +190,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: True, @@ -131,6 +206,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -144,6 +222,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -159,6 +240,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -172,6 +256,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -185,6 +272,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -198,6 +288,9 @@ torch.complex64: True, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -211,6 +304,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -224,6 +320,57 @@ torch.complex64: False, torch.complex128: False, torch.uint8: True, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: True, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, torch.int8: False, torch.int16: False, torch.int32: False, @@ -237,6 +384,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: False, torch.int32: False, @@ -250,6 +400,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: True, torch.int32: False, @@ -263,6 +416,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: True, @@ -276,6 +432,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -289,6 +448,9 @@ torch.complex64: False, torch.complex128: False, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -304,6 +466,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -317,6 +482,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -330,6 +498,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -343,6 +514,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -356,6 +530,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -369,12 +546,63 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: False, torch.int16: True, torch.int32: True, torch.int64: True, torch.bool: False, }, + torch.uint16: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, torch.int8: { torch.float16: True, torch.float32: True, @@ -382,6 +610,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -395,6 +626,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: True, torch.int32: True, @@ -408,6 +642,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: True, @@ -421,6 +658,9 @@ torch.complex64: False, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -434,6 +674,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -449,6 +692,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -462,6 +708,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -475,6 +724,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -488,6 +740,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -501,6 +756,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: False, torch.int16: False, torch.int32: False, @@ -514,6 +772,57 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -527,6 +836,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -540,6 +852,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -553,6 +868,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -566,6 +884,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, torch.int8: True, torch.int16: True, torch.int32: True, @@ -579,6 +900,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -594,6 +918,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -607,6 +934,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -620,6 +950,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -633,6 +966,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -646,6 +982,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -659,6 +998,57 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -672,6 +1062,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -685,6 +1078,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -698,6 +1094,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -711,6 +1110,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -724,6 +1126,9 @@ torch.complex64: True, torch.complex128: True, torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, torch.int8: True, torch.int16: True, torch.int32: True, @@ -742,6 +1147,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.float16, + torch.uint16: torch.float32, + torch.uint32: torch.float64, + torch.uint64: torch.float64, torch.int8: torch.float16, torch.int16: torch.float32, torch.int32: torch.float64, @@ -755,6 +1163,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.float32, + torch.uint16: torch.float32, + torch.uint32: torch.float64, + torch.uint64: torch.float64, torch.int8: torch.float32, torch.int16: torch.float32, torch.int32: torch.float64, @@ -768,6 +1179,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.float64, + torch.uint16: torch.float64, + torch.uint32: torch.float64, + torch.uint64: torch.float64, torch.int8: torch.float64, torch.int16: torch.float64, torch.int32: torch.float64, @@ -781,6 +1195,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.complex64, + torch.uint16: torch.complex64, + torch.uint32: torch.complex128, + torch.uint64: torch.complex128, torch.int8: torch.complex64, torch.int16: torch.complex64, torch.int32: torch.complex128, @@ -794,6 +1211,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.complex128, + torch.uint16: torch.complex128, + torch.uint32: torch.complex128, + torch.uint64: torch.complex128, torch.int8: torch.complex128, torch.int16: torch.complex128, torch.int32: torch.complex128, @@ -807,12 +1227,63 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.uint8, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, torch.int8: torch.int16, torch.int16: torch.int16, torch.int32: torch.int32, torch.int64: torch.int64, torch.bool: torch.uint8, }, + torch.uint16: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint16, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int32, + torch.int16: torch.int32, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.uint16, + }, + torch.uint32: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.uint32, + torch.uint16: torch.uint32, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bool: torch.uint32, + }, + torch.uint64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.uint64, + torch.uint16: torch.uint64, + torch.uint32: torch.uint64, + torch.uint64: torch.uint64, + torch.int8: torch.float64, + torch.int16: torch.float64, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.uint64, + }, torch.int8: { torch.float16: torch.float16, torch.float32: torch.float32, @@ -820,6 +1291,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.int16, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int8, torch.int16: torch.int16, torch.int32: torch.int32, @@ -833,6 +1307,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.int16, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int16, torch.int16: torch.int16, torch.int32: torch.int32, @@ -846,6 +1323,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.int32, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int32, torch.int16: torch.int32, torch.int32: torch.int32, @@ -859,6 +1339,9 @@ torch.complex64: torch.complex128, torch.complex128: torch.complex128, torch.uint8: torch.int64, + torch.uint16: torch.int64, + torch.uint32: torch.int64, + torch.uint64: torch.float64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64, @@ -872,6 +1355,9 @@ torch.complex64: torch.complex64, torch.complex128: torch.complex128, torch.uint8: torch.uint8, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, torch.int8: torch.int8, torch.int16: torch.int16, torch.int32: torch.int32, diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index f8b8f4f722be..27799adaf563 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -113,6 +113,24 @@ class uint8(unsignedinteger): torch_dtype = torch.uint8 +class uint16(unsignedinteger): + name = "uint16" + typecode = "H" + torch_dtype = torch.uint16 + + +class uint32(signedinteger): + name = "uint32" + typecode = "I" + torch_dtype = torch.uint32 + + +class uint64(signedinteger): + name = "uint64" + typecode = "L" + torch_dtype = torch.uint64 + + # floating point @@ -160,6 +178,7 @@ class bool_(generic): "byte": int8, "short": int16, "longlong": int64, # XXX: is this correct? + "ulonglong": uint64, "ubyte": uint8, "half": float16, "single": float32, @@ -180,7 +199,7 @@ class bool_(generic): # cf tests/core/test_scalar_methods.py sctypes = { "int": [int8, int16, int32, int64], - "uint": [uint8], + "uint": [uint8, uint16, uint32, uint64], "float": [float16, float32, float64], "complex": [complex64, complex128], "others": [bool_], diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index de165d5db768..93f8a8ab1198 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -941,7 +941,7 @@ def choose( return choices[idx_list].squeeze(0) -# ### unique et al ### +# ### unique et al. ### def unique( @@ -1021,7 +1021,7 @@ def resize(a: ArrayLike, new_shape=None): return reshape(a, new_shape) -# ### diag et al ### +# ### diag et al. ### def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): diff --git a/torch/_ops.py b/torch/_ops.py index 83a7b6b849df..ed8c788b8af6 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import ctypes import importlib diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 498b6fa9a2cb..603658ea6151 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import itertools import operator diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 2c7a030b3509..81cc47dc86e5 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from contextlib import nullcontext from typing import Any, Callable, Dict, Optional, Sequence diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py index ea3854d04bbd..9683c163827d 100644 --- a/torch/_prims/debug_prims.py +++ b/torch/_prims/debug_prims.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Optional diff --git a/torch/_prims/executor.py b/torch/_prims/executor.py index bb2fafce8726..8d80af720e79 100644 --- a/torch/_prims/executor.py +++ b/torch/_prims/executor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Optional from torch._prims.context import TorchRefsMode diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 616940d57036..1345ff0334f5 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Tuple import torch diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 10290535f930..c05b0ebf10e7 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import operator @@ -1045,17 +1046,17 @@ def type_to_dtype(typ: type) -> torch.dtype: assert isinstance(typ, type) - if typ is bool: + if typ in (bool, torch.SymBool): return torch.bool - if typ in [int, torch.SymInt]: + if typ in (int, torch.SymInt): return torch.long - if typ in [float, torch.SymFloat]: + if typ in (float, torch.SymFloat): return torch.get_default_dtype() # TODO: sym_complex_float? if typ is complex: return corresponding_complex_dtype(torch.get_default_dtype()) - raise ValueError("Invalid type!") + raise ValueError(f"Invalid type {typ}!") def get_dtype(x: Union[torch.Tensor, NumberType]): @@ -1362,8 +1363,12 @@ def number_type( return type(x) -def expr_type(x: sympy.Expr) -> Type: - if x.is_integer: # type: ignore[attr-defined] +def expr_type(x: sympy.Basic) -> Type: + import sympy + + if x.kind is sympy.core.kind.BooleanKind: + return bool + elif x.is_integer: # type: ignore[attr-defined] return int else: # NB: Not strictly correct, but we don't support SymPy complex or bool. @@ -1470,13 +1475,13 @@ def elementwise_dtypes( import sympy for x in args: - if not isinstance(x, (Number, TensorLike, sympy.Expr)): + if not isinstance(x, (Number, TensorLike, sympy.Basic)): msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" raise ValueError(msg) if isinstance(x, Number): highest_type = get_higher_type(highest_type, number_type(x)) - elif isinstance(x, sympy.Expr): + elif isinstance(x, sympy.Basic): highest_type = get_higher_type(highest_type, expr_type(x)) else: # x is a TensorLike diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 9057edc87594..89088aaaf049 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import warnings from functools import wraps diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py index bfd208eddb9e..644cf92fda2b 100644 --- a/torch/_python_dispatcher.py +++ b/torch/_python_dispatcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re import torch._C as C diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 68675c751736..e0157368c62c 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import collections import inspect @@ -232,6 +233,7 @@ # View & Shape Ops # "alias", + "alias_copy", "atleast_1d", "atleast_2d", "atleast_3d", @@ -3649,6 +3651,16 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL else: return _a + if a.is_contiguous(): + # Special-cases for nd_to_1d + if len(shape) == 1 and a.ndim > 1: + return torch.as_strided(a, [a.numel()], [1]) + # Special-cases for 1d_to_2d + if len(shape) == 2 and a.ndim == 1: + dim0 = shape[0] + dim1 = shape[1] + return torch.as_strided(a, [dim0, dim1], [dim1, 1]) + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape # NOTE [Reshape Algorithm] @@ -4451,6 +4463,9 @@ def alias(a: TensorLikeType) -> TensorLikeType: return prims.view_of(a) +alias_copy = _make_copy_from_view(alias) + + @register_decomposition(aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py index fa1ca2428255..b312f8f6eada 100644 --- a/torch/_refs/_conversions.py +++ b/torch/_refs/_conversions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch._prims_common as utils diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index bffc9a3df2c8..411087b773ea 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import partial from typing import List, Optional, Tuple, Union diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index dd06febbcd6c..8383d888bbe8 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from functools import wraps from typing import Callable, Optional, Union diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 14ec33cf208f..1e98deaeb16d 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Optional, Union diff --git a/torch/_size_docs.py b/torch/_size_docs.py index 58587be32f1d..b678e3dfd12a 100644 --- a/torch/_size_docs.py +++ b/torch/_size_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to torch.Size functions""" import torch._C diff --git a/torch/_sources.py b/torch/_sources.py index 3f56bd8ef247..dd2a863bfc7e 100644 --- a/torch/_sources.py +++ b/torch/_sources.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import functools import inspect diff --git a/torch/_storage_docs.py b/torch/_storage_docs.py index 5d6df58d2b6b..edf5d696ad89 100644 --- a/torch/_storage_docs.py +++ b/torch/_storage_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to Storage functions""" import torch._C diff --git a/torch/_streambase.py b/torch/_streambase.py index b06946523fa3..85e203a3d993 100644 --- a/torch/_streambase.py +++ b/torch/_streambase.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod diff --git a/torch/_strobelight/examples/cli_function_profiler_example.py b/torch/_strobelight/examples/cli_function_profiler_example.py index 8142ef1bdc77..2ddf62f065f5 100644 --- a/torch/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/_strobelight/examples/cli_function_profiler_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._strobelight.cli_function_profiler import ( diff --git a/torch/_strobelight/examples/compile_time_profile_example.py b/torch/_strobelight/examples/compile_time_profile_example.py index 338727206076..93fffa4ad01a 100644 --- a/torch/_strobelight/examples/compile_time_profile_example.py +++ b/torch/_strobelight/examples/compile_time_profile_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 4376d24255ef..90b2c878ab2a 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -258,9 +258,8 @@ def dyn_shape(fake_mode, func, *args, **kwargs): raise DynamicOutputShapeException(func) -@register_op_impl(aten._unique2.default) -def unique2( - fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +def _unique( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False ): if ( fake_mode.shape_env is None @@ -269,7 +268,8 @@ def unique2( # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) - if (nnz := arg.unique_memo) is None: + # Do not use a memo for unique_dim + if dim is not None or (nnz := arg.unique_memo) is None: # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, @@ -291,28 +291,59 @@ def unique2( maxval = sys.maxsize - 1 - if not has_free_symbols(arg.numel()): - maxval = int(arg.numel()) + numel = arg.numel() if dim is None else arg.size(dim) + if not has_free_symbols(numel): + maxval = int(numel) _constrain_range_for_size(nnz, max=maxval) - arg.unique_memo = nnz + if dim is None: + arg.unique_memo = nnz - ret = [arg.new_empty((nnz,))] + if dim is None: + ret = [arg.new_empty((nnz,))] + else: + ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] - if return_inverse: - ret.append(torch.empty_like(arg)) + return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") + if return_inverse or return_if_dim_and_cpu: + inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) else: - ret.append(arg.new_empty(0)) + inverse = arg.new_empty(0) + ret.append(inverse) - if return_counts: - ret.append(torch.empty_like(arg)) + if return_counts or return_if_dim_and_cpu: + counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) else: - ret.append(arg.new_empty(0)) + counts = arg.new_empty(0) + ret.append(counts) return tuple(ret) +@register_op_impl(aten._unique2.default) +def unique2( + fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +): + return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) + + +@register_op_impl(aten.unique_dim.default) +def unique_dim( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False +): + return _unique( + fake_mode, + func, + arg, + # normalize dim to be non-negative + dim if dim >= 0 else dim % max(arg.ndim, 1), + sorted, + return_inverse, + return_counts, + ) + + @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: @@ -872,7 +903,7 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): max_seqlen_q = kwargs["max_seqlen_q"] max_seqlen_k = kwargs["max_seqlen_k"] compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k + # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k def convert_tensor(t, device): return FakeTensor(fake_mode, t, device) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 2c75847c92a1..c5a549860f47 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import logging @@ -291,6 +292,7 @@ def from_real_tensor( *, source=None, symbolic_context=None, + trace=True, ): # see note [Tensor Fakification and Symbol Caching] if not symbolic_context and not source and shape_env: @@ -333,6 +335,7 @@ def mk_fake_tensor(make_meta_t): callback=mk_fake_tensor, source=source, symbolic_context=symbolic_context, + trace=trace, ) if out is NotImplemented: raise UnsupportedFakeTensorException("meta converter nyi") @@ -1725,7 +1728,7 @@ def go(t, real_t): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: + if op_impl_out is not NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): @@ -1925,6 +1928,7 @@ def from_tensor( static_shapes=None, source: Optional[Source] = None, symbolic_context=None, + trace=True, ): shape_env: Optional[ShapeEnv] = self.shape_env if static_shapes is None: @@ -1940,6 +1944,7 @@ def from_tensor( shape_env=shape_env, source=source, symbolic_context=symbolic_context, + trace=trace, ) diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index dfef5951ab26..4040774fe225 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import warnings from abc import ABC, abstractmethod diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 647c03861768..4ea0db56aae2 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,13 +1,16 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib + +import dataclasses import warnings import weakref - from dataclasses import dataclass from typing import ( Any, Callable, + ClassVar, ContextManager, Dict, List, @@ -20,6 +23,7 @@ from typing_extensions import TypeAlias import torch +from torch._C._autograd import CreationMeta from torch._C._functorch import ( _add_batch_dim, _unwrap_functional_tensor, @@ -33,13 +37,13 @@ maybe_get_level, peek_interpreter_stack, ) +from torch._logging import trace_structured from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils.weak import WeakIdKeyDictionary if TYPE_CHECKING: - from torch._C._autograd import CreationMeta from torch._C._functorch import CInterpreter from torch._guards import Source @@ -142,6 +146,9 @@ def is_sparse_any(t): MetaTensorId: TypeAlias = int +DESCRIBER_NEXT_ID = 0 + + class MetaTensorDescriber: """ Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc @@ -154,6 +161,9 @@ class MetaTensorDescriber: """ def __init__(self, *, copy_data=False): + global DESCRIBER_NEXT_ID + self.id = DESCRIBER_NEXT_ID + DESCRIBER_NEXT_ID += 1 self.next_tensor_id: MetaTensorId = 0 self.next_storage_id: MetaStorageId = 0 # Tensor -> int @@ -161,6 +171,8 @@ def __init__(self, *, copy_data=False): # Storage -> int self.lookup_storage = WeakIdKeyDictionary() self.copy_data = copy_data + self.traced_tensors = set() + self.traced_storages = set() def get_tensor_id(self, t: torch.Tensor): if t not in self.lookup_tensor: @@ -174,19 +186,25 @@ def get_storage_id(self, s: torch.UntypedStorage): self.next_storage_id += 1 return self.lookup_storage[s] - # NB: the describe functions NOT maintain a cache and will happily regen the - # description - - def describe_storage(self, s: torch.UntypedStorage): - return MetaStorageDesc( + def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False): + r = MetaStorageDesc( id=self.get_storage_id(s), size=s.size(), # NB: We don't do the copy yet; copy happens when we start # creating the new storages data=s if self.copy_data else None, ) + if trace and r.id not in self.traced_storages: + trace_structured( + "describe_storage", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_storages.add(r.id) + return r - def describe_tensor(self, t: torch.Tensor, recurse: bool = True): + def describe_tensor( + self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False + ): is_leaf = safe_is_leaf(t) is_view = t._is_view() is_sparse = t.is_sparse @@ -218,7 +236,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): ): # NB: We actually don't use storage to do views, but might as well # put it in for accuracy - storage = self.describe_storage(t.untyped_storage()) + storage = self.describe_storage(t.untyped_storage(), trace=trace) storage_offset = t.storage_offset() stride = None @@ -239,7 +257,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): autograd_meta_from = None current_level = None if is_batchedtensor_v or is_gradtrackingtensor_v: - unwrapped = self.describe_tensor(get_unwrapped(t)) + unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) # xla and lazy tensors present as functional tensors, but we want them # to be handled specially elif is_functional and t.device.type not in ("xla", "lazy"): @@ -249,13 +267,15 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): ) if not is_functorch_wrapped: torch._sync(t) - unwrapped = self.describe_tensor(torch._from_functional_tensor(t)) + unwrapped = self.describe_tensor( + torch._from_functional_tensor(t), trace=trace + ) autograd_meta_from = t else: reapply_views = torch._C._functionalization_reapply_views_tls() # NB: has side effects! unwrapped = self.describe_tensor( - _unwrap_functional_tensor(t, reapply_views) + _unwrap_functional_tensor(t, reapply_views), trace=trace ) # TODO: It's pretty suspicious that functional tensors don't have # valid level and thus we just grab whatever the current level @@ -273,12 +293,15 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): if is_traceable_wrapper_subclass_v: assert hasattr(t, "__tensor_flatten__") raw_attrs, ctx = t.__tensor_flatten__() - attrs = {attr: self.describe_tensor(getattr(t, attr)) for attr in raw_attrs} + attrs = { + attr: self.describe_tensor(getattr(t, attr), trace=trace) + for attr in raw_attrs + } type_v = type(t) # TODO: Is it important to enable torch.inference_mode before querying # these values? - return MetaTensorDesc( + r = MetaTensorDesc( id=self.get_tensor_id(t), storage=storage, is_inference=t.is_inference(), @@ -301,6 +324,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): is_view=is_view, is_conj=t.is_conj(), is_neg=t.is_neg(), + is_parameter=isinstance(t, torch.nn.Parameter), is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, is_nested=is_nested, is_functional=is_functional, @@ -318,22 +342,30 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): # TODO: I actually think recursing here is correct, but we have at # least an infinite cycle from base -> values -> base # https://github.com/pytorch/pytorch/issues/122089 - crow_indices=self.describe_tensor(t.crow_indices(), recurse=False) + crow_indices=self.describe_tensor( + t.crow_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} else None, - col_indices=self.describe_tensor(t.col_indices(), recurse=False) + col_indices=self.describe_tensor( + t.col_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} else None, - ccol_indices=self.describe_tensor(t.ccol_indices(), recurse=False) + ccol_indices=self.describe_tensor( + t.ccol_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} else None, - row_indices=self.describe_tensor(t.row_indices(), recurse=False) + row_indices=self.describe_tensor( + t.row_indices(), recurse=False, trace=trace + ) if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} else None, - values=self.describe_tensor(t.values(), recurse=False) + values=self.describe_tensor(t.values(), recurse=False, trace=trace) if recurse and is_sparse_compressed(t) else None, - grad=self.describe_tensor(safe_grad(t)) + grad=self.describe_tensor(safe_grad(t), trace=trace) if safe_grad(t) is not None else None, creation_meta=torch._C._autograd._get_creation_meta(t) @@ -344,7 +376,7 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): if is_batchedtensor_v or is_gradtrackingtensor_v else None, bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, - base=self.describe_tensor(t._base) + base=self.describe_tensor(t._base, trace=trace) if recurse and t._is_view() and t._base is not None else None, fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), @@ -360,6 +392,13 @@ def describe_tensor(self, t: torch.Tensor, recurse: bool = True): current_level=current_level, data=t if self.copy_data else None, ) + if trace and r.id not in self.traced_tensors: + trace_structured( + "describe_tensor", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_tensors.add(r.id) + return r @dataclass(frozen=True) @@ -370,43 +409,59 @@ class MetaStorageDesc: # serializable in JSON, you want to do something special here anyway data: Optional[torch.UntypedStorage] + def as_json(self, describer_id): + return { + "id": self.id, + "describer_id": describer_id, + "size": self.size if isinstance(self.size, int) else repr(self.size), + } + @dataclass(frozen=True) class MetaTensorDesc: id: MetaTensorId - is_inference: bool - is_leaf: bool - requires_grad: bool ndim: int dtype: torch.dtype - is_sparse: bool - is_mkldnn: bool - is_functorch_wrapped: bool - is_batchedtensor: bool - is_legacy_batchedtensor: bool - is_gradtrackingtensor: bool - is_view: bool - is_nested: bool - is_traceable_wrapper_subclass: bool - is_functional: bool - is_conj: bool - is_neg: bool device: torch.device - layout: torch.layout + # NB: Sometimes, size, stride and storage_offset contain SymInt, in which # case this is NOT serializable. That only happens when you're # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we # can get rid of this use case entirely. Notably, even if we are # fakeifying a real tensor into a fake tensor with symbolic shapes, the # size here is NOT dynamic + # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic + # goes through this codepath. But it really should not LOL. # NB: size could potentially be None as you can override it and make it # throw an error, but we don't currently have any subclasses that do this # except C++ nested tensor but we're going to have nested int to make this # defined on NJT size: Tuple[int, ...] dynamo_dynamic_indices: List[int] + + layout: torch.layout = torch.strided + is_inference: bool = False + is_leaf: bool = False + requires_grad: bool = False + is_sparse: bool = False + is_mkldnn: bool = False + is_functorch_wrapped: bool = False + is_batchedtensor: bool = False + is_legacy_batchedtensor: bool = False + is_gradtrackingtensor: bool = False + is_view: bool = False + is_nested: bool = False + is_traceable_wrapper_subclass: bool = False + is_functional: bool = False + is_conj: bool = False + is_neg: bool = False + is_parameter: bool = False stride: Optional[Tuple[int, ...]] = None storage_offset: int = 0 + # NB: We have a choice whether or not to store the id or a direct pointer + # to the data structure. For ease of use, we store the data structure, + # but this means that when we serialize, we have to swizzle these pointers + # back into ids (so we have accurate aliasing relationships) storage: Optional[MetaStorageDesc] = None sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed @@ -424,6 +479,19 @@ class MetaTensorDesc: grad: Optional[MetaTensorDesc] = None # Everything below is NOT serializable, need some more work + + _UNSERIALIZABLE: ClassVar[List[str]] = [ + "ctx", + "type", + "fake_mode", + "view_func", + "level", + "current_level", + "functorch_stack", + "autograd_meta_from", + "data", + ] + ctx: Optional[object] = None # is_traceable_wrapper_subclass type: Optional[Type] = None # is_traceable_wrapper_subclass fake_mode: Optional[FakeTensorMode] = None @@ -459,6 +527,44 @@ class MetaTensorDesc: # entirely clear how to make it all lexical again, so we haven't done # it for now. + # NB: This will reference numeric IDs, and it is assumed that you've + # already serialized everything this recursively references + def as_json(self, describer_id): + def json(k, v): + # Some best-effort debugging serialization for unserializable + # fields (feel free to add other special cases as appropriate) + if k in ["data", "autograd_meta_from"]: + return None # never repr these + if k in set(MetaTensorDesc._UNSERIALIZABLE): + return repr(v) + if isinstance(v, (torch.device, torch.dtype, torch.layout)): + return repr(v) + if isinstance(v, torch.SymInt): + return repr(v) + if isinstance(v, (tuple, list)): + return [json(k, v1) for v1 in v] + if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): + return v.id + if isinstance(v, CreationMeta): + return str(v) + if k == "attrs" and isinstance(v, dict): + return {k1: v1.id for k1, v1 in v.items()} + return v + + r = { + field.name: json(field.name, getattr(self, field.name)) + for field in dataclasses.fields(self) + if not ( + getattr(self, field.name) is field.default + or ( + field.name == "dynamo_dynamic_indices" + and not getattr(self, field.name) + ) + ) + } + r.update({"describer_id": describer_id}) + return r + @property def shape(self): return self.size @@ -887,9 +993,10 @@ def symint_visitor_fn(s): def tensor_visitor_fn( visited_t: torch.Tensor, + # These arguments are never passed, we just use them to close + # over these relevant values shape_env=shape_env, callback=callback, - source=source, ): # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: @@ -1431,6 +1538,10 @@ def is_c_of_r(complex_dtype, real_dtype): # Need to reflect this in the generated FakeTensor. if t.storage is not None and t.storage.size == 0: r.untyped_storage().resize_(0) + + if t.is_parameter: + r._is_param = True + self.set_tensor_memo(t, r) return self.get_tensor_memo(t) @@ -1443,6 +1554,10 @@ def __call__( callback=lambda t: t(), source=None, symbolic_context=None, + # Controls whether or not we should dump the tensor metadata to structured logs + # when source is not None. Because we refakify after Dynamo is done, + # we don't want to dump info again from AOTAutograd, it is redundant. + trace=True, ): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now @@ -1475,9 +1590,22 @@ def __call__( # non-Tensor types don't count as hit or miss return t + if source is None: + trace = False + # Describe the tensor. NB: do NOT disable ambient modes, we may need # to query them when figuring out what to put in here - t_desc = self.describer.describe_tensor(t) + t_desc = self.describer.describe_tensor(t, trace=trace) + + if trace: + trace_structured( + "describe_source", + metadata_fn=lambda: { + "describer_id": self.describer.id, + "id": t_desc.id, + "source": source.name(), + }, + ) # Do the meta-fication. Here, we disable all the ambient modes, to # better simulate what would be like to re-fakeify from a fresh diff --git a/torch/_tensor.py b/torch/_tensor.py index 712cbc3863d8..5ea2985c2d3f 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copyreg import enum import functools diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 88cae5b27aa3..07d94b57f791 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to Tensor functions""" import torch._C diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index eddbe4d8b729..461f3a26b58a 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import dataclasses import math diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ab244dab2635..ad44998d92dc 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Adds docstrings to functions defined in the torch._C module.""" import re diff --git a/torch/_utils.py b/torch/_utils.py index d2bb59239a30..5096b62618df 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copyreg import functools import logging @@ -7,7 +8,7 @@ import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Generic, List, Optional -from typing_extensions import deprecated, ParamSpec +from typing_extensions import ParamSpec import torch @@ -852,10 +853,6 @@ def classproperty(func): return _ClassPropertyDescriptor(func) -@deprecated( - "`is_compiling` is deprecated. Use `torch.compiler.is_compiling()` instead.", - category=FutureWarning, -) def is_compiling() -> bool: """ Indicates whether we are tracing/compiling with torch.compile() or torch.export(). diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 91b7a3722f55..0001888f18ed 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import os diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index 465e5dbdca1b..cc23d7851eb5 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any, Callable, List, Optional, Tuple, Union from typing_extensions import deprecated diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 6c9f3b61ae8b..2ca07d15136c 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Unpickler restricted to loading only state dicts # Restrict constructing types to a list defined in _get_allowed_globals() # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only @@ -23,7 +24,6 @@ import functools as _functools from collections import Counter, OrderedDict -from inspect import getattr_static from pickle import ( APPEND, APPENDS, @@ -64,8 +64,8 @@ UnpicklingError, ) from struct import unpack -from sys import maxsize, modules -from typing import Any, Dict, List, Type +from sys import maxsize +from typing import Any, Dict, List import torch @@ -170,11 +170,6 @@ def __init__(self, file, *, encoding: str = "bytes"): self.readline = file.readline self.read = file.read self.memo: Dict[int, Any] = {} - # tensor subclass types found from GLOBAL instructions that have passed the criteria - # to be allowed as the second argument to `torch._tensor._rebuild_from_type_v2` - # This enables rebuilding of tensor subclasses defined outside the `torch` package. - # See [Note: Criteria for allowing out-of-core tensor subclasses] for details on the criteria. - self.tensor_subclasses_found: Dict[str, Type] = {} def load(self): """Read a pickled object representation from the open file. @@ -201,121 +196,11 @@ def load(self): elif full_path in _get_user_allowed_globals(): self.append(_get_user_allowed_globals()[full_path]) else: - # The logic in this branch handles user-defined tensor subclasses. - # We can automatically allow and raise and error for anything that is not provably safe. - # [Note: Criteria for allowing out-of-core tensor subclasses] - # GLOBAL '.' instructions will get the class and - # push the string (not the actual type) while adding the type to the dictionary keyed - # by the string onto the unpickler's stack if they satisfy the following conditions: - # (1) The that defines them is in `sys.modules` - # (we will use getattr_static to access it to ensure no code execution) - # (2) They inherit from `torch.Tensor` - # (2) The class is not overriding any of the `torch.Tensor` methods listed here: - # `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, `__set__`, - # and `tp_alloc` - # The methods that we ban overriding were selected in a test-driven manner - # by overriding every callable method on a tensor subclass and determinining - # which might get called during unpickling. - # When executing REDUCE, the string will be appropriately converted back to the type only - # for `torch._tensor._rebuild_from_type_v2` as other use of the class could use methods - # we didn't audit. - if module == "__builtin__": - raise RuntimeError( - f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " - "Please use `torch.serialization.add_safe_globals` to allowlist this global " - "if you trust this class/function." - ) - elif module not in modules: - # TODO: add a link here to a doc that explains to users what we mean by trust - raise RuntimeError( - f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was " - f"not in the pre-defined list of allowed globals that are considered safe by the " - "weights_only unpickler for rebuilding state_dicts. This is the expected behavior if " - f"`{full_path}` is a class or function that is not in the list of allowed globals " - f"If `{full_path}` is NOT a tensor subclass, you might consider" - "`torch.serialization.add_safe_globals` if it is appropriate. However, if it is a " - "user-defined tensor subclass not defined in the `torch` package, this error might arise " - f"as we expect `{module}` to be present in `sys.modules` (i.e. it " - "must be imported in the current environment), but this was not the case. " - f"If you intend to unpickle a tensor subclass `{full_path}` please import `{name}` from " - f"`{module}`. Note that having this imported will *only* allow the type `{full_path}` to " - "be passed as the second argument to `torch._tensor._rebuild_from_type_v2`, which should " - "enable the tensor subclass to be unpickled without any arbitrary code execution as long " - # If the user imports and these are overridden the next error will prompt them to use - # torch.serialization.add_safe_globals. - "a sa pre-defined list of methods called when unpickling are not overridden. In " - "particular, the methods are `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, " - "`__set__`, as well as the implementation of `tp_alloc`." - ) - else: - try: - class_type = getattr_static(modules[module], name) - except AttributeError as e: - raise AttributeError( - "For safety during weights_only loading, we use inspect.getattr_state to " - f"get {name} from {module}, if {module} implements the descriptor protocol, " - "__getattr__ or __getattribute__ these will not be called." - ) from e - # None of the objects here contain any data from the pickle so this is safe - if isinstance(class_type, type) and issubclass( - class_type, torch.Tensor - ): - # getattr is called by the getattr call in `_rebuild_from_type_v2` - custom_get_attribute = ( - class_type.__getattribute__ - is not torch.Tensor.__getattribute__ - ) - custom_get = ( - getattr_static(class_type, "__get__", None) is not None - ) - custom_get_attr = ( - getattr_static(class_type, "__getattr__", None) - is not None - ) - # Tensor.__setstate__ might be called in `_rebuild_from_type_v2` - custom_set_state = ( - class_type.__setstate__ is not torch.Tensor.__setstate__ - ) - # setattr is called in `torch._utils._set_obj_state` - custom_set_attr = ( - class_type.__setattr__ is not object.__setattr__ - ) - custom_set = ( - getattr_static(class_type, "__set__", None) is not None - ) - # tp_alloc is called by `Tensor._rebuild_wrapper_subclass` and `Tensor.as_subclass` - has_custom_tp_alloc = ( - not torch._C._check_tp_alloc_is_default(class_type) - ) - custom_methods = { - "__getattribute__": custom_get_attribute, - "__getattr__": custom_get_attr, - "__get__": custom_get, - "__setattr__": custom_set_attr, - "__set__": custom_set, - "__setstate__": custom_set_state, - "tp_alloc": has_custom_tp_alloc, - } - if any(custom_methods.values()): - error = "" - for k, v in custom_methods.items(): - error += f" {k}={v}" - raise RuntimeError( - f"Trying to unpickle tensor subclass `{full_path}` that has defined a custom " - f"version for one of these methods:{error}. Please check whether you trust these " - "methods and allowlist the subclass with `torch.serialization.add_safe_globals` if so." - ) - # push the string full_path onto the stack (in REBUILD, there is special logic to - # access this from tensor_subclasses_found for rebuild_from_type_v2) - self.tensor_subclasses_found[full_path] = class_type - self.append(full_path) - else: - raise RuntimeError( - f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " - "Please use `torch.serialization.add_safe_globals` to allowlist this global " - "if you trust this class/function." - ) - + raise RuntimeError( + f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. " + "Please use `torch.serialization.add_safe_globals` to allowlist this global " + "if you trust this class/function." + ) elif key[0] == NEWOBJ[0]: args = self.stack.pop() cls = self.stack.pop() @@ -332,26 +217,6 @@ def load(self): raise RuntimeError( f"Trying to call reduce for unrecognized function {func}" ) - # Special handling for tensor subclass type found in GLOBAL that is pushed - # onto stack as str to prevent it from being used anywhere except the - # second arg of _rebuild_from_type_v2 and within argument tuple for _rebuild_wrapper_subclass - # _rebuild_from_type_v2 is called with args (func, type, func_args, state) - # where both type and, when func is rebuild_wrapper_subclass, func_args[0] could be the subclass type - # Since we pushed these subclass types onto the stack as strings, convert them to the actual - # type here. - if func is torch._tensor._rebuild_from_type_v2 and type(args[1]) is str: - args_after = args[2:] - if ( - args[0] is torch._utils._rebuild_wrapper_subclass - and type(args[2][0]) is str - ): - new_arg_tuple = ( - self.tensor_subclasses_found[args[2][0]], - ) + args[2][1:] - args_after = (new_arg_tuple,) + args[3:] - args = ( - args[:1] + (self.tensor_subclasses_found[args[1]],) + args_after - ) self.stack[-1] = func(*args) elif key[0] == BUILD[0]: state = self.stack.pop() diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index e33533d2c833..ad8892a3099d 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import warnings diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index a72c6246c99e..bb5cf8204c08 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import inspect diff --git a/torch/ao/__init__.py b/torch/ao/__init__.py index fe6f3a460316..32b1048ad35d 100644 --- a/torch/ao/__init__.py +++ b/torch/ao/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # torch.ao is a package with a lot of interdependencies. # We will use lazy import to avoid cyclic dependencies here. diff --git a/torch/ao/nn/__init__.py b/torch/ao/nn/__init__.py index 88a5a03af1cc..4041508e0b9b 100644 --- a/torch/ao/nn/__init__.py +++ b/torch/ao/nn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # We are exposing all subpackages to the end-user. # Because of possible inter-dependency, we want to avoid # the cyclic imports, thus implementing lazy version diff --git a/torch/ao/nn/intrinsic/__init__.py b/torch/ao/nn/intrinsic/__init__.py index a18bae3eaa38..ca446141106f 100644 --- a/torch/ao/nn/intrinsic/__init__.py +++ b/torch/ao/nn/intrinsic/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .modules import * # noqa: F403 from .modules.fused import _FusedModule # noqa: F403 diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index 4fff70cd76b2..a02365318104 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.nn.utils.parametrize import type_before_parametrizations diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 906206e18e64..91a25a11d50b 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch import torch.nn as nn @@ -289,7 +290,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss missing_keys, unexpected_keys, error_msgs) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities @@ -453,8 +454,8 @@ def forward(self, input): return F.relu(ConvBn1d._forward(self, input)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with @@ -490,8 +491,8 @@ def forward(self, input): self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class ConvBn2d(_ConvBnNd, nn.Conv2d): r""" @@ -585,8 +586,8 @@ def forward(self, input): return F.relu(ConvBn2d._forward(self, input)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with @@ -622,8 +623,8 @@ def forward(self, input): self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class ConvBn3d(_ConvBnNd, nn.Conv3d): r""" @@ -758,8 +759,8 @@ def forward(self, input): return F.relu(ConvBn3d._forward(self, input)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with @@ -813,8 +814,8 @@ def forward(self, input): ) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) def update_bn_stats(mod): if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}: diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py index 5b67283dce4b..89b3a55ff7d2 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.intrinsic as nni @@ -133,7 +134,7 @@ def train(self, mode=True): return self @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module or qparams_dict Args: `mod' a float module, either produced by torch.ao.quantization diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py index 97f7a1dbc339..49cea103982f 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.qat as nnqat import torch.ao.nn.intrinsic as nni @@ -36,8 +37,8 @@ def forward(self, input): return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) def to_float(self): linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None) diff --git a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index a0bccdc0e3d3..b8bff1f5e3a9 100644 --- a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.quantized.dynamic as nnqd import torch.ao.nn.intrinsic as nni @@ -47,8 +48,8 @@ def _get_name(self): return 'DynamicQuantizedLinearReLU' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qlinear_relu): diff --git a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py index 856fa43aac99..eb5104d8c409 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic @@ -37,9 +38,9 @@ def _get_name(self): return 'QuantizedBNReLU2d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): # TODO: Add qat support for BNReLU2d - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, bn_relu, output_scale, output_zero_point): @@ -73,9 +74,9 @@ def _get_name(self): return 'QuantizedBNReLU3d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): # TODO: Add qat support for BNReLU3d - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, bn_relu, output_scale, output_zero_point): diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 6e46aa8915e4..e7df10597331 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic import torch.ao.nn.intrinsic.qat @@ -42,8 +43,8 @@ def _get_name(self): return 'QuantizedConvAdd2d' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): @@ -85,8 +86,8 @@ def _get_name(self): return 'QuantizedConvAddReLU2d' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py index 5cdc9004c99c..1ff34f9f5f20 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic @@ -53,13 +54,13 @@ def _get_name(self): return 'QuantizedConvReLU1d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): @@ -103,13 +104,13 @@ def _get_name(self): return 'QuantizedConvReLU2d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): @@ -154,7 +155,7 @@ def _get_name(self): return 'QuantizedConvReLU3d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( @@ -166,7 +167,7 @@ def from_float(cls, mod): mod.bn.weight, mod.bn.bias, ) - return super().from_float(mod) + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_qconv, output_scale, output_zero_point): diff --git a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py index e774a72dc822..38cb543f4001 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.quantized as nnq import torch.ao.nn.intrinsic as nni @@ -40,8 +41,8 @@ def _get_name(self): return 'QuantizedLinearReLU' @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): @@ -77,7 +78,7 @@ def _get_name(self): return 'QuantizedLinearLeakyReLU' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process @@ -144,7 +145,7 @@ def _get_name(self): return 'QuantizedLinearTanh' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' activation_post_process = mod.activation_post_process diff --git a/torch/ao/nn/qat/dynamic/modules/linear.py b/torch/ao/nn/qat/dynamic/modules/linear.py index c93dfab1f15b..dd3c06953597 100644 --- a/torch/ao/nn/qat/dynamic/modules/linear.py +++ b/torch/ao/nn/qat/dynamic/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ["Linear"] diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 2b588d84a74e..896bb2d243bd 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch.nn.modules.utils import _single, _pair, _triple @@ -44,7 +45,7 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @staticmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module Args: @@ -150,8 +151,8 @@ def __init__(self, dtype=dtype) @classmethod - def from_float(cls, mod): - return super().from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv2d(_ConvNd, nn.Conv2d): r""" @@ -208,8 +209,8 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod): - return super().from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv3d(_ConvNd, nn.Conv3d): r""" @@ -266,5 +267,5 @@ def forward(self, input): return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod): - return super().from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) diff --git a/torch/ao/nn/qat/modules/embedding_ops.py b/torch/ao/nn/qat/modules/embedding_ops.py index da7f33363742..4269db4abed5 100644 --- a/torch/ao/nn/qat/modules/embedding_ops.py +++ b/torch/ao/nn/qat/modules/embedding_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor import torch.nn as nn @@ -42,7 +43,7 @@ def forward(self, input) -> Tensor: self.sparse) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module Args: `mod` a float module, either produced by torch.ao.quantization utilities @@ -112,7 +113,7 @@ def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: self.padding_idx) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module Args: `mod` a float module, either produced by torch.ao.quantization utilities diff --git a/torch/ao/nn/qat/modules/linear.py b/torch/ao/nn/qat/modules/linear.py index 99d43ed3f6c2..67573a427bae 100644 --- a/torch/ao/nn/qat/modules/linear.py +++ b/torch/ao/nn/qat/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.nn.functional as F @@ -41,7 +42,7 @@ def forward(self, input): return F.linear(input, self.weight_fake_quant(self.weight), self.bias) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index 0023faaaa162..8a45499fd80f 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.jit # this is needed to avoid a circular import from torch import nn @@ -338,6 +339,7 @@ def _forward_impl(self, warnings.warn( "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " "Use bool tensor instead.", + stacklevel=3, ) attn_mask = attn_mask.to(torch.bool) assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ @@ -359,6 +361,7 @@ def _forward_impl(self, warnings.warn( "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " "Use bool tensor instead.", + stacklevel=3, ) key_padding_mask = key_padding_mask.to(torch.bool) if self.bias_k is not None and self.bias_v is not None: diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index 2c57d1ae9bc5..a311587bd984 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numbers from typing import Optional, Tuple import warnings @@ -122,7 +123,7 @@ def from_params(cls, wi, wh, bi=None, bh=None): return cell @classmethod - def from_float(cls, other): + def from_float(cls, other, use_precomputed_fake_quant=False): assert type(other) == cls._FLOAT_MODULE assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" observed = cls.from_params(other.weight_ih, other.weight_hh, diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index 54d2b7e83fed..d47c898efa6a 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Dynamically quantized convolution modules.""" import torch diff --git a/torch/ao/nn/quantized/dynamic/modules/linear.py b/torch/ao/nn/quantized/dynamic/modules/linear.py index bf77aa04f0cb..0b8bf245af43 100644 --- a/torch/ao/nn/quantized/dynamic/modules/linear.py +++ b/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.quantized as nnq from torch.ao.nn.quantized.modules.utils import _quantize_weight @@ -77,7 +78,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a dynamic quantized module from a float module or qparams_dict Args: diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index dac1b820d50a..9afab93d1a55 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numbers import warnings from typing_extensions import deprecated @@ -268,7 +269,7 @@ def weight_bias_name(ihhh, layer, suffix): self._all_weight_values = torch.nn.ModuleList(_all_weight_values) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) in {torch.nn.LSTM, torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU' assert hasattr( @@ -495,8 +496,8 @@ def forward(self, input, hx=None): return self.forward_tensor(input, hx) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_mod): @@ -747,8 +748,8 @@ def forward(self, input, hx=None): return self.forward_tensor(input, hx) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) @classmethod def from_reference(cls, ref_mod): @@ -839,7 +840,7 @@ def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '' f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}") @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) in {torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \ @@ -1012,8 +1013,8 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: return ret @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class LSTMCell(RNNCellBase): @@ -1055,8 +1056,8 @@ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> self.bias_ih, self.bias_hh) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class GRUCell(RNNCellBase): @@ -1096,5 +1097,5 @@ def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: ) @classmethod - def from_float(cls, mod): - return super().from_float(mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) diff --git a/torch/ao/nn/quantized/functional.py b/torch/ao/nn/quantized/functional.py index 72218184fcfa..ccb450bdd834 100644 --- a/torch/ao/nn/quantized/functional.py +++ b/torch/ao/nn/quantized/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" Functional interface (quantized).""" from typing import List, Optional import warnings diff --git a/torch/ao/nn/quantized/modules/__init__.py b/torch/ao/nn/quantized/modules/__init__.py index 668f765fe3ef..2b87be71fd73 100644 --- a/torch/ao/nn/quantized/modules/__init__.py +++ b/torch/ao/nn/quantized/modules/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch # The quantized modules use `torch.nn` and `torch.ao.nn.quantizable` @@ -98,7 +99,7 @@ def forward(self, X): int(self.zero_point), self.dtype) @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): assert hasattr(mod, 'activation_post_process') scale, zero_point = mod.activation_post_process.calculate_qparams() return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype) @@ -127,5 +128,5 @@ def forward(self, Xq): return Xq.dequantize() @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): return DeQuantize() diff --git a/torch/ao/nn/quantized/modules/activation.py b/torch/ao/nn/quantized/modules/activation.py index 6fcd223e5049..3288c84555c4 100644 --- a/torch/ao/nn/quantized/modules/activation.py +++ b/torch/ao/nn/quantized/modules/activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from warnings import warn __all__ = [ @@ -46,7 +47,7 @@ def _get_name(self): return 'QuantizedReLU6' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): return ReLU6(mod.inplace) class Hardswish(torch.nn.Hardswish): @@ -69,7 +70,7 @@ def _get_name(self): return 'QuantizedHardswish' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return Hardswish(float(scale), int(zero_point)) @@ -98,7 +99,7 @@ def _get_name(self): return 'QuantizedELU' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return ELU(float(scale), int(zero_point), mod.alpha) @@ -129,7 +130,7 @@ def _get_name(self): return 'QuantizedLeakyReLU' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) @@ -154,7 +155,7 @@ def forward(self, input): return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): output_scale, output_zero_point = mod.activation_post_process.calculate_qparams() return cls(float(output_scale), int(output_zero_point)) @@ -187,7 +188,7 @@ def _get_name(self): return 'QuantizedSoftmax' @staticmethod - def from_float(mod): + def from_float(mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() return Softmax(mod.dim, float(scale), int(zero_point)) @@ -269,7 +270,7 @@ def _get_name(self): return 'QuantizedPReLU' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() qprelu = cls(float(scale), int(zero_point), mod.num_parameters) float_wt = mod.weight.float() diff --git a/torch/ao/nn/quantized/modules/batchnorm.py b/torch/ao/nn/quantized/modules/batchnorm.py index bfef31268cff..975697936d1e 100644 --- a/torch/ao/nn/quantized/modules/batchnorm.py +++ b/torch/ao/nn/quantized/modules/batchnorm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.ao.nn.intrinsic as nni @@ -14,7 +15,7 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs)) @staticmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): activation_post_process = mod.activation_post_process if type(mod) == cls._NNI_BN_RELU_MODULE: mod = mod[0] @@ -72,8 +73,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.running_var, self.eps, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): - return _BatchNorm.from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class BatchNorm3d(_BatchNorm): r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`. @@ -102,5 +103,5 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.running_var, self.eps, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): - return _BatchNorm.from_float(cls, mod) + def from_float(cls, mod, use_precomputed_fake_quant=False): + return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index ad1a51ee9c3b..ee0bceb336b7 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Quantized convolution modules.""" from typing import Optional, List, TypeVar @@ -215,7 +216,7 @@ def get_qconv(cls, mod, activation_post_process, weight_post_process=None): return qconv @staticmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): if hasattr(mod, "weight_fake_quant"): # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ # ".from_float only works for " + cls.__QAT_MODULE.__name__ @@ -368,14 +369,14 @@ def forward(self, input): return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """ - return _ConvNd.from_float(cls, mod) + return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv2d(_ConvNd): @@ -469,14 +470,14 @@ def forward(self, input): input, self._packed_params, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """ - return _ConvNd.from_float(cls, mod) + return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) class Conv3d(_ConvNd): @@ -571,14 +572,14 @@ def forward(self, input): input, self._packed_params, self.scale, self.zero_point) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """ - return _ConvNd.from_float(cls, mod) + return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant) # === Transposed Convolutions === MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) @@ -609,7 +610,7 @@ def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: L return res @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.ao.quantization diff --git a/torch/ao/nn/quantized/modules/dropout.py b/torch/ao/nn/quantized/modules/dropout.py index 64110ab53bed..ac934111c7f6 100644 --- a/torch/ao/nn/quantized/modules/dropout.py +++ b/torch/ao/nn/quantized/modules/dropout.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ['Dropout'] @@ -19,7 +20,7 @@ def _get_name(self): return 'QuantizedDropout' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): return cls(mod.p, mod.inplace) @classmethod diff --git a/torch/ao/nn/quantized/modules/embedding_ops.py b/torch/ao/nn/quantized/modules/embedding_ops.py index 25de7fa9b3cf..43b8d65063a4 100644 --- a/torch/ao/nn/quantized/modules/embedding_ops.py +++ b/torch/ao/nn/quantized/modules/embedding_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch import Tensor # noqa: F401 @@ -137,7 +138,7 @@ def weight(self): return self._packed_params._weight() @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized embedding module from a float module Args: @@ -241,7 +242,7 @@ def _get_name(self): return 'QuantizedEmbeddingBag' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized embedding_bag module from a float module Args: diff --git a/torch/ao/nn/quantized/modules/functional_modules.py b/torch/ao/nn/quantized/modules/functional_modules.py index 96408457a449..77b366c1f6d0 100644 --- a/torch/ao/nn/quantized/modules/functional_modules.py +++ b/torch/ao/nn/quantized/modules/functional_modules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List import torch @@ -239,7 +240,7 @@ def matmul(self, x: Tensor, y: Tensor) -> Tensor: return r @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): assert type(mod) == FloatFunctional, \ "QFunctional.from_float expects an instance of FloatFunctional" scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] diff --git a/torch/ao/nn/quantized/modules/linear.py b/torch/ao/nn/quantized/modules/linear.py index 9d988104a71d..52b0a80a1c90 100644 --- a/torch/ao/nn/quantized/modules/linear.py +++ b/torch/ao/nn/quantized/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections.abc import Iterable import torch @@ -240,12 +241,14 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params.set_weight_bias(w, b) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized module from an observed float module Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user + use_precomputed_fake_quant (bool): if True, the module will reuse min/max + values from the precomputed fake quant module. """ if hasattr(mod, 'weight_fake_quant'): if type_before_parametrizations(mod) == nniqat.LinearBn1d: @@ -267,8 +270,12 @@ def from_float(cls, mod): activation_post_process = mod.activation_post_process if type_before_parametrizations(mod) == nni.LinearReLU: mod = mod[0] - weight_post_process = mod.qconfig.weight() - weight_post_process(mod.weight) + weight_post_process = mod.qconfig.weight() if not hasattr(mod, "weight_fake_quant") else mod.weight_fake_quant + + if not use_precomputed_fake_quant: + # Observer may not have been called yet + # Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound + weight_post_process(mod.weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' diff --git a/torch/ao/nn/quantized/modules/normalization.py b/torch/ao/nn/quantized/modules/normalization.py index f798a241e324..46a18c4e2853 100644 --- a/torch/ao/nn/quantized/modules/normalization.py +++ b/torch/ao/nn/quantized/modules/normalization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d'] @@ -30,7 +31,7 @@ def _get_name(self): return 'QuantizedLayerNorm' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), @@ -71,7 +72,7 @@ def _get_name(self): return 'QuantizedGroupNorm' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), @@ -105,7 +106,7 @@ def _get_name(self): return 'QuantizedInstanceNorm1d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -145,7 +146,7 @@ def _get_name(self): return 'QuantizedInstanceNorm2d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -185,7 +186,7 @@ def _get_name(self): return 'QuantizedInstanceNorm3d' @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), diff --git a/torch/ao/nn/quantized/modules/rnn.py b/torch/ao/nn/quantized/modules/rnn.py index deb14856a9ef..b75ad0e6b34d 100644 --- a/torch/ao/nn/quantized/modules/rnn.py +++ b/torch/ao/nn/quantized/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = [ diff --git a/torch/ao/nn/quantized/modules/utils.py b/torch/ao/nn/quantized/modules/utils.py index 7c24c0ca31dc..83f478b57ff3 100644 --- a/torch/ao/nn/quantized/modules/utils.py +++ b/torch/ao/nn/quantized/modules/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import torch import itertools diff --git a/torch/ao/nn/quantized/reference/modules/conv.py b/torch/ao/nn/quantized/reference/modules/conv.py index 910223056fba..a7c285bc7f67 100644 --- a/torch/ao/nn/quantized/reference/modules/conv.py +++ b/torch/ao/nn/quantized/reference/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torch/ao/nn/quantized/reference/modules/linear.py b/torch/ao/nn/quantized/reference/modules/linear.py index 378fe0eb6eee..9dcba1f4bacd 100644 --- a/torch/ao/nn/quantized/reference/modules/linear.py +++ b/torch/ao/nn/quantized/reference/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torch/ao/nn/quantized/reference/modules/rnn.py b/torch/ao/nn/quantized/reference/modules/rnn.py index 4120338ce271..f5a53d0ceb3e 100644 --- a/torch/ao/nn/quantized/reference/modules/rnn.py +++ b/torch/ao/nn/quantized/reference/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch import Tensor @@ -213,7 +214,7 @@ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> return ret @classmethod - def from_float(cls, mod, weight_qparams_dict): + def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False): ref_mod = cls( mod.input_size, mod.hidden_size, diff --git a/torch/ao/nn/quantized/reference/modules/sparse.py b/torch/ao/nn/quantized/reference/modules/sparse.py index 4890402b875a..8db3f14b08ce 100644 --- a/torch/ao/nn/quantized/reference/modules/sparse.py +++ b/torch/ao/nn/quantized/reference/modules/sparse.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.nn as nn import torch.nn.functional as F from torch import Tensor @@ -76,7 +77,7 @@ def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_we self.padding_idx) @classmethod - def from_float(cls, mod, weight_qparams): + def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False): return cls( mod.num_embeddings, mod.embedding_dim, diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py index c4f4d0b46efd..87acd1901f0c 100644 --- a/torch/ao/nn/quantized/reference/modules/utils.py +++ b/torch/ao/nn/quantized/reference/modules/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import typing diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index 5347b682fb5a..7a28142e4b0d 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch @@ -92,7 +93,7 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor], self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized sparse dynamic module from a float module. We only care about the convert at this stage, no need for observers just yet. diff --git a/torch/ao/nn/sparse/quantized/linear.py b/torch/ao/nn/sparse/quantized/linear.py index 71caa8cbab61..26388e2e2c7b 100644 --- a/torch/ao/nn/sparse/quantized/linear.py +++ b/torch/ao/nn/sparse/quantized/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch @@ -146,7 +147,7 @@ def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor], self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, use_precomputed_fake_quant=False): r"""Create a quantized sparse module from a float module. We only care about the convert at this stage, no need for observers just yet. diff --git a/torch/ao/nn/sparse/quantized/utils.py b/torch/ao/nn/sparse/quantized/utils.py index 3d934f578574..46b1cb1e5b71 100644 --- a/torch/ao/nn/sparse/quantized/utils.py +++ b/torch/ao/nn/sparse/quantized/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import threading __all__ = [ diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 3f0df31dfd2a..d6df04bbb5e6 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.quantized as nnq diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index ec5fdaede073..bd827ea16368 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module contains tooling to compare weights and activations across models. Example usage:: diff --git a/torch/ao/ns/fx/graph_matcher.py b/torch/ao/ns/fx/graph_matcher.py index 8db946ec707a..8b542a3a0b81 100644 --- a/torch/ao/ns/fx/graph_matcher.py +++ b/torch/ao/ns/fx/graph_matcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import enum diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index fbd03426790d..ba977eed9962 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import GraphModule, map_arg from torch.fx.graph import Graph, Node diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 1fd6f069ac83..fc96a0da5a2b 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.fx from torch.fx import ( diff --git a/torch/ao/ns/fx/qconfig_multi_mapping.py b/torch/ao/ns/fx/qconfig_multi_mapping.py index 33efe21e3fe0..915fdb3e7830 100644 --- a/torch/ao/ns/fx/qconfig_multi_mapping.py +++ b/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import copy diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index bf35a7e531e1..16ac0c9c1504 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import enum import operator diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py index 7c03a9f6e36a..0f4ace3de206 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py +++ b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional import torch from collections import defaultdict diff --git a/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py index ad4df426c8e1..76514b19f93c 100644 --- a/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py +++ b/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import wraps import weakref import abc diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index 7f4fcb461e22..f56fa511f991 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import torch from typing import Optional, Tuple, List, Any, Dict diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index 20919c140a4d..a90ed9bae523 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from dlrm_s_pytorch import DLRM_Net # type: ignore[import] import numpy as np # type: ignore[import] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py index 3813f01c0975..1780b68540aa 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_disk_savings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List import torch import time diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py index 4f205312e181..69ddce634237 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List import torch from dlrm_s_pytorch import unpack_batch # type: ignore[import] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py index 31600118f662..79d5093d5098 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List import torch from dlrm_s_pytorch import unpack_batch # type: ignore[import] diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index 448c9377cc55..f1281729a74b 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.nn import functional as F from functools import reduce diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py index 922c81322cfe..704391268985 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import SUPPORTED_TYPES diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py index 77ca61d599cb..554ad27dd357 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from copy import deepcopy from typing import Any, Optional, Dict, TYPE_CHECKING diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py index 252405de4968..957254284215 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.ao.pruning._experimental.data_sparsifier.data_norm_sparsifier import DataNormSparsifier from torch.ao.pruning._experimental.data_scheduler.base_data_scheduler import BaseDataScheduler import torch diff --git a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py index 1e76cfc345ac..0e907f42d3bf 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index d8c3d20052ba..fe874c6effc7 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Optional, Union import torch diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py index 357421fb5529..b380ae00adce 100644 --- a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from itertools import chain from operator import getitem import torch diff --git a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py index 9e569c14a6c8..3b65ce59fecc 100644 --- a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast import torch diff --git a/torch/ao/pruning/_experimental/pruner/parametrization.py b/torch/ao/pruning/_experimental/pruner/parametrization.py index df94f7093b53..c5aa74e3bc52 100644 --- a/torch/ao/pruning/_experimental/pruner/parametrization.py +++ b/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import nn from torch.nn.utils.parametrize import is_parametrized diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index 2b16d4b327a0..f7dcf120f9c3 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Collection of conversion functions for linear / conv2d structured pruning Also contains utilities for bias propagation diff --git a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py index 7f96f0865d30..cf932c272005 100644 --- a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .base_structured_sparsifier import BaseStructuredSparsifier diff --git a/torch/ao/pruning/_mappings.py b/torch/ao/pruning/_mappings.py index 726cbc6b0fc8..70a0c785190f 100644 --- a/torch/ao/pruning/_mappings.py +++ b/torch/ao/pruning/_mappings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = [ "get_static_sparse_quantized_mapping", "get_dynamic_sparse_quantized_mapping", diff --git a/torch/ao/pruning/scheduler/base_scheduler.py b/torch/ao/pruning/scheduler/base_scheduler.py index 3391d3e73cd6..82f02399b7ec 100644 --- a/torch/ao/pruning/scheduler/base_scheduler.py +++ b/torch/ao/pruning/scheduler/base_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.ao.pruning import BaseSparsifier diff --git a/torch/ao/pruning/scheduler/cubic_scheduler.py b/torch/ao/pruning/scheduler/cubic_scheduler.py index 76fc61daa288..1a883059f569 100644 --- a/torch/ao/pruning/scheduler/cubic_scheduler.py +++ b/torch/ao/pruning/scheduler/cubic_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from .base_scheduler import BaseScheduler diff --git a/torch/ao/pruning/scheduler/lambda_scheduler.py b/torch/ao/pruning/scheduler/lambda_scheduler.py index a88d99a1f83b..5236ebc33a26 100644 --- a/torch/ao/pruning/scheduler/lambda_scheduler.py +++ b/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from .base_scheduler import BaseScheduler diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 1c210ace344d..8afed4d68945 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import copy from collections import defaultdict diff --git a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py index 4f44e81485df..419323e68f93 100644 --- a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py +++ b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from . import base_sparsifier diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 98f489904cc4..7fd93e4d9da7 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Optional, Type from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized from itertools import chain diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 2b24ca3d82e3..2f50d51f2a38 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import reduce from typing import Callable, Optional, Tuple, Union diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index e2b8ee5c810a..f77969b32149 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # flake8: noqa: F403 from .fake_quantize import * # noqa: F403 diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 83cc81bb6b00..bf6b42a4a0dc 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.quantized as nnq diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index 7d39dbcf1ca8..4fed532c56f0 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import copy from typing import Dict, Any diff --git a/torch/ao/quantization/_learnable_fake_quantize.py b/torch/ao/quantization/_learnable_fake_quantize.py index 6827ae35533c..ce23e80de150 100644 --- a/torch/ao/quantization/_learnable_fake_quantize.py +++ b/torch/ao/quantization/_learnable_fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.nn.parameter import Parameter from typing import List @@ -8,9 +9,8 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): r"""Generalized extension of the FakeQuantize module in fake_quantize.py. This is an extension of the FakeQuantize module in fake_quantize.py, which - supports more generalized lower-bit quantization and support learning of the scale - and zero point parameters through backpropagation. For literature references, - please see the class _LearnableFakeQuantizePerTensorOp. + supports more generalized lower-bit quantization and supports learning of the scale + and zero point parameters through backpropagation. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize module also includes the following attributes to support quantization parameter learning. diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 4e946a25ffbb..d76bdfddddaf 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import operator import torch diff --git a/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/torch/ao/quantization/backend_config/_qnnpack_pt2e.py index 01e112b688c0..871d26dd9ff7 100644 --- a/torch/ao/quantization/backend_config/_qnnpack_pt2e.py +++ b/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import torch from torch.ao.quantization.backend_config import ( diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 2288aced0995..96fb66662d6f 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Type, Union, TYPE_CHECKING diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py index 81cfc928adb5..84e0fbc45c62 100644 --- a/torch/ao/quantization/backend_config/native.py +++ b/torch/ao/quantization/backend_config/native.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from ._common_operator_config_utils import ( _get_binary_op_configs, diff --git a/torch/ao/quantization/backend_config/onednn.py b/torch/ao/quantization/backend_config/onednn.py index 6eab945f7d74..88dffedfd81b 100644 --- a/torch/ao/quantization/backend_config/onednn.py +++ b/torch/ao/quantization/backend_config/onednn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn as nn import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/quantization/backend_config/tensorrt.py b/torch/ao/quantization/backend_config/tensorrt.py index 1c5f761508bb..7a80d1883cfd 100644 --- a/torch/ao/quantization/backend_config/tensorrt.py +++ b/torch/ao/quantization/backend_config/tensorrt.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from .backend_config import ( BackendConfig, diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py index 2e7382274079..13bf632e251a 100644 --- a/torch/ao/quantization/backend_config/utils.py +++ b/torch/ao/quantization/backend_config/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Any, List, Callable, Union, Tuple, Type import torch diff --git a/torch/ao/quantization/experimental/APoT_tensor.py b/torch/ao/quantization/experimental/APoT_tensor.py index debda7aea8c0..6caa2334be07 100644 --- a/torch/ao/quantization/experimental/APoT_tensor.py +++ b/torch/ao/quantization/experimental/APoT_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.experimental.quantizer import APoTQuantizer diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py index 4d988bbb25bb..d035a02b047a 100644 --- a/torch/ao/quantization/experimental/adaround_fake_quantize.py +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Tuple import torch diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index 7304f885a6f3..f7eedd9fef12 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -1,5 +1,5 @@ +# mypy: allow-untyped-defs import copy -import logging from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch @@ -12,16 +12,21 @@ from torch.nn.parallel import DataParallel from torch.utils.data import DataLoader, TensorDataset -logger: logging.Logger = logging.getLogger(__name__) - class AdaptiveRoundingOptimizer: def __init__( self, model: Union[torch.nn.Module, torch.nn.DataParallel], - callback: Callable[[torch.nn.Module, List[Any]], None], + callback: Callable[ + [ + Union[torch.nn.Module, torch.nn.DataParallel], + Any, + Optional[torch.nn.Module], + ], + None, + ], forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable], - data: List[Any], + data: Any, observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver, max_iter=10000, dtype: torch.dtype = torch.qint8, @@ -29,8 +34,14 @@ def __init__( quant_max=127, qscheme: torch.qscheme = torch.per_tensor_symmetric, batch_size: int = 256, + feed_forward_wrapper: Optional[torch.nn.Module] = None, ): - self.model = model + if torch.cuda.is_available(): + self.model = model.cuda() + if torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(model) + else: + self.model = model self.q_model = copy.deepcopy(self.model) self.device = torch.device("cuda") if torch.cuda.is_available() else None self.callback = callback @@ -47,20 +58,27 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.qscheme = qscheme + self.feed_forward_wrapper = feed_forward_wrapper def run_adaround(self) -> torch.nn.Module: layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = [] for (name, module), q_module in zip( self.model.named_modules(), self.q_model.modules() ): + if isinstance(module, torch.nn.ReLU): + # Disable all inplace operations + module.inplace = False + if isinstance(q_module, torch.nn.ReLU): + # Disable all inplace operations + q_module.inplace = False if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)): # Knowing activation ahead-of-time would be helpful for asymmetric formulation # But this is challenging in eager mode, but graph module. layer_list.append((name, module, q_module)) - logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004 + print(f"Total number of layers : {len(layer_list)}") # noqa: G004 for name, module, q_module in layer_list: - logger.info( + print( f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 ) self.optimize_adaptive_rounding( @@ -87,10 +105,15 @@ def get_data_inp_out( handler2 = q_module.register_forward_hook( self.forward_hook_wrapper(quant_fetcher) ) + if torch.cuda.is_available(): + # Somehow, we need to move the model continuously + # Otherwise, the model will be lowered to CPU misteriously + self.model = self.model.cuda() + self.q_model = self.q_model.cuda() for data_ in data: with torch.no_grad(): - self.callback(self.model, data_) - self.callback(self.q_model, data_) + self.callback(self.model, data_, self.feed_forward_wrapper) + self.callback(self.q_model, data_, self.feed_forward_wrapper) fp32_output = fp32_fetcher[1] quant_input = quant_fetcher[0] fp_out.append(fp32_output) @@ -137,7 +160,7 @@ def _compute_and_display_local_losses( out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module) soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) - logger.info( + print( f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 ) @@ -162,13 +185,9 @@ def optimize_adaptive_rounding( optimizer = torch.optim.Adam([ada_quantizer.V]) inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) - logger.info("==================== Before adaround ====================") - test_in, test_out, fp_test_in = self.get_data_inp_out( - module, q_module, self.data[0] - ) - + print("==================== Before adaround ====================") assert ( - torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0 + torch.abs(out[0] - module(fp_in[0])).sum().item() == 0 ), "In-placed activation is detected, please do not use activation in-placed" # Stack the tensors in each list into a single tensor # Assuming inp and out are your lists of tensors @@ -177,9 +196,7 @@ def optimize_adaptive_rounding( dataset = TensorDataset(inp_tensor, out_tensor) dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) - self._compute_and_display_local_losses( - ada_quantizer, q_module, test_in[0], test_out[0] - ) + self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) global_idx = 0 one_iter = len(out) // self.batch_size for iteration in range(self.max_iter // one_iter): @@ -191,6 +208,7 @@ def optimize_adaptive_rounding( q_out = torch.nn.functional.conv1d( q_inp, q_weight, + bias=q_module.bias, stride=q_module.stride, padding=q_module.padding, dilation=q_module.dilation, @@ -219,14 +237,12 @@ def optimize_adaptive_rounding( if global_idx >= self.max_iter: break if iteration % 30 == 0: - logger.info( + print( f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 ) - logger.info("==================== After adaround ====================") - self._compute_and_display_local_losses( - ada_quantizer, q_module, test_in[0], test_out[0] - ) + print("==================== After adaround ====================") + self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) ada_quantizer.use_soft_rounding = True ada_quantizer.V.requires_grad = False diff --git a/torch/ao/quantization/experimental/apot_utils.py b/torch/ao/quantization/experimental/apot_utils.py index ad7a7bed1fbe..c2f2f0746ca5 100644 --- a/torch/ao/quantization/experimental/apot_utils.py +++ b/torch/ao/quantization/experimental/apot_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This file contains utility functions to convert values using APoT nonuniform quantization methods. diff --git a/torch/ao/quantization/experimental/fake_quantize.py b/torch/ao/quantization/experimental/fake_quantize.py index 7541106a61c8..6b4da74541f2 100644 --- a/torch/ao/quantization/experimental/fake_quantize.py +++ b/torch/ao/quantization/experimental/fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor from torch.ao.quantization.experimental.observer import APoTObserver diff --git a/torch/ao/quantization/experimental/fake_quantize_function.py b/torch/ao/quantization/experimental/fake_quantize_function.py index cac01fd8c002..924c81fc08df 100644 --- a/torch/ao/quantization/experimental/fake_quantize_function.py +++ b/torch/ao/quantization/experimental/fake_quantize_function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor from torch.ao.quantization.experimental.quantizer import quantize_APoT, dequantize_APoT diff --git a/torch/ao/quantization/experimental/linear.py b/torch/ao/quantization/experimental/linear.py index 154023b16183..cb46c99b01af 100644 --- a/torch/ao/quantization/experimental/linear.py +++ b/torch/ao/quantization/experimental/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import numpy as np diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 76a63815bdc6..8474f69c26a2 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module implements nonuniform observers used to collect statistics about the values observed during calibration (PTQ) or training (QAT). diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py index df9c0f27847e..b386ce20bbd3 100644 --- a/torch/ao/quantization/experimental/quantizer.py +++ b/torch/ao/quantization/experimental/quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor import numpy as np diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index 9f0503cf06a5..b921df39217a 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Implements modules used to perform fake quantization.""" import torch diff --git a/torch/ao/quantization/fuse_modules.py b/torch/ao/quantization/fuse_modules.py index 2caa0a2b7f2d..b9447ff37e39 100644 --- a/torch/ao/quantization/fuse_modules.py +++ b/torch/ao/quantization/fuse_modules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch.nn as nn diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 16c0c3a85b8f..a989ae298825 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.nn as nn import torch.ao.nn.intrinsic as nni diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index f2e774590be3..72ce4b2471f5 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Optional, Tuple diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index b0965b9a7051..40a7e7bbff3b 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from collections import namedtuple diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 049f4e3135d9..92620a169383 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import map_arg, Node from torch.fx.graph import Graph diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index b5c7f9fd2976..8e59df51c6ff 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Set, Tuple, Callable, List import torch diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index 724e76ad576f..3370d8c9baf6 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Set, Tuple, Callable from collections import OrderedDict import torch diff --git a/torch/ao/quantization/fx/_model_report/model_report_observer.py b/torch/ao/quantization/fx/_model_report/model_report_observer.py index eaa45264be7e..f04d6da8a054 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_observer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.observer import ObserverBase diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py index 5463862aa1cd..e6288c6f71d9 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import Any, Set, Dict, List, Tuple, OrderedDict from collections import OrderedDict as OrdDict diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 5aa095b49b65..6ca622cc4171 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -954,6 +954,7 @@ def convert( "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " "in a future version. Please pass in a ConvertCustomConfig instead.", FutureWarning, + stacklevel=2, ) convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) @@ -962,6 +963,7 @@ def convert( "Passing a QConfig dictionary to convert is deprecated and will not be supported " "in a future version. Please pass in a QConfigMapping instead.", FutureWarning, + stacklevel=2, ) qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None qconfig_mapping = copy.deepcopy(qconfig_mapping) @@ -972,6 +974,7 @@ def convert( "Passing a backend_config_dict to prepare is deprecated and will not be supported " "in a future version. Please pass in a BackendConfig instead.", FutureWarning, + stacklevel=2, ) backend_config = BackendConfig.from_dict(backend_config) diff --git a/torch/ao/quantization/fx/custom_config.py b/torch/ao/quantization/fx/custom_config.py index 4fb2c3a28cb0..72f28ddbc777 100644 --- a/torch/ao/quantization/fx/custom_config.py +++ b/torch/ao/quantization/fx/custom_config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 17b934efc6be..b555789f673a 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx import ( GraphModule, Node, @@ -57,6 +58,7 @@ def fuse( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " "in a future version. Please pass in a FuseCustomConfig instead.", FutureWarning, + stacklevel=2, ) fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) @@ -65,6 +67,7 @@ def fuse( "Passing a backend_config_dict to prepare is deprecated and will not be supported " "in a future version. Please pass in a BackendConfig instead.", FutureWarning, + stacklevel=2, ) backend_config = BackendConfig.from_dict(backend_config) diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py index 718cc561bfa0..2766211e8e1b 100644 --- a/torch/ao/quantization/fx/fuse_handler.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.backend_config import BackendConfig from torch.fx.graph import Node, Graph diff --git a/torch/ao/quantization/fx/graph_module.py b/torch/ao/quantization/fx/graph_module.py index cc9187285ae6..224f71745157 100644 --- a/torch/ao/quantization/fx/graph_module.py +++ b/torch/ao/quantization/fx/graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import copy from torch.fx import GraphModule diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py index cf287db8c524..b5a6657103fc 100644 --- a/torch/ao/quantization/fx/match_utils.py +++ b/torch/ao/quantization/fx/match_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import torch from torch.fx.graph import ( diff --git a/torch/ao/quantization/fx/pattern_utils.py b/torch/ao/quantization/fx/pattern_utils.py index d8648a0aed5e..3665f75f7567 100644 --- a/torch/ao/quantization/fx/pattern_utils.py +++ b/torch/ao/quantization/fx/pattern_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict from typing import Dict, Any from torch.ao.quantization.utils import Pattern diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index ce99fc757efb..80f50581cc72 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch import warnings @@ -1754,6 +1755,7 @@ def prepare( "Passing a QConfig dictionary to prepare is deprecated and will not be supported " "in a future version. Please pass in a QConfigMapping instead.", FutureWarning, + stacklevel=2, ) qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) @@ -1762,6 +1764,7 @@ def prepare( "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " "be supported in a future version. Please pass in a QConfigMapping instead.", FutureWarning, + stacklevel=2, ) _equalization_config = QConfigMapping.from_dict(_equalization_config) @@ -1770,6 +1773,7 @@ def prepare( "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " "in a future version. Please pass in a PrepareCustomConfig instead.", FutureWarning, + stacklevel=2, ) prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) @@ -1778,6 +1782,7 @@ def prepare( "Passing a backend_config_dict to prepare is deprecated and will not be supported " "in a future version. Please pass in a BackendConfig instead.", FutureWarning, + stacklevel=2, ) backend_config = BackendConfig.from_dict(backend_config) diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 0b906a1777de..378c51b6805d 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import re from collections import defaultdict, OrderedDict diff --git a/torch/ao/quantization/fx/quantize_handler.py b/torch/ao/quantization/fx/quantize_handler.py index e70040f7e649..83fee8efcd99 100644 --- a/torch/ao/quantization/fx/quantize_handler.py +++ b/torch/ao/quantization/fx/quantize_handler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC from typing import Callable, Dict, List, Optional, Type diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index be26332b2485..5029db47961f 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch import torch.nn as nn @@ -837,7 +838,7 @@ def _activation_post_process_satisfies_dtype_config_constraints( suggestion_str = ( "Please use torch.ao.quantization.get_default_qconfig_mapping or " "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" - " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n" + ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n' " model = prepare_fx(model, qconfig_mapping, example_inputs)" ) if not isinstance(activation_post_process, FixedQParamsObserver) and \ diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 5f075df1cd83..656372d37555 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module implements observers which are used to collect statistics about the values observed during calibration (PTQ) or training (QAT). diff --git a/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/torch/ao/quantization/pt2e/duplicate_dq_pass.py index 48c7d7247b99..a6cfbce611fa 100644 --- a/torch/ao/quantization/pt2e/duplicate_dq_pass.py +++ b/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import operator diff --git a/torch/ao/quantization/pt2e/export_utils.py b/torch/ao/quantization/pt2e/export_utils.py index 139042c326b8..78c69b718d7d 100644 --- a/torch/ao/quantization/pt2e/export_utils.py +++ b/torch/ao/quantization/pt2e/export_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types import torch diff --git a/torch/ao/quantization/pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py index bacb4d8a28f1..6ae93ba1d260 100644 --- a/torch/ao/quantization/pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import Any, List, OrderedDict, Set, Optional, Callable import operator diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index 5ea1f939a3b6..313b420e7a22 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Optional diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 169a982f62ce..162ee45623ee 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._subclasses import FakeTensor from torch.ao.quantization.fx.prepare import ( diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index 45f5c265d2cb..c4c1f804d41c 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import itertools import operator diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 7f5cb2eeb13b..40801344740b 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import GraphModule from ..export_utils import _WrapperModule diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 25f82f04e4e3..cde22426ae5b 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import types diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 88e7b47aff2b..dc93d7938f0c 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple from typing import Optional, Any, Union, Type from typing_extensions import deprecated diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 6bf4b41c724a..37f71465afea 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from collections import OrderedDict from typing import Any, Callable, Dict, Tuple, Union, List diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 794cb142220d..be00be0e295b 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -1,7 +1,8 @@ +# mypy: allow-untyped-defs import copy import itertools import warnings - +import inspect import torch import torch.nn as nn import torch.ao.nn.quantized as nnq @@ -235,6 +236,13 @@ def insert_activation_post_process(m, special_act_post_process=None): if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ and type_before_parametrizations(module) in qconfig_propagation_list: insert_activation_post_process(module) + # This is a special case for AdaRound eager mode + # AdaRound contains weight_fake_quant to be propagated from API to convert + # leaf node check with a number of children looks naive assumption that blocks + # Adding an exception case for AdaRound + if hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) \ + and type_before_parametrizations(module) in qconfig_propagation_list: + insert_activation_post_process(module) def _get_unique_devices_(module): return {p.device for p in module.parameters()} | \ @@ -520,7 +528,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False): def convert( module, mapping=None, inplace=False, remove_qconfig=True, - is_reference=False, convert_custom_config_dict=None): + is_reference=False, convert_custom_config_dict=None, + use_precomputed_fake_quant=False): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class. And remove qconfig at the end if remove_qconfig is set to True. @@ -533,6 +542,7 @@ def convert( `inplace`: carry out model transformations in-place, the original module is mutated `convert_custom_config_dict`: custom configuration dictionary for convert function + `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant .. code-block:: python @@ -552,14 +562,16 @@ def convert( module = copy.deepcopy(module) _convert( module, mapping, inplace=True, is_reference=is_reference, - convert_custom_config_dict=convert_custom_config_dict) + convert_custom_config_dict=convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant) if remove_qconfig: _remove_qconfig(module) return module def _convert( module, mapping=None, inplace=False, - is_reference=False, convert_custom_config_dict=None): + is_reference=False, convert_custom_config_dict=None, + use_precomputed_fake_quant=False): r"""Converts submodules in input module to a different module according to `mapping` by calling `from_float` method on the target module class @@ -571,6 +583,7 @@ def _convert( inplace: carry out model transformations in-place, the original module is mutated is_reference: a flag to enable quantized reference module + use_precomputed_fake_quant: a flag to enable use of precomputed fake quant """ if mapping is None: @@ -589,15 +602,16 @@ def _convert( if not isinstance(mod, _FusedModule) and \ type_before_parametrizations(mod) not in custom_module_class_mapping: _convert(mod, mapping, True, # inplace - is_reference, convert_custom_config_dict) - reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) + is_reference, convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant) + reassign[name] = swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant) for key, value in reassign.items(): module._modules[key] = value return module -def swap_module(mod, mapping, custom_module_class_mapping): +def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. @@ -623,7 +637,11 @@ def swap_module(mod, mapping, custom_module_class_mapping): weight_qparams = get_qparam_dict(weight_post_process) new_mod = qmod.from_float(mod, weight_qparams) else: - new_mod = qmod.from_float(mod) + sig = inspect.signature(qmod.from_float) + if 'use_precomputed_fake_quant' in sig.parameters: + new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant) + else: + new_mod = qmod.from_float(mod) swapped = True if swapped: diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 453c0511e4d9..5767a525342e 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -122,6 +122,7 @@ def _prepare_fx( "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " "in a future version. Please pass in a PrepareCustomConfig instead.", FutureWarning, + stacklevel=3, ) prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) @@ -229,6 +230,7 @@ def fuse_fx( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " "in a future version. Please pass in a FuseCustomConfig instead.", FutureWarning, + stacklevel=2, ) fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) @@ -520,6 +522,7 @@ def _convert_fx( "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " "in a future version. Please pass in a ConvertCustomConfig instead.", FutureWarning, + stacklevel=3, ) convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 632fc1db2327..3001deb6ab9c 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.ao.quantization.qconfig import QConfig diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index d9919aa2e9c5..b312d89911a5 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -26,7 +26,7 @@ from torch.fx.passes.infra.pass_manager import PassManager from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ -from torch._inductor.constant_folding import constant_fold +from torch._export.passes.constant_folding import constant_fold __all__ = [ "prepare_pt2e", diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index 81306943264b..bd3d2773e628 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import copy diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index a521ff56c34c..4cecfee28f2b 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index f25d0916018b..68c90f5cf57f 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List from torch.ao.quantization.pt2e.utils import _is_sym_size_node @@ -47,3 +48,37 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): ((user not in partition_nodes) or _is_sym_size_node(user)) for user in node.users ) + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 4cc05e46c6a7..6eecabb6fee0 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import itertools @@ -14,8 +15,11 @@ Set, Tuple, TYPE_CHECKING, + Union, ) +from typing_extensions import TypeAlias + import torch import torch.nn.functional as F from torch.ao.quantization.fake_quantize import ( @@ -36,8 +40,9 @@ Quantizer, SharedQuantizationSpec, ) + +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - _is_annotated, get_bias_qspec, get_input_act_qspec, get_output_act_qspec, @@ -52,6 +57,9 @@ SourcePartition, ) +FilterFn: TypeAlias = Callable[[List[Node]], bool] + + if TYPE_CHECKING: from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor @@ -67,6 +75,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): # * Node as output node of a fusion pattern. # * The fusion pattern supports int8 data type. # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. _is_output_of_quantized_pattern: bool = False @@ -101,6 +110,91 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): QUANT_ANNOTATION_KEY = "quantization_annotation" +def _skip_annotate(nodes: List[Node], filter_fn: Optional[FilterFn] = None) -> bool: + """Determine whether to skip annotation for a list of nodes.""" + + # 1) Skip annotate if any node is already annotated + if _is_any_annotated(nodes): + return True + + # 2) Proceed annotate if a) a filter function is provided + # and b) the given nodes list passes the filter function check. + if filter_fn and filter_fn(nodes): + return False + + return True + + +def _create_module_name_filter(module_name: str) -> FilterFn: + """Create a filter function for a given module name. + + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> module_name_filter = _create_module_name_filter_inner("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". + """ + + filter_fn = _get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: List[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) + return all_nodes_from_module_name + + return check_all_nodes_from_module + + +def _create_operator_type_filter( + operator_type: Callable, +) -> FilterFn: + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: List[Node]): + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) + return num_nodes_with_operator_type == 1 + + return operator_type_filter + + +def _global_config_filter(nodes: List[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Several nodes within a single pattern are default quantizable operations." + ) + return num_nodes_in_default_quantizable_ops == 1 + + def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( @@ -293,16 +387,63 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_x86_inductor_config_and_operators() +def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None: + """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + + +def _config_checker(method: Callable) -> Callable: + @functools.wraps(method) + def wrapper( + quantizer: "X86InductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "X86InductorQuantizer": + if quantizer._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name}.", + ) + return quantizer + return method(quantizer, name, quantization_config) + + return wrapper + + +@dataclass +class _CurrentQuantizationMode: + r"""Configuration defining the current quantization mode for the quantizer. + + All possible current quantization modes are listed below: + ---------------------------------------------------------------------------------------------------------- + | dynamic_state + qat_state |--------------------------------------------------------------------------------------------- + | None | True | False + ---------------------------------------------------------------------------------------------------------- + None | quantizer does not receive a non-None `quantization_config` | \ | \ + False | quantizer will not do QAT | dynamic | static + True | quantizer will do QAT | QAT + dynamic | QAT + static + """ + + qat_state: Optional[bool] + dynamic_state: Optional[bool] + + class X86InductorQuantizer(Quantizer): supported_config_and_operators = _get_supported_config_and_operators() module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() def __init__(self): super().__init__() - self.global_config: QuantizationConfig = None # type: ignore[assignment] + self.global_config: Optional[QuantizationConfig] = None self.operator_type_qconfig: Dict[ torch._ops.OpOverloadPacket, Optional[QuantizationConfig] ] = {} + self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {} @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: @@ -326,7 +467,78 @@ def get_supported_operator_for_quantization_config( return ops return [] + def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: + """Retrieves the current quantization mode based on all configurations.""" + qat_state = None + dynamic_state = None + + # As we use `_need_skip_config` to skip all invalid configurations, + # we can safely assume that the all existing non-None configurations + # have the same quantization mode. + for qconfig in ( + list(self.module_name_qconfig.values()) + + list(self.operator_type_qconfig.values()) + + [self.global_config] + ): + if qconfig is not None: + # Query the `is_qat` state + if qat_state is None: + qat_state = qconfig.is_qat + else: + assert qat_state == qconfig.is_qat, ( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) + # Query the `is_dynamic` state + input_activation_spec = qconfig.input_activation + if input_activation_spec is not None: + if dynamic_state is None: + dynamic_state = input_activation_spec.is_dynamic + else: + assert dynamic_state == input_activation_spec.is_dynamic, ( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) + return _CurrentQuantizationMode( + qat_state=qat_state, dynamic_state=dynamic_state + ) + + def _need_skip_config( + self, quantization_config: Optional[QuantizationConfig] + ) -> bool: + """Check if the provided quantization config is valid for X86InductorQuantizer. + + Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + To avoid such a mix, we compare the incoming configuration with current configuration status. + Refer the `_CurrentQuantizationMode` definition for all possible modes. + """ + if quantization_config is None: + return False + + need_skip = False + current_mode = self._get_current_quantization_mode() + if ( + current_mode.qat_state is not None + and current_mode.qat_state != quantization_config.is_qat + ): + warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") + need_skip = True + if current_mode.dynamic_state is not None: + input_activation_spec = quantization_config.input_activation + if ( + input_activation_spec is not None + and current_mode.dynamic_state != input_activation_spec.is_dynamic + ): + warnings.warn( + "Mixed dynamic and static quantization config is not supported." + ) + need_skip = True + return need_skip + def set_global(self, quantization_config: QuantizationConfig): + if self._need_skip_config(quantization_config): + warnings.warn("Skip the global quantization config.") + return self self.global_config = quantization_config return self @@ -338,6 +550,7 @@ def get_global_quantization_config(self): ) return self.global_config + @_config_checker def set_function_type_qconfig( self, function_type: Callable, @@ -356,6 +569,7 @@ def set_function_type_qconfig( ) return self + @_config_checker def set_module_type_qconfig( self, module_type: torch.nn.Module, @@ -372,6 +586,19 @@ def set_module_type_qconfig( ) return self + @_config_checker + def set_module_name_qconfig( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. + """ + self.module_name_qconfig[module_name] = quantization_config + return self + def _set_aten_operator_qconfig( self, operator_type: torch._ops.OpOverloadPacket, @@ -385,22 +612,16 @@ def _set_aten_operator_qconfig( ) return self - def _get_aten_operator_qconfig( - self, - operator_type: torch._ops.OpOverloadPacket, - ) -> Optional[QuantizationConfig]: - if operator_type in self.operator_type_qconfig: - assert operator_type in quantizable_ops - return self.operator_type_qconfig[operator_type] - return self.global_config if operator_type in default_quantizable_ops else None - def _annotate_conv_node_helper( self, conv_node: torch.fx.Node, annotate_output: bool, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: """Helper function to annotate the conv node""" + if quantization_config is None: + _annotate_nodes_not_quantize(conv_node) + return input_qspec_map = {} input_node = conv_node.args[0] assert isinstance(input_node, Node) @@ -427,9 +648,12 @@ def _annotate_linear_node_helper( self, linear_node: torch.fx.Node, annotate_output: bool, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], ) -> None: """Helper function to annotate the linear node""" + if quantization_config is None: + _annotate_nodes_not_quantize(linear_node) + return input_qspec_map = {} assert linear_node.target in (torch.ops.aten.linear.default,) has_bias = len(linear_node.args) == 3 @@ -503,65 +727,92 @@ def _get_input_idx_for_binary_node( return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """just handling global spec for now""" - if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] - model = self._annotate_for_dynamic_quantization_config(model) - else: - model = self._annotate_for_static_quantization_config(model) + """Annotate the given model with quantization configurations. + + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. + 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. + """ + for module_name, quantization_config in self.module_name_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_module_name_filter(module_name) + ) + + for operator_type, quantization_config in self.operator_type_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_operator_type_filter(operator_type) + ) + + if self.global_config: + self._annotate_with_config( + model, + self.global_config, + _global_config_filter, + ) + + # Once we've annotated the model with quantization configurations, we also need to annotate + # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, + # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) + return model - def _annotate_for_static_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - r""" + def _annotate_with_config( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: FilterFn, + ) -> None: + """Annotate the model with the given quantization configuration. + High-level description of quantization recipe for X86 Inductor Backend: Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model from start to the end. If a pattern supports computation with int8 data type and inputs connected to quantized patterns, annotate its inputs as quantized pattern. - Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns, - such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, - we need to annotate the output of this pattern. """ # Step1: Recipe of fusion patterns like conv/linear. - self._annotate_conv2d_fusion_pattern(model) - self._annotate_linear_fusion_pattern(model) - self._annotate_matmul(model) + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) # Step2: Recipe to propagate annotation for patterns beside conv/linear. # Go through all the nodes from start to end. # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 - for node in model.graph.nodes: - self._annotate_propagation_quantizable_pattern(node) - - # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized - # in inputs. So, we can fuse dq-operator-q into a quantized op. - # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ - # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 - for node in model.graph.nodes: - self._annotate_output_for_int8_in_int8_out_pattern(node) - - return model - def _annotate_for_dynamic_quantization_config( - self, model: torch.fx.GraphModule - ) -> torch.fx.GraphModule: - self._annotate_linear_fusion_pattern(model) - return model + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) def _annotate_qat_conv2d_fusion_pattern( - self, model: torch.fx.GraphModule, config: QuantizationConfig + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ): # Annotate QAT Specific patterns - self._annotate_qat_conv2d_bn_binary_unary(model, config) - self._annotate_qat_conv2d_bn_binary(model, config) - self._annotate_qat_conv2d_bn_unary(model, config) - self._annotate_qat_conv2d_bn(model, config) + self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) def _annotate_qat_conv2d_bn_binary_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] @@ -601,25 +852,34 @@ def _annotate_qat_conv2d_bn_binary_unary( ): continue - if _is_annotated([unary_node, binary_node, bn_output_node, conv_node]): + if _skip_annotate( + [unary_node, binary_node, bn_output_node, conv_node], filter_fn + ): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( - quantization_config - ) - binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True, - ) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize([binary_node, unary_node]) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(binary_partition.nodes)) @@ -627,7 +887,10 @@ def _annotate_qat_conv2d_bn_binary_unary( _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_qat_conv2d_bn_binary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] @@ -661,29 +924,37 @@ def _annotate_qat_conv2d_bn_binary( ): continue - if _is_annotated([binary_node, bn_output_node, conv_node]): + if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( - quantization_config - ) - binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(binary_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(binary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_qat_conv2d_bn_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -715,23 +986,31 @@ def _annotate_qat_conv2d_bn_unary( ): continue - if _is_annotated([unary_node, bn_output_node, conv_node]): + if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + unary_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(unary_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) nodes_to_mark_annotated.extend(list(unary_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_qat_conv2d_bn( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] @@ -748,60 +1027,87 @@ def _annotate_qat_conv2d_bn( ): continue - if _is_annotated([bn_output_node, conv_node]): + if _skip_annotate([bn_output_node, conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) - bn_output_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - # TODO Remove the annotate of output in QAT when qat util support pattern matcher. - output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + if quantization_config is not None: + bn_output_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + _annotate_nodes_not_quantize(bn_output_node) nodes_to_mark_annotated = list(conv_partition.nodes) nodes_to_mark_annotated.extend(list(bn_partition.nodes)) _mark_nodes_as_annotated(nodes_to_mark_annotated) - def _annotate_conv2d_fusion_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.conv2d.default): - if config.is_qat: - # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat - self._annotate_qat_conv2d_fusion_pattern(model, config) - self._annotate_conv2d_binary_unary(model, config) - self._annotate_conv2d_binary(model, config) - self._annotate_conv2d_unary(model, config) - self._annotate_conv2d(model, config) - - def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default): - if config.input_activation and not config.input_activation.is_dynamic: - # Weiwen: Dynamic Quant of linear unary will be supported in next step - self._annotate_linear_binary_unary(model, config) - self._annotate_linear_unary(model, config) - self._annotate_linear(model, config) - - def _annotate_matmul(self, model: torch.fx.GraphModule): - if config := self._get_aten_operator_qconfig(torch.ops.aten.matmul.default): - for node in model.graph.nodes: - if node.target == torch.ops.aten.matmul.default and not _is_annotated( - [node] - ): - input_qspec_map = {} - matmul_node = node - for input_node in matmul_node.args: - input_qspec_map[input_node] = get_input_act_qspec(config) - matmul_node.meta[ - QUANT_ANNOTATION_KEY - ] = _X86InductorQuantizationAnnotation( - input_qspec_map=input_qspec_map, - _annotated=True, - _is_output_of_quantized_pattern=True, - ) + def _annotate_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or (quantization_config.is_qat): + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d_unary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) + + def _annotate_linear_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or ( + quantization_config.input_activation + and not quantization_config.input_activation.is_dynamic + ): + # Weiwen: Dynamic Quant of linear unary will be supported in next step + self._annotate_linear_binary_unary(model, quantization_config, filter_fn) + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn) + + def _annotate_matmul( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in model.graph.nodes: + if node.target != torch.ops.aten.matmul.default: + continue + if _skip_annotate([node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + + input_qspec_map = {} + matmul_node = node + for input_node in matmul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) def _annotate_conv2d_binary_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # Conv2d + add + unary op fused_partitions = find_sequential_partitions( @@ -829,8 +1135,13 @@ def _annotate_conv2d_binary_unary( ): # No conv node found to be fused with add continue - if _is_annotated([unary_node, binary_node, conv_node]): + if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( @@ -846,7 +1157,10 @@ def _annotate_conv2d_binary_unary( ) def _annotate_conv2d_binary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # Conv2d + add fused_partitions = find_sequential_partitions( @@ -875,8 +1189,13 @@ def _annotate_conv2d_binary( ): # No conv node found to be fused with add continue - if _is_annotated([binary_node, conv_node]): + if _skip_annotate([binary_node, conv_node], filter_fn): continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node]) + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( @@ -889,7 +1208,10 @@ def _annotate_conv2d_binary( ) def _annotate_conv2d_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: fused_partitions = [] unary_patterns = [ @@ -915,8 +1237,13 @@ def _annotate_conv2d_unary( or conv_node.target != torch.ops.aten.conv2d.default ): continue - if _is_annotated([unary_node, conv_node]): + if _skip_annotate([unary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, unary_node]) continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, @@ -924,7 +1251,10 @@ def _annotate_conv2d_unary( ) def _annotate_conv2d( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: conv_partitions = get_source_partitions( gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] @@ -940,15 +1270,21 @@ def _annotate_conv2d( ): raise ValueError(f"{conv_node} is not an aten conv2d operator") # skip annotation if it is already annotated - if _is_annotated([conv_node]): + if _skip_annotate([conv_node], filter_fn): continue self._annotate_conv_node_helper(conv_node, True, quantization_config) def _annotate_maxpool2d( - self, node: Node, quantization_config: QuantizationConfig + self, + node: Node, + quantization_config: Optional[QuantizationConfig], ) -> None: if node.target is not torch.ops.aten.max_pool2d.default: return + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + maxpool_node = node if _is_any_annotated( [ @@ -956,6 +1292,7 @@ def _annotate_maxpool2d( ] ): return + input_node = maxpool_node.args[0] assert isinstance(input_node, Node) input_qspec_map = {} @@ -969,6 +1306,9 @@ def _annotate_maxpool2d( def _annotate_cat( self, node: Node, quantization_config: QuantizationConfig ) -> None: + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return cat_node = node input_nodes = cat_node.args[0] assert isinstance(input_nodes, Sequence) @@ -993,13 +1333,25 @@ def _annotate_cat( _is_output_of_quantized_pattern=True, ) - def _annotate_propagation_quantizable_pattern(self, node: Node) -> None: + def _annotate_propagation_quantizable_pattern_entry( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in gm.graph.nodes: + self._annotate_propagation_quantizable_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_propagation_quantizable_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: # Propagate annotation to quantizable patterns. if ( (node.target in propagation_quantizable_ops) and (not _is_any_annotated([node])) and (node.op == "call_function") - and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] ): def is_all_inputs_connected_to_quantized_op(input_nodes): @@ -1009,11 +1361,23 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): return False return True + if _skip_annotate([node], filter_fn): + return + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + if node.target is torch.ops.aten.max_pool2d.default: # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not input_nodes_to_check = [node.all_input_nodes[0]] if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + if quantization_config is not None: + warnings.warn( + f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." + ) return + self._annotate_maxpool2d(node, quantization_config) return elif node.target is torch.ops.aten.cat.default: @@ -1056,18 +1420,24 @@ def _annotate_output_share_observer_as_input( ) return - def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: + def _annotate_output_for_int8_in_int8_out_pattern_entry( + self, + model: torch.fx.GraphModule, + ): + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node) + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: r""" Check and insert observer at output of node in int8_in_int8_out_ops if needed. Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 """ edge_or_node: Tuple[Node, Node] - if ( - (node.target in int8_in_int8_out_ops) - and (_is_any_annotated([node])) - and (quantization_config := self._get_aten_operator_qconfig(node.target)) # type: ignore[arg-type] - ): + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): if node.target == torch.ops.aten.max_pool2d.default: maxpool_node = node if not _is_all_annotated( @@ -1076,6 +1446,7 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: ] ): return + # Get the quantization_annotation from getitem_node maxpool_node_quantization_annotation = ( maxpool_node.meta[QUANT_ANNOTATION_KEY] @@ -1100,7 +1471,10 @@ def _annotate_output_for_int8_in_int8_out_pattern(self, node: Node) -> None: return def _annotate_linear( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: linear_partitions = get_source_partitions( gm.graph, [torch.nn.Linear, torch.nn.functional.linear] @@ -1119,12 +1493,15 @@ def _annotate_linear( ): raise ValueError(f"{linear_node} is not an aten linear operator") # skip annotation if it is already annotated - if _is_annotated([linear_node]): + if _skip_annotate([linear_node], filter_fn): continue self._annotate_linear_node_helper(linear_node, True, quantization_config) def _annotate_linear_unary( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: postop_list = [ torch.nn.ReLU, @@ -1146,8 +1523,13 @@ def _annotate_linear_unary( torch.ops.aten.linear.default, ): continue - if _is_annotated([unary_node, linear_node]): + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([linear_node, unary_node]) continue + self._annotate_linear_node_helper(linear_node, False, quantization_config) unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( _annotated=True, @@ -1157,7 +1539,8 @@ def _annotate_linear_unary( def _annotate_linear_binary_unary( self, gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, ) -> None: # linear + binary_op + (optional) unary op binary_op_list = [operator.add] @@ -1214,8 +1597,13 @@ def _annotate_linear_binary_unary( if unary_node is None else [unary_node, binary_node, linear_node] ) - if _is_annotated(node_list): + if _skip_annotate(node_list, filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node_list) continue + self._annotate_linear_node_helper( linear_node, False, quantization_config ) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index f3d1b6ca8b39..88ccc1454f44 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import copy @@ -22,6 +23,7 @@ ) from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import _get_module_name_filter from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( _convert_scalars_to_attrs, @@ -192,40 +194,6 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]: return _get_supported_symmetric_config_and_operators() -def _get_module_name_filter(module_name: str): - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - def module_name_filter(n: Node) -> bool: - # example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - - def _normalize_path(n): - prefix = 0 - # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. - if n.startswith("L['self']."): - prefix = len("L['self'].") - return n[prefix:] - - names = [_normalize_path(n) for n, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter - - def _get_module_type_filter(tp: Callable): """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index 9f1732e57370..928ee0d3ac45 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import operator from dataclasses import dataclass diff --git a/torch/ao/quantization/stubs.py b/torch/ao/quantization/stubs.py index 10a63fb8f0ee..f62a227f1d77 100644 --- a/torch/ao/quantization/stubs.py +++ b/torch/ao/quantization/stubs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch import nn diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index d0de50bbeb57..fadbf33a70b6 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Utils shared by different modes of quantization (eager/graph) """ @@ -121,6 +122,25 @@ def check_node(node, modules): return is_call_function, is_call_method, is_call_module def get_combined_dict(default_dict, additional_dict): + """ + Combines two dictionaries. + + This function takes two dictionaries as input and returns a new dictionary + that contains all the key-value pairs from both input dictionaries. + If there are any duplicate keys in the `additional_dict`, the values + from the `additional_dict` will overwrite those in the `default_dict`. + Args: + default_dict (dict): The main dictionary that will be used as the base + additional_dict (dict): The dictionary used to update `default_dict` + + Returns: + dict: The resulting dictionary + Example: + >>> x = dict(a=1, b=1) + >>> y = dict(b=2, c=3) + >>> get_combined_dict(x, y) + {'a': 1, 'b': 2, 'c': 3} + """ d = default_dict.copy() d.update(additional_dict) return d diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 4cefb143dcc0..aca9abb24070 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ ``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. It requires minimal @@ -255,6 +256,7 @@ def backward( warnings.warn( "`grad_variables` is deprecated. Use `grad_tensors` instead.", FutureWarning, + stacklevel=2, ) if grad_tensors is None: grad_tensors = grad_variables @@ -400,6 +402,7 @@ def grad( "(defaults to True). To accumulate gradient for other " "parts of the graph, please use torch.autograd.backward.", FutureWarning, + stacklevel=2, ) grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs)) diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py index d2b3149bfc81..9c982b074b65 100644 --- a/torch/autograd/_functions/tensor.py +++ b/torch/autograd/_functions/tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from functools import reduce from typing_extensions import deprecated diff --git a/torch/autograd/_functions/utils.py b/torch/autograd/_functions/utils.py index 7111d893400f..56baae4aae3b 100644 --- a/torch/autograd/_functions/utils.py +++ b/torch/autograd/_functions/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from functools import reduce diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 80a2526a81de..7e73ad4ef2c3 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings import torch diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index 747b18f0f369..4187e220ceab 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from collections import namedtuple diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 9aca2b2a1b32..62ec1183a365 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import itertools diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 6701efbedac1..8cf3955a6927 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple import torch diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index be173c9b9de0..1c97ab58298b 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any import torch diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index a0d874038761..5bf74afacb66 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import warnings diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 19938c183557..cde56a6f26c7 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import collections import contextlib diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 38cc0e3a3b35..0392a8769846 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from dataclasses import dataclass from time import perf_counter_ns @@ -216,6 +217,7 @@ def __init__( "The attribute `use_cuda` will be deprecated soon, " "please use ``use_device = 'cuda'`` instead.", FutureWarning, + stacklevel=2, ) self.use_device: Optional[str] = "cuda" else: diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index e8b2b62019bc..40baafd441ae 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import warnings from typing_extensions import deprecated @@ -57,7 +58,10 @@ def __init__( self.with_modules = with_modules if self.use_cuda and not torch.cuda.is_available(): - warnings.warn("CUDA is not available, disabling CUDA profiling") + warnings.warn( + "CUDA is not available, disabling CUDA profiling", + stacklevel=2, + ) self.use_cuda = False if self.use_cuda: diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 23243733aaa8..a5cff1ea12a8 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import bisect import itertools import math diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py index ed841d4da7d4..84b504a9c82c 100644 --- a/torch/autograd/variable.py +++ b/torch/autograd/variable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._C import _ImperativeEngine as ImperativeEngine diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index 2236230e8c6d..086147b87a81 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import types from contextlib import contextmanager diff --git a/torch/backends/_coreml/preprocess.py b/torch/backends/_coreml/preprocess.py index f393929bb7c2..18cb8229db9a 100644 --- a/torch/backends/_coreml/preprocess.py +++ b/torch/backends/_coreml/preprocess.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import hashlib import json from typing import Dict, Tuple diff --git a/torch/backends/_nnapi/prepare.py b/torch/backends/_nnapi/prepare.py index 8b07c3d6e0c6..6ba389902c9f 100644 --- a/torch/backends/_nnapi/prepare.py +++ b/torch/backends/_nnapi/prepare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index 551fa821df68..34bcc42f8927 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import array import enum import functools diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index c35a962ba693..00f511a544e6 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Union @@ -26,6 +27,7 @@ "enable_math_sdp", "can_use_flash_attention", "can_use_efficient_attention", + "can_use_cudnn_attention", "sdp_kernel", ] @@ -358,6 +360,26 @@ def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool return torch._C._can_use_mem_efficient_attention(params, debug) +def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn with information as to why cuDNN attention could not be run. + Defaults to False. + + Returns: + True if cuDNN can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_cudnn_attention(params, debug) + + def cudnn_sdp_enabled(): r""" .. warning:: This flag is beta and subject to change. diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index e00d92f44b28..e528ac68552d 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import sys import warnings diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index aaf0bd02e8af..f2e9d4321a02 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.cuda try: diff --git a/torch/backends/mkl/__init__.py b/torch/backends/mkl/__init__.py index 261ee764485b..9f96d692ae02 100644 --- a/torch/backends/mkl/__init__.py +++ b/torch/backends/mkl/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 9cdee1cbd565..669ed59a1132 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys from contextlib import contextmanager diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index 8d5e70f06a0a..06eda58e82f9 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import lru_cache as _lru_cache from typing import Optional diff --git a/torch/backends/nnpack/__init__.py b/torch/backends/nnpack/__init__.py index 892dfa022cfc..1a30e977cab3 100644 --- a/torch/backends/nnpack/__init__.py +++ b/torch/backends/nnpack/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager import torch diff --git a/torch/backends/openmp/__init__.py b/torch/backends/openmp/__init__.py index 4a7fcca12d0c..aff8d46cd4ac 100644 --- a/torch/backends/openmp/__init__.py +++ b/torch/backends/openmp/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index 2e66cd37542d..993a219fa9aa 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import warnings from contextlib import contextmanager diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py index 85009753e0ae..3cb795dd39fc 100644 --- a/torch/backends/quantized/__init__.py +++ b/torch/backends/quantized/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import types from typing import List diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 0344631ee6b4..bdf07e286174 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations. diff --git a/torch/backends/xnnpack/__init__.py b/torch/backends/xnnpack/__init__.py index c26dc11deb47..31e69876927d 100644 --- a/torch/backends/xnnpack/__init__.py +++ b/torch/backends/xnnpack/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import types diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index cf0b544e929a..812bbaa4c660 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import List @@ -32,22 +33,77 @@ def reset() -> None: def allow_in_graph(fn): """ - Customize which functions compilation will include in the generated graph. - It bypasses all introspection of the symbolic python code in favor of - directly writing it to the graph. - If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()` - to each function and returns a new list or tuple containing the modified functions + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. + + If you are using :func:`torch.compile` (with backend="inductor" (the default)), or + :func:`torch.export.export`, and trying to black-box a Python function throughout + all tracing, do not use this API. + Instead, please create a custom operator (see :ref:`custom-ops-landing-page`) + + .. warning:: + + If you're a typical torch.compile user (e.g. you're applying torch.compile to + a model to make it run faster), you probably don't want to use this function. + :func:`allow_in_graph` is a footgun because it skips the compiler frontend + (Dynamo) that is responsible for doing safety checks (graph breaks, handling + closures, etc). Incorrect usage will lead to difficult-to-debug silent + incorrectness issues. + + Given a Python function with no allow_in_graph decorator, regular execution + of torch.compile traces through the function. :func:`allow_in_graph` changes + it so that the frontend does not trace inside the function, but the compiler + backend still traces through it. Compare this to custom operators, which + treats a function as a black box throughout the torch.compile stack. The following + table compares these mechanisms. + + +------------------------+-----------------------+--------------------------------+ + | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) | + +========================+=======================+================================+ + | no decorator | trace inside | trace inside | + +------------------------+-----------------------+--------------------------------+ + | allow_in_graph | opaque callable | trace inside | + +------------------------+-----------------------+--------------------------------+ + | custom op | opaque callable | opaque callable | + +------------------------+-----------------------+--------------------------------+ + + One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler + frontend: if you know the function works w.r.t. to the downstream components of the + compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from + symbolically introspecting the function properly (or if your code is in C/C++ and + therefore cannot be introspected with Dynamo), then one can decorate said function + with :func:`allow_in_graph` to bypass Dynamo. + + We require that ``fn`` adhere to the following restrictions. Failure to adhere + results in undefined behavior: + + - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include: + Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?] + Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device + - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet) + - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn`` + (as opposed to being captured variables). Args: fn: A callable representing the function to be included in the graph. + If ``fn`` is a list or tuple of callables it recursively applies + :func:`allow_in_graph()` to each function and returns a new list or + tuple containing the modified functions. - .. warning:: + Example:: + + torch.compiler.allow_in_graph(my_custom_function) + + @torch.compile(...) + def fn(a): + x = torch.add(x, 1) + x = my_custom_function(x) + x = torch.add(x, 1) + return x + + fn(...) - :func:`allow_in_graph` skips TorchDynamo completely on the decorated function - skipping all TorchDynamo safety checks (graph breaks, handling closures, etc). - Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems - like AOT Autograd rely on torchdynamo - If not careful, this could lead to soundness and really hard-to-debug issues. + Will capture a single graph containing ``my_custom_function()``. """ import torch._dynamo diff --git a/torch/contrib/_tensorboard_vis.py b/torch/contrib/_tensorboard_vis.py index 87c325948a8b..ed1445dd7bce 100644 --- a/torch/contrib/_tensorboard_vis.py +++ b/torch/contrib/_tensorboard_vis.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import time from collections import defaultdict from functools import partial diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 2f2561b69c1c..d404ad4ba3b9 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package implements abstractions found in ``torch.cuda`` to facilitate writing device-agnostic code. @@ -28,9 +29,20 @@ _device_t = Union[_device, str, int, None] +def _is_cpu_support_avx2() -> bool: + r"""Returns a bool indicating if CPU supports AVX2.""" + return torch._C._cpu._is_cpu_support_avx2() + + +def _is_cpu_support_avx512() -> bool: + r"""Returns a bool indicating if CPU supports AVX512.""" + return torch._C._cpu._is_cpu_support_avx512() + + def _is_cpu_support_vnni() -> bool: r"""Returns a bool indicating if CPU supports VNNI.""" - return torch._C._cpu._is_cpu_support_vnni() + # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. + return torch._C._cpu._is_cpu_support_avx512_vnni() def is_available() -> bool: diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index b545e91dd6f4..b61e9b542dba 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any from typing_extensions import deprecated diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 4f8d614e16dc..e4779ff984bc 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -19,7 +19,7 @@ #include #endif -static inline void PyErr_SetString(PyObject* type, const std::string& message) { +inline void PyErr_SetString(PyObject* type, const std::string& message) { PyErr_SetString(type, message.c_str()); } /// NOTE [ Conversion Cpp Python Warning ] diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 9ff9131435f4..57b28d676484 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -67,6 +67,7 @@ #include #include #include +#include #include #include #include @@ -375,22 +376,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { THPVariable* a = reinterpret_cast(a_); THPVariable* b = reinterpret_cast(b_); - TORCH_CHECK( - a->cdata->use_count() == 1, - "Expected single reference to a's Tensor object but got ", - a->cdata->use_count()); - TORCH_CHECK( - b->cdata->use_count() == 1, - "Expected single reference to b's Tensor object but got ", - b->cdata->use_count()); // weak_use_count() adds 1 if use_count is non-zero TORCH_CHECK( a->cdata->weak_use_count() == 1, - "Expected no weakrefs to a's Tensor object but got ", + "Expected no weakrefs to t1's Tensor object but got ", a->cdata->weak_use_count() - 1); TORCH_CHECK( b->cdata->weak_use_count() == 1, - "Expected no weakrefs to b's Tensor object but got ", + "Expected no weakrefs to t2's Tensor object but got ", b->cdata->weak_use_count() - 1); // Swap the Tensor Impl @@ -422,19 +415,6 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { END_HANDLE_TH_ERRORS } -PyObject* THPModule_check_tp_alloc_is_default( - PyObject* _unused, - PyObject* cls) { - HANDLE_TH_ERRORS - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); - return PyBool_FromLong(Py_TYPE(cls)->tp_alloc == PyType_GenericAlloc); - END_HANDLE_TH_ERRORS -} - PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { // adds a __doc__ string to a function, similar to numpy's arr_add_docstring static std::vector all_docs; @@ -1281,10 +1261,6 @@ static PyMethodDef TorchMethods[] = { // NOLINT {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr}, - {"_check_tp_alloc_is_default", - THPModule_check_tp_alloc_is_default, - METH_O, - nullptr}, {"_init_names", THPModule_initNames, METH_O, nullptr}, {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr}, {"_set_default_tensor_type", @@ -1627,6 +1603,8 @@ PyObject* initModule() { THPDevice_init(module); THPStream_init(module); THPEvent_init(module); + NodeBase_init(module); + NodeIter_init(module); ASSERT_TRUE(THPVariable_initModule(module)); ASSERT_TRUE(THPFunction_initModule(module)); ASSERT_TRUE(THPEngine_initModule(module)); @@ -1943,6 +1921,15 @@ Call this whenever a new thread is created in order to propagate values from return sdp::can_use_mem_efficient_attention(params, debug); #else return false; +#endif + }); + py_module.def( + "_can_use_cudnn_attention", + [](const sdp::sdp_params& params, bool debug) { +#ifdef USE_CUDA + return sdp::can_use_cudnn_attention(params, debug); +#else + return false; #endif }); @@ -2162,50 +2149,13 @@ Call this whenever a new thread is created in order to propagate values from return torch::should_allow_numbers_as_tensors(name); }); - // FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype` - // Currently I see the second item of the key is displayed as - // e.g. `torch._C._te.ScalarType at 0x7fcf318adab0` - // I thought adding an appropriate type_caster of `at::ScalarType` to - // torch/csrc/pybind.h` would solve this but it caused segmentation fault in - // my environment. - using _DeviceDtypeKey = std::pair; - // Custom hasher is necessary to make unordered_map compilable for Windows - // debug targets. As `at::native::ParamsHash` only works on structs with - // standard layout, but std::string isn't one in Visual C++ debug builds, - // which one can easily verify by running something like: - // #define _DEBUG - // #include - // #include - // static_assert(std::is_standard_layout_v, "Oh noes"); - // If above condition is not met, VC++ raises a very cryptic compilation - // error. See - // https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for - // more detail - struct _DeviceDtypeHasher { - std::size_t operator()(const _DeviceDtypeKey& k) const noexcept { - static at::native::ParamsHash device_hasher; - static std::hash string_hasher; - return device_hasher(k.first) ^ string_hasher(k.second); - } - }; - using _FlatMap = std::unordered_map< - _DeviceDtypeKey, - at::native::TensorsAndIndicesT, - _DeviceDtypeHasher>; py_module.def( "_group_tensors_by_device_and_dtype", [](const std::vector>>& nested_tensorlist, const bool with_indices) { - _FlatMap map; - for (const auto& iter : - at::native::_group_tensors_by_first_tensors_device_and_dtype( - nested_tensorlist, with_indices)) { - const auto scalar_type_name = - torch::utils::getDtypeNames(iter.first.second).first; - map.insert({{iter.first.first, scalar_type_name}, iter.second}); - } - return map; + return at::native::_group_tensors_by_first_tensors_device_and_dtype( + nested_tensorlist, with_indices); }); py_module.def( diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 16bf87bbcc2e..55deb18892bb 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -23,11 +23,11 @@ TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( bool allow_preexisting_pyobj = false); extern PyTypeObject* THPStorageClass; -static inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { +inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { return tp == THPStorageClass; } -static inline bool THPStorage_CheckExact(PyObject* obj) { +inline bool THPStorage_CheckExact(PyObject* obj) { return THPStorage_CheckTypeExact(Py_TYPE(obj)); } diff --git a/torch/csrc/api/include/torch/data/dataloader.h b/torch/csrc/api/include/torch/data/dataloader.h index 06ea83d8a232..a7bbdcb27d84 100644 --- a/torch/csrc/api/include/torch/data/dataloader.h +++ b/torch/csrc/api/include/torch/data/dataloader.h @@ -18,8 +18,8 @@ namespace data { /// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and /// some `options`. template -torch::disable_if_t< - Dataset::is_stateful, +std::enable_if_t< + !Dataset::is_stateful, std::unique_ptr>> make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { return std::make_unique>( @@ -30,8 +30,8 @@ make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { /// `options`. A sampler (by default a `RandomSampler`) will be constructed from /// the size of the dataset. template -torch::disable_if_t< - Dataset::is_stateful || !std::is_constructible::value, +std::enable_if_t< + !Dataset::is_stateful && std::is_constructible_v, std::unique_ptr>> make_data_loader( Dataset dataset, @@ -46,7 +46,7 @@ make_data_loader( } /// Creates a `DataLoader` for a stateful `dataset` and some `options`. -template > +template > std::unique_ptr> make_data_loader( Dataset dataset, DataLoaderOptions options = DataLoaderOptions()) { diff --git a/torch/csrc/api/include/torch/data/dataloader/stateful.h b/torch/csrc/api/include/torch/data/dataloader/stateful.h index e8eb85861f77..22d584ce4a00 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateful.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateful.h @@ -36,10 +36,8 @@ class StatefulDataLoader : public DataLoaderBase< /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`. StatefulDataLoader(Dataset dataset, DataLoaderOptions options) - : super( - std::move(options), - std::make_unique(std::move(dataset))) { - for (const auto w : c10::irange(this->options_.workers)) { + : super(options, std::make_unique(std::move(dataset))) { + for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) { // As opposed to the stateless case, here all worker threads access the // same underlying dataset. this->workers_.emplace_back( diff --git a/torch/csrc/api/include/torch/data/datasets/map.h b/torch/csrc/api/include/torch/data/datasets/map.h index 7b8b8febd222..facd4fe28705 100644 --- a/torch/csrc/api/include/torch/data/datasets/map.h +++ b/torch/csrc/api/include/torch/data/datasets/map.h @@ -71,7 +71,7 @@ class MapDataset : public BatchDataset< /// applies the transform to the output of `get_batch()` from the dataset. template < typename D = SourceDataset, - typename = torch::disable_if_t> + typename = std::enable_if_t> OutputBatchType get_batch_impl(BatchRequestType indices) { return transform_.apply_batch(dataset_.get_batch(std::move(indices))); } @@ -82,7 +82,7 @@ class MapDataset : public BatchDataset< /// contains a value, and returns a new optional (of a different type) if the /// original optional returned by `get_batch()` was empty. template - torch::enable_if_t get_batch_impl( + std::enable_if_t get_batch_impl( BatchRequestType indices) { if (auto batch = dataset_.get_batch(std::move(indices))) { return transform_.apply_batch(std::move(*batch)); diff --git a/torch/csrc/api/include/torch/nn/modules/container/any.h b/torch/csrc/api/include/torch/nn/modules/container/any.h index 05983b1ea106..35d9c91b8ca3 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any.h @@ -340,7 +340,7 @@ std::unique_ptr AnyModule::make_holder( "AnyModule cannot store modules that return void " "(you can return a dummy value)."); return std::make_unique< - AnyModuleHolder, ArgumentTypes...>>( + AnyModuleHolder, ArgumentTypes...>>( std::move(module)); } diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h index cd1dca9ff7a0..4d1e69650035 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h @@ -40,10 +40,10 @@ struct AnyModuleHolder : public AnyModulePlaceholder { /// \internal struct CheckedGetter { template - decay_t&& operator()(size_t index) { + std::decay_t&& operator()(size_t index) { AT_ASSERT(index < arguments_.size()); auto& value = arguments_[index]; - if (auto* maybe_value = value.template try_get>()) { + if (auto* maybe_value = value.template try_get>()) { return std::move(*maybe_value); } AT_ERROR( diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_value.h b/torch/csrc/api/include/torch/nn/modules/container/any_value.h index 3e6c23ef977c..d154130618f2 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_value.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_value.h @@ -40,7 +40,8 @@ class AnyValue { template // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) explicit AnyValue(T&& value) - : content_(std::make_unique>>(std::forward(value))) { + : content_( + std::make_unique>>(std::forward(value))) { } /// Returns a pointer to the value contained in the `AnyValue` if the type diff --git a/torch/csrc/api/include/torch/nn/modules/container/functional.h b/torch/csrc/api/include/torch/nn/modules/container/functional.h index dbd2b0aaebdc..3f381a63944f 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/functional.h +++ b/torch/csrc/api/include/torch/nn/modules/container/functional.h @@ -65,7 +65,7 @@ class TORCH_API FunctionalImpl : public torch::nn::Cloneable { template < typename SomeFunction, typename... Args, - typename = torch::enable_if_t<(sizeof...(Args) > 0)>> + typename = std::enable_if_t<(sizeof...(Args) > 0)>> explicit FunctionalImpl(SomeFunction original_function, Args&&... args) // NOLINTNEXTLINE(modernize-avoid-bind) : function_(std::bind( diff --git a/torch/csrc/api/include/torch/nn/modules/container/modulelist.h b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h index 72a76163ac03..683b6416b04f 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/modulelist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h @@ -91,7 +91,7 @@ class ModuleListImpl : public Cloneable { void push_back(std::shared_ptr module) { modules_.push_back(std::move(module)); const auto index = modules_.size() - 1; - register_module(c10::to_string(index), modules_[index]); + register_module(std::to_string(index), modules_[index]); } /// Adds a new `Module` to the `ModuleList` container, moving or copying @@ -224,9 +224,9 @@ class ModuleListImpl : public Cloneable { for (const auto i : c10::irange(index, size() - 1)) { (void)i; // Suppress unused variable warning - replace_module(c10::to_string(index), modules_[index]); + replace_module(std::to_string(index), modules_[index]); } - register_module(c10::to_string(size() - 1), modules_.back()); + register_module(std::to_string(size() - 1), modules_.back()); } } diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h index 30b7eb89e48b..cb816d1bb2a1 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -50,14 +50,14 @@ class ParameterListImpl : public Cloneable { void append(torch::Tensor&& param) { bool requires_grad = param.requires_grad(); register_parameter( - c10::to_string(parameters_.size()), std::move(param), requires_grad); + std::to_string(parameters_.size()), std::move(param), requires_grad); } /// push the a given parameter at the end of the list void append(const torch::Tensor& param) { bool requires_grad = param.requires_grad(); register_parameter( - c10::to_string(parameters_.size()), param, requires_grad); + std::to_string(parameters_.size()), param, requires_grad); } /// push the a given parameter at the end of the list @@ -65,7 +65,7 @@ class ParameterListImpl : public Cloneable { /// will be added into the `ParameterList` void append(const OrderedDict::Item& pair) { register_parameter( - c10::to_string(parameters_.size()), + std::to_string(parameters_.size()), pair.value(), pair.value().requires_grad()); } @@ -111,7 +111,7 @@ class ParameterListImpl : public Cloneable { /// for a non-throwing way of access at::Tensor& at(size_t idx) { TORCH_CHECK(idx < size(), "Index out of range"); - return parameters_[c10::to_string(idx)]; + return parameters_[std::to_string(idx)]; } /// Returns the value associated with the given `key`. Throws an exception if @@ -119,7 +119,7 @@ class ParameterListImpl : public Cloneable { /// for a non-throwing way of access const at::Tensor& at(size_t idx) const { TORCH_CHECK(idx < size(), "Index out of range"); - return parameters_[c10::to_string(idx)]; + return parameters_[std::to_string(idx)]; } /// Returns the value associated with the given `key`. Throws an exception if diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h index 9494926eef3c..acefa23d49e5 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -195,7 +195,7 @@ class SequentialImpl : public Cloneable { /// Adds a new (boxed) `Module` to the `Sequential` container. template void push_back(std::shared_ptr module_ptr) { - push_back(c10::to_string(modules_.size()), std::move(module_ptr)); + push_back(std::to_string(modules_.size()), std::move(module_ptr)); } /// Adds a new named (boxed) `Module` to the `Sequential` container. @@ -211,7 +211,7 @@ class SequentialImpl : public Cloneable { /// `Sequential(std::make_shared(3, 4))`. template > void push_back(M&& module) { - push_back(c10::to_string(modules_.size()), std::forward(module)); + push_back(std::to_string(modules_.size()), std::forward(module)); } /// Adds a new named `Module` to the `Sequential` container, moving or copying @@ -219,7 +219,7 @@ class SequentialImpl : public Cloneable { /// and letting the container deal with the boxing. template > void push_back(std::string name, M&& module) { - using Type = typename std::remove_reference::type; + using Type = typename std::remove_reference_t; push_back(std::move(name), std::make_shared(std::forward(module))); } @@ -227,7 +227,7 @@ class SequentialImpl : public Cloneable { /// `Sequential`. template void push_back(const ModuleHolder& module_holder) { - push_back(c10::to_string(modules_.size()), module_holder); + push_back(std::to_string(modules_.size()), module_holder); } /// Unwraps the contained named module of a `ModuleHolder` and adds it to the @@ -247,7 +247,7 @@ class SequentialImpl : public Cloneable { /// Adds a type-erased `AnyModule` to the `Sequential`. void push_back(AnyModule any_module) { - push_back(c10::to_string(modules_.size()), std::move(any_module)); + push_back(std::to_string(modules_.size()), std::move(any_module)); } void push_back(std::string name, AnyModule any_module) { @@ -348,12 +348,10 @@ class SequentialImpl : public Cloneable { typename First, typename Second, typename... Rest, - typename = torch::disable_if_t< - std::is_same::value || + typename = std::enable_if_t< + !std::is_same_v && // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - std::is_same< - typename std::decay::type, - std::decay::type>::value>> + !std::is_same_v, std::decay_t>>> void push_back(First&& first, Second&& second, Rest&&... rest) { push_back(std::forward(first)); // Recursively calls this method, until the parameter pack only thas this diff --git a/torch/csrc/api/include/torch/nn/pimpl-inl.h b/torch/csrc/api/include/torch/nn/pimpl-inl.h index b38e6cf2c0ff..cea53b6562bd 100644 --- a/torch/csrc/api/include/torch/nn/pimpl-inl.h +++ b/torch/csrc/api/include/torch/nn/pimpl-inl.h @@ -6,10 +6,12 @@ struct ModuleHolderIndicator {}; // A type trait that is true for types that are `ModuleHolder`s. template -using is_module_holder = std::is_base_of>; +using is_module_holder = + std::is_base_of>; template -using disable_if_module_holder_t = disable_if_t::value>; +using disable_if_module_holder_t = + std::enable_if_t::value>; // A collection of templates that answer the question whether a type `T` is a // `ModuleHolder`, and if so whether its contained type is of type `C`. This is @@ -43,8 +45,8 @@ struct is_module_holder_of_impl template struct is_module_holder_of : is_module_holder_of_impl< is_module_holder::value, - decay_t, - decay_t> {}; + std::decay_t, + std::decay_t> {}; // A collection of templates that allow deducing the return type of the // `forward()` method, but only if a module actually has a `forward()` method, diff --git a/torch/csrc/api/include/torch/python.h b/torch/csrc/api/include/torch/python.h index 15902a026cf5..cc9d6a51a6de 100644 --- a/torch/csrc/api/include/torch/python.h +++ b/torch/csrc/api/include/torch/python.h @@ -212,8 +212,8 @@ py::class_ add_module_bindings( /// } /// \endrst template -torch::disable_if_t< - torch::detail::has_forward::value && !force_enable, +std::enable_if_t< + !torch::detail::has_forward::value || force_enable, detail::PyModuleClass> bind_module(py::module module, const char* name) { py::module cpp = module.def_submodule("cpp"); @@ -249,8 +249,7 @@ bind_module(py::module module, const char* name) { /// \endrst template < typename ModuleType, - typename = - torch::enable_if_t::value>> + typename = std::enable_if_t::value>> detail::PyModuleClass bind_module( py::module module, const char* name) { diff --git a/torch/csrc/autograd/cpp_hook.cpp b/torch/csrc/autograd/cpp_hook.cpp index 36f4671ee2e6..b851078b5280 100644 --- a/torch/csrc/autograd/cpp_hook.cpp +++ b/torch/csrc/autograd/cpp_hook.cpp @@ -41,7 +41,7 @@ variable_list CppFunctionTensorPreHook::operator()( // Don't change gradient continue; } - check_single_result(value, res, c10::to_string(i)); + check_single_result(value, res, std::to_string(i)); value = std::move(res); } variable_list results(values); diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 8c20bd807820..aed3eaa3e558 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -444,8 +444,8 @@ variable_list CppNode::apply(variable_list&& inputs) { if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; - msg += c10::to_string(num_forward_inputs) + ", got "; - msg += c10::to_string(num_outputs) + ")"; + msg += std::to_string(num_forward_inputs) + ", got "; + msg += std::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } @@ -458,8 +458,8 @@ variable_list CppNode::apply(variable_list&& inputs) { std::string msg("function "); msg += name() + " returned a gradient different that is defined at position "; - msg += c10::to_string(i + 1) + - ", but the corresponding forward input was not a Variable"; + msg += std::to_string(i + 1) + + ", std the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 799188be9a68..5fcc7b86a2fa 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -20,7 +19,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 078b0f92124c..65f4b0efd3c1 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1615,6 +1615,13 @@ int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) { END_HANDLE_TH_ERRORS_RET(-1) } +PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + const auto& t = THPVariable_Unpack(self); + return THPUtils_packUInt64(t.use_count()); + END_HANDLE_TH_ERRORS +} + // properties are registered here because we are currently only able to bind // them manually. TODO: make declarable in native_functions // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) @@ -1766,6 +1773,7 @@ static PyMethodDef extra_methods[] = { THPVariable_rev_view_func_unsafe, METH_O, nullptr}, + {"_use_count", THPVariable__use_count, METH_NOARGS, nullptr}, {nullptr}}; struct THPVariableMeta { diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index d0cb13e9f33e..51ade77f03ec 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -39,7 +39,7 @@ TORCH_PYTHON_API extern PyObject* ParameterClass; bool THPVariable_initModule(PyObject* module); TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var); -static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { +inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass. // (A subclass could have different semantics.) The one exception is // Parameter, which is used for Python bookkeeping but is equivalent to @@ -49,7 +49,7 @@ static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { tp == (PyTypeObject*)ParameterClass); } -static inline bool THPVariable_CheckExact(PyObject* obj) { +inline bool THPVariable_CheckExact(PyObject* obj) { return THPVariable_CheckTypeExact(Py_TYPE(obj)); } diff --git a/torch/csrc/autograd/python_variable_indexing.h b/torch/csrc/autograd/python_variable_indexing.h index a0e35a6e9eff..78c4a546ddbe 100644 --- a/torch/csrc/autograd/python_variable_indexing.h +++ b/torch/csrc/autograd/python_variable_indexing.h @@ -15,7 +15,7 @@ struct UnpackedSlice { }; // This mirrors Cpython's PySlice_Unpack method -static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { +inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { PySliceObject* r = (PySliceObject*)_r; /* this is harder to get right than you might think */ diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index f577c0c0dae1..3485f2a991cb 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -2,15 +2,15 @@ #include #include -namespace torch { -namespace cpu { +namespace torch::cpu { void initModule(PyObject* module) { auto m = py::handle(module).cast(); auto cpu = m.def_submodule("_cpu", "cpu related pybind."); - cpu.def("_is_cpu_support_vnni", at::cpu::is_cpu_support_vnni); + cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2); + cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512); + cpu.def("_is_cpu_support_avx512_vnni", at::cpu::is_cpu_support_avx512_vnni); } -} // namespace cpu -} // namespace torch +} // namespace torch::cpu diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 030c5a2b5ccf..4197c2aa5e81 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -1403,6 +1404,275 @@ PyObject* THCPModule_rocm_is_backward_pass( END_HANDLE_TH_ERRORS } +PyObject* THCPModule_cuda_tunableop_enable(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_tunableop_enable expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->EnableTunableOp( + THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_is_enabled( + PyObject* _unused, + PyObject* noarg) { + HANDLE_TH_ERRORS + if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_tuning_enable( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_tunableop_tuning_enable expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->EnableTuning(THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_tuning_is_enabled( + PyObject* _unused, + PyObject* noarg) { + HANDLE_TH_ERRORS + if (at::cuda::tunable::getTuningContext()->IsTuningEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_write_file_on_exit( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_tunableop_write_file_on_exit expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->WriteFileOnExit( + THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_set_max_tuning_duration( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "cuda_tunableop_set_max_tuning_duration expects an int, but got ", + THPUtils_typename(arg)); + auto duration = static_cast(THPUtils_unpackLong(arg)); + at::cuda::tunable::getTuningContext()->SetMaxTuningDurationMs(duration); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_max_tuning_duration( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packInt32( + at::cuda::tunable::getTuningContext()->GetMaxTuningDurationMs()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_set_max_tuning_iterations( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "cuda_tunableop_set_max_tuning_iterations expects an int, but got ", + THPUtils_typename(arg)); + auto iterations = static_cast(THPUtils_unpackLong(arg)); + at::cuda::tunable::getTuningContext()->SetMaxTuningIterations(iterations); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_max_tuning_iterations( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packInt32( + at::cuda::tunable::getTuningContext()->GetMaxTuningIterations()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_set_filename( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* obj_str = nullptr; + PyObject* obj_ord = nullptr; + if (!PyArg_ParseTuple(args, "O|O", &obj_str, &obj_ord)) { + } + TORCH_CHECK( + THPUtils_checkString(obj_str), + "cuda_tunableop_set_filename expects a string, but got ", + THPUtils_typename(obj_str)); + auto filename = THPUtils_unpackString(obj_str); + bool dev = false; + if (obj_ord) { + TORCH_CHECK( + THPUtils_checkBool(obj_ord), + "cuda_tunableop_set_filename expects a bool, but got ", + THPUtils_typename(obj_ord)); + dev = THPUtils_unpackBool(obj_ord); + } + at::cuda::tunable::getTuningContext()->SetFilename(filename, dev); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_filename( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packString( + at::cuda::tunable::getTuningContext()->GetFilename()); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_write_file( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* str = nullptr; + bool success = false; + if (!PyArg_ParseTuple(args, "|O", &str)) { + } + if (str) { + TORCH_CHECK( + THPUtils_checkString(str), + "cuda_tunableop_write_file expects a string, but got ", + THPUtils_typename(str)); + auto filename = THPUtils_unpackString(str); + success = at::cuda::tunable::getTuningContext()->WriteFile(filename); + } else { + success = at::cuda::tunable::getTuningContext()->WriteFile(); + } + if (success) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_read_file( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* str = nullptr; + bool success = false; + if (!PyArg_ParseTuple(args, "|O", &str)) { + } + if (str) { + TORCH_CHECK( + THPUtils_checkString(str), + "cuda_tunableop_read_file expects a string, but got ", + THPUtils_typename(str)); + auto filename = THPUtils_unpackString(str); + success = at::cuda::tunable::getTuningContext()->ReadFile(filename); + } else { + success = at::cuda::tunable::getTuningContext()->ReadFile(); + } + if (success) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_results( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto results = + at::cuda::tunable::getTuningContext()->GetTuningResultsManager().Dump(); + size_t result_size = 0; + for (const auto& [op_sig, kernelmap] : results) { + result_size += kernelmap.size(); + } + THPObjectPtr outer_tuple(PyTuple_New(result_size)); + if (!outer_tuple) + throw python_error(); + size_t result_index = 0; + for (const auto& [op_sig, kernelmap] : results) { + for (const auto& [param_sig, result] : kernelmap) { + THPObjectPtr inner_tuple(PyTuple_New(4)); + if (!inner_tuple) + throw python_error(); + PyObject* obj_op_sig = THPUtils_packString(op_sig); + if (!obj_op_sig) + throw python_error(); + PyObject* obj_param_sig = THPUtils_packString(param_sig); + if (!obj_param_sig) + throw python_error(); + PyObject* obj_result_key = THPUtils_packString(result.GetKey()); + if (!obj_result_key) + throw python_error(); + PyObject* obj_result_time = PyFloat_FromDouble(result.GetTime()); + if (!obj_result_time) + throw python_error(); + PyTuple_SET_ITEM(inner_tuple.get(), 0, obj_op_sig); + PyTuple_SET_ITEM(inner_tuple.get(), 1, obj_param_sig); + PyTuple_SET_ITEM(inner_tuple.get(), 2, obj_result_key); + PyTuple_SET_ITEM(inner_tuple.get(), 3, obj_result_time); + PyTuple_SET_ITEM( + outer_tuple.get(), result_index++, inner_tuple.release()); + } + } + return outer_tuple.release(); + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_tunableop_get_validators( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto validators = at::cuda::tunable::getTuningContext() + ->GetTuningResultsValidator() + .GetAllValidators(); + THPObjectPtr outer_tuple(PyTuple_New(validators.size())); + if (!outer_tuple) + throw python_error(); + size_t validator_index = 0; + for (const auto& [key, val] : validators) { + THPObjectPtr inner_tuple(PyTuple_New(2)); + if (!inner_tuple) + throw python_error(); + PyObject* obj_key = THPUtils_packString(key); + if (!obj_key) + throw python_error(); + PyObject* obj_val = THPUtils_packString(val); + if (!obj_val) + throw python_error(); + PyTuple_SET_ITEM(inner_tuple.get(), 0, obj_key); + PyTuple_SET_ITEM(inner_tuple.get(), 1, obj_val); + PyTuple_SET_ITEM( + outer_tuple.get(), validator_index++, inner_tuple.release()); + } + return outer_tuple.release(); + END_HANDLE_TH_ERRORS +} + static PyObject* THCPModule_isCurrentStreamCapturing_wrap( PyObject* self, PyObject* noargs) { @@ -1576,6 +1846,66 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_rocm_is_backward_pass, METH_NOARGS, nullptr}, + {"_cuda_tunableop_enable", + THCPModule_cuda_tunableop_enable, + METH_O, + nullptr}, + {"_cuda_tunableop_is_enabled", + THCPModule_cuda_tunableop_is_enabled, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_tuning_enable", + THCPModule_cuda_tunableop_tuning_enable, + METH_O, + nullptr}, + {"_cuda_tunableop_tuning_is_enabled", + THCPModule_cuda_tunableop_tuning_is_enabled, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_write_file_on_exit", + THCPModule_cuda_tunableop_write_file_on_exit, + METH_O, + nullptr}, + {"_cuda_tunableop_set_max_tuning_duration", + THCPModule_cuda_tunableop_set_max_tuning_duration, + METH_O, + nullptr}, + {"_cuda_tunableop_get_max_tuning_duration", + THCPModule_cuda_tunableop_get_max_tuning_duration, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_set_max_tuning_iterations", + THCPModule_cuda_tunableop_set_max_tuning_iterations, + METH_O, + nullptr}, + {"_cuda_tunableop_get_max_tuning_iterations", + THCPModule_cuda_tunableop_get_max_tuning_iterations, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_set_filename", + THCPModule_cuda_tunableop_set_filename, + METH_VARARGS, + nullptr}, + {"_cuda_tunableop_get_filename", + THCPModule_cuda_tunableop_get_filename, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_write_file", + THCPModule_cuda_tunableop_write_file, + METH_VARARGS, + nullptr}, + {"_cuda_tunableop_read_file", + THCPModule_cuda_tunableop_read_file, + METH_VARARGS, + nullptr}, + {"_cuda_tunableop_get_results", + THCPModule_cuda_tunableop_get_results, + METH_NOARGS, + nullptr}, + {"_cuda_tunableop_get_validators", + THCPModule_cuda_tunableop_get_validators, + METH_NOARGS, + nullptr}, {nullptr}}; PyMethodDef* THCPModule_methods() { diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index 65ea8a600b57..cbfa64af2523 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -84,12 +84,6 @@ static void THCPStream_dealloc(THCPStream* self) { Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject* THCPStream_get_device(THCPStream* self, void* unused) { - HANDLE_TH_ERRORS - return THPDevice_New(self->cuda_stream.device()); - END_HANDLE_TH_ERRORS -} - static PyObject* THCPStream_get_cuda_stream(THCPStream* self, void* unused) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(self->cuda_stream.stream()); diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index b118bd4600a5..37d1be15cbd7 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -88,7 +88,7 @@ namespace detail { TORCH_CUDA_CPP_API void throw_nccl_error(ncclResult status); -static inline void NCCL_CHECK(ncclResult status) { +inline void NCCL_CHECK(ncclResult status) { if (status != ncclResult::Success) { throw_nccl_error(status); } diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 9d525f0d5640..2485999e7a00 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -199,7 +199,7 @@ at::Tensor all_gather_into_tensor( at::Tensor& all_gather_into_tensor_out( at::Tensor& input, int64_t group_size, - std::string group_name, + const std::string& group_name, at::Tensor& output) { c10d::AllgatherOptions opts; @@ -463,9 +463,9 @@ class ReduceScatterTensor static torch::autograd::Variable forward( torch::autograd::AutogradContext* ctx, const at::Tensor& input, - std::string reduce_op, + const std::string& reduce_op, int64_t group_size, - std::string group_name) { + const std::string& group_name) { TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported"); ctx->saved_data["group_size"] = group_size; @@ -510,9 +510,9 @@ class ReduceScatterTensor at::Tensor reduce_scatter_tensor_autograd( const at::Tensor& input, - std::string reduce_op, + const std::string& reduce_op, int64_t group_size, - std::string group_name) { + const std::string& group_name) { return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name); } @@ -523,7 +523,7 @@ class AllGatherIntoTensor torch::autograd::AutogradContext* ctx, const at::Tensor& input, int64_t group_size, - std::string group_name) { + const std::string& group_name) { ctx->saved_data["group_size"] = group_size; ctx->saved_data["group_name"] = group_name; @@ -566,7 +566,7 @@ class AllGatherIntoTensor at::Tensor all_gather_into_tensor_autograd( const at::Tensor& input, int64_t group_size, - std::string group_name) { + const std::string& group_name) { return AllGatherIntoTensor::apply(input, group_size, group_name); } @@ -607,7 +607,7 @@ at::Tensor shard_dim_alltoall( const at::Tensor& input, int64_t gather_dim, int64_t shard_dim, - std::string group_name) { + const std::string& group_name) { auto group = c10d::resolve_process_group(group_name); auto group_size = group->getSize(); std::vector output_sizes = input.sizes().vec(); @@ -619,12 +619,14 @@ at::Tensor shard_dim_alltoall( } output_sizes[shard_dim] = output_sizes[shard_dim] / group_size; std::vector inputs; + inputs.reserve(group_size); auto length = output_sizes[shard_dim]; for (int i = 0; i < group_size; i++) { inputs.push_back(input.narrow(shard_dim, i * length, length).contiguous()); } // allocate outputs std::vector outputs; + outputs.reserve(group_size); for (int i = 0; i < group_size; i++) { outputs.push_back(input.new_empty(output_sizes).contiguous()); } diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index e26ab22f1a9f..bc820fc1c8d5 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -63,6 +63,33 @@ void NCCLComm::waitUntilInitialized(int timeoutSecs) { } } +#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) +// last argument to split() API is not used to support +// multiple implementations +std::shared_ptr NCCLComm::split( + NCCLComm* source, + int color_id, + int rank, + ncclConfig_t& config, + std::vector& ranks_ull) { + auto comm = std::make_shared(); + C10D_NCCL_CHECK( + ncclCommSplit( + source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), + c10::nullopt); + ++source->ncclCommSplitCounter_; + ncclCommUserRank(comm->ncclComm_, &comm->rank_); + return comm; +} +#endif + +#ifndef FBCODE_CAFFE2 +bool shouldBroadcastNCCLUniqueID(bool isSendRecvSelf) { + // For point-to-point communication on the same process, don't need broadcast. + return !isSendRecvSelf; +} +#endif + std::string getNcclVersion() { static c10::once_flag ncclGetVersionFlag; static std::string versionString; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 5690c0591a7a..9ce25b55dc13 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -173,14 +173,15 @@ namespace c10d { TORCH_API size_t hashTensors(const std::vector& tensors); -std::string getNcclVersion(); -std::string ncclGetErrorWithVersion(ncclResult_t error); +TORCH_API std::string getNcclVersion(); +TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); bool nccl_use_nonblocking(); int nccl_nonblocking_timeout(); +bool shouldBroadcastNCCLUniqueID(bool isSendRecvSelf); // Provides additional detail into NCCL error codes based on when these are // thrown in the NCCL codebase. -std::string getNcclErrorDetailStr( +TORCH_API std::string getNcclErrorDetailStr( ncclResult_t error, std::optional processGroupFailureReason = c10::nullopt); @@ -286,22 +287,12 @@ class NCCLComm { } #endif -#ifdef NCCL_HAS_COMM_SPLIT static std::shared_ptr split( NCCLComm* source, int color_id, int rank, - ncclConfig_t& config) { - auto comm = std::make_shared(); - C10D_NCCL_CHECK( - ncclCommSplit( - source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), - c10::nullopt); - ++source->ncclCommSplitCounter_; - comm->rank_ = rank; - return comm; - } -#endif + ncclConfig_t& config, + std::vector& ranks_ull); #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump() { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 91ce50a4183f..bb9198f22200 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -197,6 +198,20 @@ inline std::string getKeyFromDevice(at::Device& device) { return std::to_string(device.index()); } +inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { + // initialize the device index to -1, which is an invalid value. + int index = -1; + try { + index = std::stoi(deviceKey); + } catch (const std::invalid_argument& e) { + LOG(WARNING) << c10::str( + "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); + } catch (const std::out_of_range& e) { + LOG(ERROR) << "Out of range: " << e.what(); + } + return static_cast(index); +} + std::string getKeySendRecv(int myRank, int peer) { int lowRank = myRank < peer ? myRank : peer; int highRank = myRank < peer ? peer : myRank; @@ -327,7 +342,10 @@ void cacheAllocatorDeregisterHook( } #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) -std::string dump_nccl_trace() { +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { std::unordered_map< std::string /* ncclUniqueID */, std::unordered_map /* dump from this comm */> @@ -347,14 +365,29 @@ std::string dump_nccl_trace() { std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); } - return NCCLTraceBuffer::get()->dump(ncclDumpMap); + return NCCLTraceBuffer::get()->dump( + ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); } + #else -std::string dump_nccl_trace() { - return NCCLTraceBuffer::get()->dump(c10::nullopt); +std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + return NCCLTraceBuffer::get()->dump( + c10::nullopt, includeCollectives, includeStackTraces, onlyActive); } #endif +// TODO(c-p-i-o): add a JSON endpoint. +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request& req, control_plane::Response& res) { + // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. + res.setContent( + dump_nccl_trace(true, true, false), "application/octet-stream"); + }}; + std::optional)>>& get_cpp_trace_dumper() { static std::optional< @@ -911,7 +944,12 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " << device << ", key " << key << ", i am " << this; auto comm = getNCCLComm(key, device, OpType::ALLREDUCE); - NCCLComm::split(comm.get(), NCCL_SPLIT_NOCOLOR, rank_, options_->config); + NCCLComm::split( + comm.get(), + NCCL_SPLIT_NOCOLOR, + rank_, + options_->config, + options_->global_ranks_in_group); #endif } @@ -1050,7 +1088,11 @@ void ProcessGroupNCCL::abortCommsFromMap( for (auto& it : ncclCommsMap) { auto& devName = it.first; auto& ncclComm = it.second; - + at::cuda::OptionalCUDAGuard gpuGuard; + at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); + if (deviceIndex >= 0) { + gpuGuard.set_index(deviceIndex); + } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " << ncclComm->ncclComm_ << " on CUDA device: " << devName; ncclComm->ncclCommAbort(abortReason); @@ -1166,7 +1208,7 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // We dump nccl trace into local disk by default and users can register // their customized writer by inheriting `DebugInfoWriter` via // `registerDebugInfoWriter`. - auto ncclTrace = dump_nccl_trace(); + auto ncclTrace = dump_nccl_trace(true, true, false); DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank()); LOG(INFO) << logPrefix() << "ProcessGroupNCCL dumping nccl trace to " << writer.getWriterTarget(); @@ -1240,9 +1282,9 @@ void ProcessGroupNCCL::heartbeatMonitor() { "Received a dump signal from this local rank and will ", "start to dump the debug info. ", "Last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); exitMsg = c10::str( "ProcessGroupNCCL's watchdog detected an exception from the local rank. ", @@ -1302,9 +1344,9 @@ void ProcessGroupNCCL::heartbeatMonitor() { timeOutRank, ", and will start to dump the debug info. ", "Last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); exitMsg = c10::str( "ProcessGroupNCCL's watchdog detected a dump signal from rank ", @@ -1570,9 +1612,9 @@ void ProcessGroupNCCL::watchdogHandler() { logPrefix(), "NCCL Work update periodically: ", "last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); #endif auto logger = ::c10d::C10dLogger::getLogger(); @@ -1585,13 +1627,19 @@ void ProcessGroupNCCL::watchdogHandler() { data.integers["pg_id"] = uid_; data.integers["rank"] = rank_; data.integers["global_rank"] = globalRank(); - data.integers["last_enqueued_work"] = lastEnqueuedSeq_; - data.integers["last_started_work"] = lastStartedSeq_; - data.integers["last_completed_work"] = lastCompletedSeq_; + data.integers["last_enqueued_work"] = pgStatus_.lastEnqueuedSeq; + data.integers["last_started_work"] = pgStatus_.lastStartedSeq; + data.integers["last_completed_work"] = pgStatus_.lastCompletedSeq; + data.integers["last_enqueued_numel_in"] = pgStatus_.lastEnqueuedNumelIn; + data.integers["last_enqueued_numel_out"] = pgStatus_.lastEnqueuedNumelOut; + data.integers["last_completed_numel_in"] = pgStatus_.lastCompletedNumelIn; + data.integers["last_completed_numel_out"] = + pgStatus_.lastCompletedNumelOut; // logging strings - data.strings["last_enqueued_work_name"] = lastEnqueuedWorkName_; - data.strings["last_started_work_name"] = lastStartedWorkName_; - data.strings["last_completed_work_name"] = lastCompletedWorkName_; + data.strings["last_enqueued_work_name"] = pgStatus_.lastEnqueuedWorkName; + data.strings["last_started_work_name"] = pgStatus_.lastStartedWorkName; + data.strings["last_completed_work_name"] = + pgStatus_.lastCompletedWorkName; data.strings["pg_name"] = pg_name_; data.strings["pg_desc"] = pg_desc_; logger->log(data); @@ -1618,9 +1666,9 @@ void ProcessGroupNCCL::watchdogHandler() { "Exception (either an error or timeout) detected by watchdog at work: ", work.seq_, ", last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); // try to dump flight records if exception happens. // Flight recorder behavior should be independent of desync Debug @@ -1663,9 +1711,9 @@ void ProcessGroupNCCL::watchdogHandler() { "Timeout at NCCL work: ", work.seq_, ", last enqueued NCCL work: ", - lastEnqueuedSeq_, + pgStatus_.lastEnqueuedSeq, ", last completed NCCL work: ", - lastCompletedSeq_, + pgStatus_.lastCompletedSeq, "."); if (desyncDebug_) { try { @@ -1700,18 +1748,20 @@ void ProcessGroupNCCL::watchdogHandler() { } // a work could be started but not completed, so we should not update - // lastStartedSeq_ and lastStartedOpName_ if the work state is checked + // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start - if (lastStartedSeq_ < static_cast(work.seq_) && + if (pgStatus_.lastStartedSeq < static_cast(work.seq_) && work.isStarted()) { - lastStartedSeq_ = work.seq_; - lastStartedWorkName_ = opTypeToString(work.opType_); + pgStatus_.lastStartedSeq = work.seq_; + pgStatus_.lastStartedWorkName = opTypeToString(work.opType_); } // Clean up completed work if (work.isCompleted()) { - lastCompletedSeq_ = work.seq_; - lastCompletedWorkName_ = opTypeToString(work.opType_); + pgStatus_.lastCompletedSeq = work.seq_; + pgStatus_.lastCompletedWorkName = opTypeToString(work.opType_); + pgStatus_.lastCompletedNumelIn = work.numelIn_; + pgStatus_.lastCompletedNumelOut = work.numelOut_; NCCLTraceBuffer::get()->retire_id(work.trace_id_, true); if (onCompletionHook_) { // Move Work object to completedWorkList_ to be consumed by the hook @@ -1994,8 +2044,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt); } - // For point-to-point communication on the same process, don't need broadcast. - if (!isSendRecvSelf) { + if (shouldBroadcastNCCLUniqueID(isSendRecvSelf)) { // Broadcast so that each process can have a unique NCCL ID auto timeStarted = std::chrono::steady_clock::now(); broadcastUniqueNCCLID(&ncclID, singleP2POp, deviceKey, p2pRank); @@ -2048,6 +2097,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( numRanks = 2; rank = p2pRank; } + // Get the device index auto deviceIndex = device.index(); gpuGuard.set_index(deviceIndex); @@ -2064,13 +2114,17 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( auto& parentComm = dit->second; if (parentComm != nullptr && !parentComm->isAborted()) { ncclComm = NCCLComm::split( - parentComm.get(), options_->split_color, rank, options_->config); + parentComm.get(), + options_->split_color, + rank, + options_->config, + options_->global_ranks_in_group); } } } #endif - // To simplify conditioonal nesting, just create the ncclComms[i] + // To simplify conditional nesting, just create the ncclComms[i] // entry if it hasn't been yet rather than untangling the // conditions that might have resulted in a split above. if (!ncclComm) { @@ -2301,6 +2355,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( outputs, r->ncclStartEvent_.get(), r->ncclEndEvent_.get(), + options_->timeout, isP2P); } return r; @@ -2340,8 +2395,11 @@ void ProcessGroupNCCL::workEnqueue( // needs to be destructed in user thread. Otherwise will // get deadlock. Here we enqueue work without outputs_. workMetaList_.emplace_back(*work); - lastEnqueuedSeq_ = work->seq_; - lastEnqueuedWorkName_ = opTypeToString(work->opType_); + // update the PG status related to the last enqueued work + pgStatus_.lastEnqueuedSeq = work->seq_; + pgStatus_.lastEnqueuedWorkName = opTypeToString(work->opType_); + pgStatus_.lastEnqueuedNumelIn = work->numelIn_; + pgStatus_.lastEnqueuedNumelOut = work->numelOut_; lastWorkListUpdateTime_ = std::chrono::steady_clock::now(); } } @@ -2908,6 +2966,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, nullptr, nullptr, + options_->timeout, /*isP2P=*/true); // TODO(whc) if we want to make the per-p2p-op flightrecorder entries get // their timings/states updated by proxy when the Work obj representing the @@ -2941,6 +3000,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( {tensor}, work->ncclStartEvent_.get(), work->ncclEndEvent_.get(), + options_->timeout, /*isP2P=*/true); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 1655de8a7848..faaabe411bfc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -437,6 +437,34 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::string group_name; }; + // A struct to hold the latest status of the process group. + struct ProcessGroupStatus { + // the sequential number of the last collective enqueued into workMetaList_ + // This is useful for indentifying a rank that has not join a collective + // initialized to be -1 to indicate no collective has been enqueued + int64_t lastEnqueuedSeq{-1}; + // the sequential number of the last collective started as the kernel + int64_t lastStartedSeq{-1}; + // the sequential number of the last colletive completed marked by + // the watchdog thread + // initialized to be -1 to indicate no collective has been completed + int64_t lastCompletedSeq{-1}; + + // the name of the last collective enqueued into workMetaList_ + std::string lastEnqueuedWorkName; + // the name of the last collective started as the kernel + std::string lastStartedWorkName; + // the name of the last collective completed + std::string lastCompletedWorkName; + + // the sizes of the last work enqueued + size_t lastEnqueuedNumelIn; + size_t lastEnqueuedNumelOut; + // the sizes of the last work completed + size_t lastCompletedNumelIn; + size_t lastCompletedNumelOut; + }; + // If you wish to create multiple process groups, each with a potentially // different rank and size, you can do so by passing a new store instance // to each one. If you have only a single store object, you can @@ -615,7 +643,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { uint64_t getSequenceNumberForGroup() override; // Return the total number of splits the communicators held by this process - // group have performed. + // group have performed. Counts ncclCommCreateFromRanks() for ncclx v2.21.5+ uint64_t getCommSplitCounter() const; void registerOnCompletionHook( @@ -880,6 +908,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // communication, the key will be "1:2" on both processes. Note: this is for // the scenario where there is only 1 GPU per process. When it comes to // multiple GPUs per process, this part may need to redesigned. + // TODO: we probably need a separte map for P2P comms std::unordered_map> devNCCLCommMap_; // The NCCL communicators currently in process of being initialized. @@ -1071,28 +1100,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the ProcessGroup uint64_t op_id_{0}; - // the sequential number of the last collective enqueued into workMetaList_ - // This is useful for indentifying a rank that has not join a collective - // initialized to be -1 to indicate no collective has been enqueued - int64_t lastEnqueuedSeq_{-1}; - - // the name of the last collective enqueued into workMetaList_ - std::string lastEnqueuedWorkName_; - - // the sequential number of the last collective started as the kernel - int64_t lastStartedSeq_{-1}; - - // the name of the last collective started as the kernel - std::string lastStartedWorkName_; - - // the sequential number of the last colletive completed marked by - // the watchdog thread - // initialized to be -1 to indicate no collective has been completed - int64_t lastCompletedSeq_{-1}; - - // the name of the last collective completed - std::string lastCompletedWorkName_; - std::exception_ptr watchDogException_ = nullptr; size_t uid_; @@ -1103,13 +1110,20 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Number of devices on this node. int localDeviceCount_{0}; + + ProcessGroupStatus pgStatus_; }; -TORCH_API std::string dump_nccl_trace(); +// Dumps the NCCL comm traces and additional information about the Process +// Group. +TORCH_API std::string dump_nccl_trace( + bool includeCollectives, + bool includeStackTraces, + bool onlyActive); -// Gets a mutable reference to a global optional function. Heartbeat Monitor -// will use this function to dump traces, if available. Inside fbcode, we store -// a function here that uses an internal tool for process tracing +// Gets a mutable reference to a global optional function.Heartbeat Monitor +// will use this function to dump traces, if available. Inside fbcode, we +// store a function here that uses an internal tool for process tracing TORCH_API std::optional< std::function)>>& get_cpp_trace_dumper(); diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index 626a9e3b688b..061d0ed620ba 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -106,8 +106,7 @@ class StoreTimeoutGuard { explicit StoreTimeoutGuard( Store& store, const std::chrono::milliseconds& timeout) - : store_(store) { - oldTimeout_ = store.getTimeout(); + : store_(store), oldTimeout_(store.getTimeout()) { store.setTimeout(timeout); } @@ -123,7 +122,7 @@ class StoreTimeoutGuard { private: Store& store_; - std::chrono::milliseconds oldTimeout_; + std::chrono::milliseconds oldTimeout_{}; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index aee1d7677dc4..a716bf666755 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -291,6 +291,17 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) TORCH_CHECK( ::c10d::detail::is_libuv_tcpstore_backend_available(), "use_libuv was requested but PyTorch was build without libuv support"); + + if (opts.masterListenFd.has_value()) { + // TODO(xilunwu): support this init method after testing + constexpr auto* msg = + "The libuv TCPStore backend does not support initialization with an listen fd. " + "Please switch to the legacy TCPStore by setting environment variable USE_LIBUV " + "to \"0\"."; + C10D_ERROR(msg); + C10_THROW_ERROR(NotImplementedError, msg); + return; + } } Socket::initialize(); diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 7080d50136e9..25783f2d2ace 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -63,7 +63,7 @@ struct TCPStoreOptions { std::optional masterListenFd = c10::nullopt; // A boolean value indicating whether to use the experimental libUV backend. - bool useLibUV = false; + bool useLibUV = true; }; class TORCH_API TCPStore : public Store { @@ -158,7 +158,7 @@ class TORCH_API TCPStore : public Store { const std::string keyPrefix_ = "/"; std::mutex activeOpLock_; std::unordered_map clientCounters_; - bool usingLibUv_ = false; + bool usingLibUv_ = true; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index f33cbb019401..d162149ed3a4 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -186,10 +187,14 @@ class UvTcpServer : public UvTcpSocket { int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket); TORCH_CHECK( uv_res == 0, - "Failed to open existing socket. socket:{} code:{} name:{} message:{}", + "Failed to open existing socket. ", + "socket: ", socket, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); res->cacheSocketPort(); @@ -221,30 +226,42 @@ class UvTcpServer : public UvTcpSocket { } TORCH_CHECK( uv_res == 0, - "UV Store addr parsing failure. useIpv6:{} code:{} name:{} message:{}", + "UV Store addr parsing failure. ", + "useIpv6: ", useIpv6, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); uv_res = uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0); TORCH_CHECK( uv_res == 0, - "UV Store bind failed. useIpv6:{} code:{} name:{} message:{}", + "The server socket has failed to bind. ", + "useIpv6: ", useIpv6, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); uv_res = uv_listen(res->unsafeGetStream(), DEFAULT_BACKLOG, on_new_connection); TORCH_CHECK( uv_res == 0, - "UV Store listen failed. useIpv6:{} code:{} name:{} message:{}", + "The server socket has failed to listen on any local network address. ", + "useIpv6: ", useIpv6, + ", code: ", uv_res, + ", name: ", uv_err_name(uv_res), + ", message: ", uv_strerror(uv_res)); res->cacheSocketPort(); @@ -265,9 +282,12 @@ class UvTcpServer : public UvTcpSocket { uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle()); TORCH_CHECK( res == 0, - "Failed to accept socket. code:{} name:{} desc:{}.", + "Failed to accept socket. ", + "code: ", res, + ", name: ", uv_err_name(res), + ", message: ", uv_strerror(res)); } @@ -458,9 +478,12 @@ class ChunkedStream { if (buff_idx >= buffers.size() && remaining > 0) { TORCH_CHECK( false, - "Trying to read past end of buffer buffer_idx:{} available:{} remaining:{}", + "Trying to read past end of buffer. ", + "buffer_idx: ", buff_idx, + ", available: ", buffers.size(), + ", remaining: ", remaining); } } @@ -498,8 +521,10 @@ class ChunkedStream { return false; TORCH_CHECK( size <= MAX_STRING_LEN, - "Invalid string size. size:{} max:{}", + "Invalid string size. ", + "size: ", size, + ", max: ", MAX_STRING_LEN); if (available() < size) @@ -515,8 +540,10 @@ class ChunkedStream { auto size_in_bytes = size * sizeof(uint8_t); TORCH_CHECK( size_in_bytes <= MAX_PAYLOAD_LEN, - "Invalid payload size. size: {} max:{}", + "Invalid payload size. ", + "size: ", size_in_bytes, + ", max: ", MAX_PAYLOAD_LEN); if (available() < size_in_bytes) @@ -752,7 +779,7 @@ class UvClient : public UvTcpSocket { if (!stream.read_key(key)) return false; - auto data = store->get(key); + const auto& data = store->get(key); StreamWriter sw(iptr()); sw.write_vector(data); sw.send(); @@ -782,8 +809,10 @@ class UvClient : public UvTcpSocket { return false; TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys being waited. keys:{} max:{}", + "Too many keys being waited. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); std::vector keys(key_count); @@ -810,8 +839,10 @@ class UvClient : public UvTcpSocket { } TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys being waited. keys:{} max:{}", + "Too many keys being waited. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); std::vector keys(key_count); @@ -872,8 +903,10 @@ class UvClient : public UvTcpSocket { } TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys with multi_get. keys:{} max:{}", + "Too many keys with multi_get. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); StreamWriter sw(iptr()); @@ -884,8 +917,7 @@ class UvClient : public UvTcpSocket { return false; } - auto data = store->get(key); - sw.write_vector(data); + sw.write_vector(store->get(key)); } sw.send(); @@ -899,8 +931,10 @@ class UvClient : public UvTcpSocket { } TORCH_CHECK( key_count <= MAX_KEY_COUNT, - "Too many keys with multi_get. keys:{} max:{}", + "Too many keys with multi_get. ", + "keys: ", key_count, + ", max: ", MAX_KEY_COUNT); for (const auto _ : c10::irange(key_count)) { @@ -989,9 +1023,11 @@ void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { port_ = tcpServer->port(); TORCH_CHECK( port_ == opts.port || opts.port == 0, // zero means use any port - "listen fd {} is bound to port {}, expected to be bound to port {}", + "listen fd ", *opts.masterListenFd, + " is bound to port ", port_, + ", expected to be bound to port ", opts.port); } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 575bb0451f18..de623d77fe9e 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -1,15 +1,19 @@ #pragma once -#include #include #include #include #include -#include #include #include #include #include #include +#include + +#ifdef USE_C10D_NCCL +#include +#include +#endif #include #include @@ -25,7 +29,7 @@ static c10::IValue nccl_comm_key = "nccl_comm_state"; static c10::IValue version_key = "version"; // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) -static c10::IValue version_val = "2.1"; +static c10::IValue version_val = "2.2"; static c10::IValue pg_config_key = "pg_config"; static c10::IValue record_id_key = "record_id"; static c10::IValue pg_id_key = "pg_id"; @@ -41,6 +45,7 @@ static c10::IValue output_sizes_key = "output_sizes"; static c10::IValue output_dtypes_key = "output_dtypes"; static c10::IValue time_created_key = "time_created_ns"; static c10::IValue duration_key = "duration_ms"; +static c10::IValue timeout_key = "timeout_ms"; static c10::IValue frames_key = "frames"; static c10::IValue state_key = "state"; @@ -458,6 +463,9 @@ struct NCCLTraceBuffer { // was 'enqueued'- not necessarily started c10::time_t time_created_; + // configured timeout for this entry + c10::time_t timeout_ms_; + // Is this a P2P event? bool isP2P_; @@ -505,6 +513,7 @@ struct NCCLTraceBuffer { const std::vector& outputs, Event* start, Event* end, + std::chrono::milliseconds timeout_ms, bool isP2P) { if (!enabled_) { return c10::nullopt; @@ -525,6 +534,7 @@ struct NCCLTraceBuffer { std::move(start), std::move(end), c10::getTime(), + timeout_ms.count(), isP2P}; for (const auto& input : inputs) { @@ -652,31 +662,44 @@ struct NCCLTraceBuffer { entry->start_ = entry->end_ = nullptr; } - std::string dump( - const std::optional>>& ncclDumpMap) { - auto result = dump_entries(); + const c10::List getCollectiveTrace( + bool includeStacktraces, + bool onlyActive) { auto entries = new_list(); - + auto result = dump_entries(); std::vector tracebacks; - for (auto& e : result) { - tracebacks.push_back(e.traceback_.get()); - } - torch::SymbolizedTracebacks stracebacks = torch::symbolize(tracebacks); + torch::SymbolizedTracebacks stracebacks; std::vector all_frames; - for (const auto& f : stracebacks.all_frames) { - auto d = new_dict(); - d.insert(name_key, f.funcname); - d.insert(filename_key, f.filename); - d.insert(line_key, int64_t(f.lineno)); - all_frames.emplace_back(std::move(d)); + if (includeStacktraces) { + for (auto& e : result) { + tracebacks.push_back(e.traceback_.get()); + } + stracebacks = torch::symbolize(tracebacks); + for (const auto& f : stracebacks.all_frames) { + auto d = new_dict(); + d.insert(name_key, f.funcname); + d.insert(filename_key, f.filename); + d.insert(line_key, int64_t(f.lineno)); + all_frames.emplace_back(std::move(d)); + } } - for (auto i : c10::irange(result.size())) { - auto& e = result.at(i); - auto& tb = stracebacks.tracebacks.at(i); auto dict = new_dict(); + auto& e = result.at(i); + // Skip completed events + if (onlyActive && e.time_discovered_completed_.has_value()) { + continue; + } + + if (includeStacktraces) { + auto& tb = stracebacks.tracebacks.at(i); + auto frames = new_list(); + for (int64_t frame : tb) { + frames.push_back(all_frames.at(frame)); + } + dict.insert(frames_key, frames); + } + dict.insert(record_id_key, int64_t(e.id_)); dict.insert(pg_id_key, int64_t(e.pg_id_)); dict.insert(pg_name_key, e.pg_name_); @@ -736,15 +759,16 @@ struct NCCLTraceBuffer { ? int64_t(*e.time_discovered_completed_) : c10::IValue()); dict.insert(retired_key, e.retired_); + dict.insert(timeout_key, e.timeout_ms_); dict.insert(is_p2p_key, e.isP2P_); - auto frames = new_list(); - for (int64_t frame : tb) { - frames.push_back(all_frames.at(frame)); - } - dict.insert(frames_key, frames); entries.push_back(dict); } + return entries; + } + + // dump pg_entries + const c10::Dict getPgConfig() { auto pg_config = new_dict(); for (const auto& [pg_name, ranks] : pg_name_to_ranks_) { auto pg_info = new_dict(); @@ -753,6 +777,27 @@ struct NCCLTraceBuffer { pg_info.insert("ranks", ranks_str(ranks)); pg_config.insert(std::get<0>(pg_name), pg_info); } + return pg_config; + } + + // dump all collectives + ncclDumpMap + std::string dump( + const std::optional>>& ncclDumpMap, + bool includeCollectives, + bool includeStackTraces, + bool onlyActive) { + auto result = new_dict(); + // common values + result.insert(version_key, version_val); + result.insert(pg_config_key, getPgConfig()); + + // collective trace + if (includeCollectives) { + result.insert( + entries_key, getCollectiveTrace(includeStackTraces, onlyActive)); + } // convert ncclDumpMap into a dictionary auto per_comm_dict = new_dict(); @@ -765,16 +810,10 @@ struct NCCLTraceBuffer { per_comm_dict.insert(ncclId, inner_dict); } } - - auto dict = new_dict(); - dict.insert(entries_key, entries); - dict.insert(version_key, version_val); if (per_comm_dict.size() > 0) { - dict.insert(nccl_comm_key, per_comm_dict); + result.insert(nccl_comm_key, per_comm_dict); } - dict.insert(pg_config_key, pg_config); - - return pickle_str(dict); + return pickle_str(result); } }; diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index b193c8971b57..a03337e97514 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -38,6 +38,11 @@ TORCH_API std::vector getTensorShapes( // Use -2 to represent unset state of env vars #define C10D_ENV_NOT_SET -2 +#define WARN_ENV_VAR_ONCE(deprecated_env, new_env) \ + TORCH_WARN_ONCE( \ + "Environment variable " + deprecated_env + " is deprecated; use " + \ + new_env + " instead"); + // Turns at::IntArrayRef into "(1, 2, 3, 4)". inline std::string toString(at::IntArrayRef l) { std::stringstream ss; @@ -102,9 +107,7 @@ inline std::string getCvarString( if (val == nullptr) { continue; } else if (i) { - TORCH_WARN( - "Environment variable " + env[i] + " is deprecated; use " + env[0] + - " instead"); + WARN_ENV_VAR_ONCE(env[i], env[0]); } ret = val; @@ -129,9 +132,7 @@ inline int getCvarInt(const std::vector& env, int def) { if (val == nullptr) { continue; } else if (i) { - TORCH_WARN( - "Environment variable " + env[i] + " is deprecated; use " + env[0] + - " instead"); + WARN_ENV_VAR_ONCE(env[i], env[0]); } try { @@ -160,9 +161,7 @@ inline bool getCvarBool(const std::vector& env, bool def) { if (val_ == nullptr) { continue; } else if (i) { - TORCH_WARN( - "Environment variable " + env[i] + " is deprecated; use " + env[0] + - " instead"); + WARN_ENV_VAR_ONCE(env[i], env[0]); } std::string val = std::string(val_); diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp new file mode 100644 index 000000000000..e29f1e3a2ac3 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -0,0 +1,75 @@ +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { + +class HandlerRegistry { + public: + void registerHandler(const std::string& name, HandlerFunc f) { + std::unique_lock lock(handlersMutex_); + + if (handlers_.find(name) != handlers_.end()) { + throw std::runtime_error( + fmt::format("Handler {} already registered", name)); + } + + handlers_[name] = f; + } + + HandlerFunc getHandler(const std::string& name) { + std::shared_lock lock(handlersMutex_); + + auto it = handlers_.find(name); + if (it == handlers_.end()) { + throw std::runtime_error(fmt::format("Failed to find handler {}", name)); + } + return handlers_[name]; + } + + std::vector getHandlerNames() { + std::shared_lock lock(handlersMutex_); + + std::vector names; + for (const auto& [name, _] : handlers_) { + names.push_back(name); + } + return names; + } + + private: + std::shared_mutex handlersMutex_{}; + std::unordered_map handlers_{}; +}; + +HandlerRegistry& getHandlerRegistry() { + static HandlerRegistry registry; + return registry; +} + +RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { + res.setContent("pong", "text/plain"); + }}; + +} // namespace + +void registerHandler(const std::string& name, HandlerFunc f) { + return getHandlerRegistry().registerHandler(name, f); +} + +HandlerFunc getHandler(const std::string& name) { + return getHandlerRegistry().getHandler(name); +} + +std::vector getHandlerNames() { + return getHandlerRegistry().getHandlerNames(); +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp new file mode 100644 index 000000000000..0c1063054931 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace c10d { +namespace control_plane { + +// Request represents a request to the handler. This conceptually maps to an +// HTTP request but could be called via other transports. +class TORCH_API Request { + public: + virtual ~Request() = default; + + virtual const std::string& body() = 0; +}; + +// Response represents a response to the handler. This conceptually maps to an +// HTTP response but could be called via other transports. +class TORCH_API Response { + public: + virtual ~Response() = default; + + // Set the response body to the provided string. + // TODO: add support for chunked responses + virtual void setContent( + std::string&& content, + const std::string& content_type) = 0; + + // Set the response status code. + // These should match standard HTTP status codes. + virtual void setStatus(int status) = 0; +}; + +using HandlerFunc = std::function; + +// Registers a handler. The name needs to be unique and can be called by using +// getHandler directly or via WorkerServer for remote requests. +// These handlers are called from a background C++ thread concurrently with the +// main thread. These handlers need to be thread safe and not cause issues +// during Python training. +TORCH_API void registerHandler(const std::string& name, HandlerFunc f); + +// Fetches a handler by name. +TORCH_API HandlerFunc getHandler(const std::string& name); + +TORCH_API std::vector getHandlerNames(); + +// Registers a handler statically. +// See registerHandler for more details. +class TORCH_API RegisterHandler { + public: + RegisterHandler(const std::string& name, HandlerFunc f) { + registerHandler(name, f); + } + + // disable move, copy + RegisterHandler(const RegisterHandler&) = delete; + RegisterHandler(RegisterHandler&&) = delete; + RegisterHandler& operator=(const RegisterHandler&) = delete; + RegisterHandler& operator=(RegisterHandler&&) = delete; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp new file mode 100644 index 000000000000..e4b649d888dd --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -0,0 +1,186 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10d { +namespace control_plane { + +namespace { +class RequestImpl : public Request { + public: + RequestImpl(const httplib::Request& req) : req_(req) {} + + const std::string& body() override { + return req_.body; + } + + private: + const httplib::Request& req_; +}; + +class ResponseImpl : public Response { + public: + ResponseImpl(httplib::Response& res) : res_(res) {} + + void setStatus(int status) override { + res_.status = status; + } + + void setContent(std::string&& content, const std::string& content_type) + override { + res_.set_content(std::move(content), content_type); + } + + private: + httplib::Response& res_; +}; + +std::string jsonStrEscape(const std::string& str) { + std::ostringstream ostream; + for (char ch : str) { + if (ch == '"') { + ostream << "\\\""; + } else if (ch == '\\') { + ostream << "\\\\"; + } else if (ch == '\b') { + ostream << "\\b"; + } else if (ch == '\f') { + ostream << "\\f"; + } else if (ch == '\n') { + ostream << "\\n"; + } else if (ch == '\r') { + ostream << "\\r"; + } else if (ch == '\t') { + ostream << "\\t"; + } else if ('\x00' <= ch && ch <= '\x1f') { + ostream << "\\u" << std::hex << std::setw(4) << std::setfill('0') + << static_cast(ch); + } else { + ostream << ch; + } + } + return ostream.str(); +} +} // namespace + +WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { + server_.Get("/", [](const httplib::Request& req, httplib::Response& res) { + res.set_content( + R"BODY(

torch.distributed.WorkerServer

+Handler names +)BODY", + "text/html"); + }); + server_.Get( + "/handler/", [](const httplib::Request& req, httplib::Response& res) { + std::ostringstream body; + body << "["; + bool first = true; + for (const auto& name : getHandlerNames()) { + if (!first) { + body << ","; + } + first = false; + + body << "\"" << jsonStrEscape(name) << "\""; + } + body << "]"; + + res.set_content(body.str(), "application/json"); + }); + server_.Post( + "/handler/:handler", + [](const httplib::Request& req, httplib::Response& res) { + auto handler_name = req.path_params.at("handler"); + HandlerFunc handler; + try { + handler = getHandler(handler_name); + } catch (const std::exception& e) { + res.status = 404; + res.set_content( + fmt::format("Handler {} not found: {}", handler_name, e.what()), + "text/plain"); + return; + } + RequestImpl torchReq{req}; + ResponseImpl torchRes{res}; + + try { + handler(torchReq, torchRes); + } catch (const std::exception& e) { + res.status = 500; + res.set_content( + fmt::format("Handler {} failed: {}", handler_name, e.what()), + "text/plain"); + return; + } catch (...) { + res.status = 500; + res.set_content( + fmt::format( + "Handler {} failed with unknown exception", handler_name), + "text/plain"); + return; + } + }); + + // adjust keep alives as it stops the server from shutting down quickly + server_.set_keep_alive_timeout(1); // second, default is 5 + server_.set_keep_alive_max_count( + 30); // wait max 30 seconds before closing socket + + if (port == -1) { + // using unix sockets + server_.set_address_family(AF_UNIX); + + if (std::filesystem::exists(hostOrFile)) { + throw std::runtime_error(fmt::format("{} already exists", hostOrFile)); + } + + C10D_WARNING("Server listening to UNIX {}", hostOrFile); + if (!server_.bind_to_port(hostOrFile, 80)) { + throw std::runtime_error(fmt::format("Error binding to {}", hostOrFile)); + } + } else { + C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); + if (!server_.bind_to_port(hostOrFile, port)) { + throw std::runtime_error( + fmt::format("Error binding to {}:{}", hostOrFile, port)); + } + } + + serverThread_ = std::thread([this]() { + try { + if (!server_.listen_after_bind()) { + throw std::runtime_error("failed to listen"); + } + } catch (std::exception& e) { + C10D_ERROR("Error while running server: {}", e.what()); + throw; + } + C10D_WARNING("Server exited"); + }); +} + +void WorkerServer::shutdown() { + C10D_WARNING("Server shutting down"); + server_.stop(); + serverThread_.join(); +} + +WorkerServer::~WorkerServer() { + if (serverThread_.joinable()) { + C10D_WARNING("WorkerServer destructor called without shutdown"); + shutdown(); + } +} + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp new file mode 100644 index 000000000000..a0b16ac192ba --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace c10d { +namespace control_plane { + +class TORCH_API WorkerServer : public c10::intrusive_ptr_target { + public: + WorkerServer(const std::string& hostOrFile, int port = -1); + ~WorkerServer(); + + void shutdown(); + + private: + httplib::Server server_; + std::thread serverThread_; +}; + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2aaf9009a246..6f1b28886b98 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifndef _WIN32 #include @@ -1015,11 +1016,13 @@ Example:: const std::string& key, const std::string& expected_value, const std::string& desired_value) -> py::bytes { - auto value = store.compareSet( - key, toVec8(expected_value), toVec8(desired_value)); + auto value = [&]() { + py::gil_scoped_release guard; + return store.compareSet( + key, toVec8(expected_value), toVec8(desired_value)); + }(); return toPyBytes(value); }, - py::call_guard(), R"( Inserts the key-value pair into the store based on the supplied ``key`` and performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value`` @@ -1390,6 +1393,7 @@ the server to establish a connection. wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True. multi_tenant (bool, optional): If True, all ``TCPStore`` instances in the current process with the same host/port will use the same underlying ``TCPServer``. Default is False. master_listen_fd (int, optional): If specified, the underlying ``TCPServer`` will listen on this file descriptor, which must be a socket already bound to ``port``. Useful to avoid port assignment races in some scenarios. Default is None (meaning the server creates a new socket and attempts to bind it to ``port``). + use_libuv (bool, optional): If True, use libuv for ``TCPServer`` backend. Default is True. Example:: >>> import torch.distributed as dist >>> from datetime import timedelta @@ -1439,7 +1443,7 @@ Example:: py::arg("wait_for_workers") = true, py::arg("multi_tenant") = false, py::arg("master_listen_fd") = py::none(), - py::arg("use_libuv") = false, + py::arg("use_libuv") = true, py::call_guard()) .def( "collect_client_counters", @@ -2031,7 +2035,7 @@ communication mechanism. self->registerOnCompletionHook( [hookWrapper = ::c10d::PythonOnCompletionHook(std::move( hook))](std::shared_ptr<::c10d::WorkInfo> workInfo) { - hookWrapper(std::move(workInfo)); + hookWrapper(workInfo); }); }, py::arg("hook"), @@ -3160,10 +3164,41 @@ such as `dist.all_reduce(tensor, async_op=True)`. Arguments: tensors(List[torch.Tensor]): List of tensors we want to hash. )"); - module.def("_dump_nccl_trace", []() { - return py::bytes(::c10d::dump_nccl_trace()); - }); + module.def( + "_dump_nccl_trace", + [](std::optional includeCollectives, + std::optional includeStackTraces, + std::optional onlyActive) { + return py::bytes(::c10d::dump_nccl_trace( + includeCollectives.value_or(true), + includeStackTraces.value_or(true), + onlyActive.value_or(false))); + }, + py::arg("includeCollectives") = std::optional(), + py::arg("includeStackTraces") = std::optional(), + py::arg("onlyActive") = std::optional(), + R"( + Arguments: + includeCollectives(bool, optional): Whether to include collective work traces. Default is True. + includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True. + onlyActive (bool, optional): Whether to only include active collective work traces. Default is False. + Returns: + Stringified pickle work traces. + Default settings return everything - i.e. contains NCCL comm dumps and collective traces. + )"); #endif + + intrusive_ptr_class_<::c10d::control_plane::WorkerServer>( + module, "_WorkerServer", R"( +)") + .def( + py::init([](const std::string& hostOrFile, int port) { + return c10::make_intrusive<::c10d::control_plane::WorkerServer>( + hostOrFile, port); + }), + py::arg("host_or_file"), + py::arg("port") = -1) + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); Py_RETURN_TRUE; } diff --git a/torch/csrc/distributed/c10d/quantization/quantization.cpp b/torch/csrc/distributed/c10d/quantization/quantization.cpp index 8ed6d97d6d80..2d4fa2ba3812 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization.cpp +++ b/torch/csrc/distributed/c10d/quantization/quantization.cpp @@ -2,10 +2,7 @@ #include #include -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { // TODO: The kernels are copied from fbgemm_gpu, we should dedup them later @@ -31,11 +28,9 @@ static void BFloat16QuantizedToFloat_ref( const size_t nrows, const size_t ncols, float* const output) { - const int32_t output_columns = ncols; - for (const auto row : c10::irange(nrows)) { const at::BFloat16* input_row = input + row * ncols; - float* output_row = output + row * output_columns; + float* output_row = output + row * ncols; for (const auto col : c10::irange(ncols)) { uint32_t val_fp32 = static_cast( @@ -52,11 +47,9 @@ at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) { TENSOR_NDIM_EQUALS(input, 2); const auto input_sizes = input.sizes(); - const int32_t nrows = input_sizes[0]; - const int32_t ncols = input_sizes[1]; - const int32_t output_columns = ncols; - auto output = - at::empty({nrows, output_columns}, input.options().dtype(at::kHalf)); + const auto nrows = input_sizes[0]; + const auto ncols = input_sizes[1]; + auto output = at::empty({nrows, ncols}, input.options().dtype(at::kHalf)); FloatToBFloat16Quantized_ref( input.const_data_ptr(), @@ -73,13 +66,10 @@ at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) { TENSOR_NDIM_EQUALS(input, 2); const auto input_sizes = input.sizes(); - const int32_t nrows = input_sizes[0]; - const int32_t ncols = input_sizes[1]; - const int32_t output_columns = ncols; + const auto nrows = input_sizes[0]; + const auto ncols = input_sizes[1]; - auto output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) - input.options().dtype(at::kFloat)); // + auto output = at::empty({nrows, ncols}, input.options().dtype(at::kFloat)); BFloat16QuantizedToFloat_ref( reinterpret_cast(input.const_data_ptr()), nrows, @@ -99,7 +89,4 @@ TORCH_LIBRARY_IMPL(quantization, CPU, m) { m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu); } -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization.h b/torch/csrc/distributed/c10d/quantization/quantization.h index 8cf3455ce79b..3d2f23de421b 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization.h +++ b/torch/csrc/distributed/c10d/quantization/quantization.h @@ -8,15 +8,9 @@ #include #include -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input); at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input); -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu index 48cc7cfc4f3e..480cfb91cfb1 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu @@ -9,16 +9,16 @@ // FP32 -> BF16 kernel __global__ void _float_to_bfloat16_cuda_kernel( const float* __restrict__ input, - const int nrows, - const int ncols, + const size_t nrows, + const size_t ncols, uint16_t* __restrict__ output) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; + const auto row_incre = blockDim.y * gridDim.y; + const auto col_incre = blockDim.x * gridDim.x; + for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; row += row_incre) { const float* input_row = input + row * ncols; uint16_t* output_row = output + row * ncols; - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; + for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; col += col_incre) { // Add 2^15 and right shift 16 to do round-nearest output_row[col] = @@ -31,14 +31,14 @@ __global__ void _float_to_bfloat16_cuda_kernel( // BF16 -> FP32 kernel __global__ void _bfloat16_to_float_cuda_kernel( const uint16_t* __restrict__ input, - const int nrows, - const int ncols, + const size_t nrows, + const size_t ncols, float* __restrict__ output) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; + const auto row_incre = blockDim.y * gridDim.y; + const auto col_incre = blockDim.x * gridDim.x; + for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; row += row_incre) { - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; + for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; col += col_incre) { const uint16_t* input_row = input + row * ncols; float* output_row = output + row * ncols; @@ -50,10 +50,7 @@ __global__ void _bfloat16_to_float_cuda_kernel( } } -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) { TENSOR_ON_CUDA_GPU(input); @@ -63,27 +60,28 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); - const int nrows = input.size(0); - const int ncols = input.size(1); - const int output_columns = ncols; + const auto nrows = input.size(0); + const auto ncols = input.size(1); + const size_t output_columns = ncols; auto output = at::empty( - {nrows, output_columns}, + {nrows, ncols}, #if HAS_NCCL_BF16_DATATYPE input.options().dtype(at::kBFloat16)); #else input.options().dtype(at::kHalf)); #endif - if (nrows == 0 || output_columns == 0) { + if (nrows == 0 || ncols == 0) { return output; } - constexpr int threads_per_block = 256; - const int blockDim_x = std::min(output_columns, threads_per_block); + constexpr size_t threads_per_block = 256; + const auto blockDim_x = std::min(output_columns, threads_per_block); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; - const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); + const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; + const auto gridDim_y = + std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); dim3 gridDim(gridDim_x, gridDim_y); _float_to_bfloat16_cuda_kernel<<< @@ -113,24 +111,25 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); - const int nrows = input.size(0); - const int ncols = input.size(1); - const int output_columns = ncols; + const auto nrows = input.size(0); + const auto ncols = input.size(1); + const size_t output_columns = ncols; auto output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) + {nrows, ncols}, // 4 = sizeof(float) input.options().dtype(at::kFloat)); // at::kBytes for uint8_t - if (nrows == 0 || output_columns == 0) { + if (nrows == 0 || ncols == 0) { return output; } - constexpr int threads_per_block = 256; + constexpr size_t threads_per_block = 256; - const int blockDim_x = std::min(output_columns, threads_per_block); + const auto blockDim_x = std::min(output_columns, threads_per_block); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; - const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); + const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; + const auto gridDim_y = + std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); dim3 gridDim(gridDim_x, gridDim_y); _bfloat16_to_float_cuda_kernel<<< @@ -152,14 +151,11 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) { } #define DISPATCH_TO_CUDA(name, function) \ - m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) + m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) TORCH_LIBRARY_IMPL(quantization, CUDA, m) { - DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda); - DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda); + DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda); + DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda); } -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h index 90bfc083b39d..f865599595d3 100644 --- a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h @@ -8,15 +8,9 @@ #include #include -namespace torch { -namespace distributed { -namespace c10d { -namespace quantization { +namespace torch::distributed::c10d::quantization { at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input); at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input); -} // namespace quantization -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d::quantization diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index d600426192ce..6a2812ab24b9 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -2366,10 +2366,18 @@ void Reducer::reset_state() { // Ensure forward can run despite previous backward not succeeding. expect_autograd_hooks_ = false; require_finalize_ = false; + first_autograd_hook_called_ = false; // Unset allreduce division factor, as it may change in next backwards pass // when running with DDP join mode. div_factor_ = kUnsetDivFactor; + + // Reset unused parameter accounting. + // See Note [local_used_map_ -> local_used_map_dev copying] + if (find_unused_parameters_) { + local_used_map_.zero_(); + local_used_map_reduced_ = false; + } } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/sequence_num.cpp b/torch/csrc/distributed/c10d/sequence_num.cpp index 6ea35820179e..fd76247199f6 100644 --- a/torch/csrc/distributed/c10d/sequence_num.cpp +++ b/torch/csrc/distributed/c10d/sequence_num.cpp @@ -1,11 +1,10 @@ #include -#include #include #include namespace c10d { -SequenceNum::SequenceNum() : num_(c10::nullopt) {} +SequenceNum::SequenceNum() = default; SequenceNum::SequenceNum(const uint64_t num) : num_(num) {} diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index 093a47a076b0..6cbaa018762e 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -670,7 +670,7 @@ class SocketConnectOp { static const std::chrono::seconds delay_duration_; - enum class ConnectResult { Success, Error, Retry }; + enum class ConnectResult : uint8_t { Success, Error, Retry }; public: SocketConnectOp( diff --git a/torch/csrc/distributed/rpc/agent_utils.cpp b/torch/csrc/distributed/rpc/agent_utils.cpp index 8eaae18cb209..89cb878755d9 100644 --- a/torch/csrc/distributed/rpc/agent_utils.cpp +++ b/torch/csrc/distributed/rpc/agent_utils.cpp @@ -13,7 +13,7 @@ std::unordered_map collectNames( std::vector selfNameVector( (uint8_t*)selfName.c_str(), (uint8_t*)selfName.c_str() + selfName.length()); - store.set(c10::to_string(selfId), selfNameVector); + store.set(std::to_string(selfId), selfNameVector); std::unordered_map nameToId; nameToId.reserve(worldSize); @@ -22,7 +22,7 @@ std::unordered_map collectNames( if (workerId == selfId) { continue; } - std::vector workerNameVector = store.get(c10::to_string(workerId)); + std::vector workerNameVector = store.get(std::to_string(workerId)); std::string workerName( (char*)workerNameVector.data(), workerNameVector.size()); @@ -69,7 +69,7 @@ std::unordered_map collectCurrentNames( // Check that ID does not already exist and set {ID : NAME} std::vector resultVector = store.compareSet( - c10::to_string(selfId), std::vector(), selfNameVector); + std::to_string(selfId), std::vector(), selfNameVector); TORCH_CHECK( resultVector == selfNameVector, "RPC worker id ", @@ -80,7 +80,7 @@ std::unordered_map collectCurrentNames( selfNameVector, " cannot be added."); - store.set(c10::to_string(selfId), selfNameVector); + store.set(std::to_string(selfId), selfNameVector); std::unordered_map nameToId; nameToId.emplace(selfName, selfId); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index 73b66f954541..bba751e08917 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -143,10 +143,10 @@ std::unordered_map RRefContext::getDebugInfo() { numForks += owner.second.size(); } lock.unlock(); - info[kNumOwnerRRefs] = c10::to_string(ownerSize); - info[kNumPendingFutures] = c10::to_string(numPendingFutures_.load()); - info[kNumPendingUsers] = c10::to_string(numPendingUsers); - info[kNumForks] = c10::to_string(numForks); + info[kNumOwnerRRefs] = std::to_string(ownerSize); + info[kNumPendingFutures] = std::to_string(numPendingFutures_.load()); + info[kNumPendingUsers] = std::to_string(numPendingUsers); + info[kNumForks] = std::to_string(numForks); return info; } diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 8af4336c0746..2de6bacb7ee4 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -1279,13 +1279,13 @@ void TensorPipeAgent::updateGroupMembership( } std::unordered_map TensorPipeAgent::getMetrics() { std::unordered_map metrics; - metrics[kThreadPoolSize] = c10::to_string(threadPool_.size()); - metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable()); + metrics[kThreadPoolSize] = std::to_string(threadPool_.size()); + metrics[kNumIdleThreads] = std::to_string(threadPool_.numAvailable()); { std::unique_lock lock(callCountMutex_); - metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_); - metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_); - metrics[kServerActiveAsyncCalls] = c10::to_string(serverActiveAsyncCalls_); + metrics[kClientActiveCalls] = std::to_string(clientActiveCalls_); + metrics[kServerActiveCalls] = std::to_string(serverActiveCalls_); + metrics[kServerActiveAsyncCalls] = std::to_string(serverActiveAsyncCalls_); } if (isGILProfilingEnabled()) { { @@ -1295,7 +1295,7 @@ std::unordered_map TensorPipeAgent::getMetrics() { auto averageGilWaitTime = timeSeriesMetrics_[kGilAverageWaitTime].computeAverage(); lock.unlock(); - metrics[kGilAverageWaitTime] = c10::to_string(averageGilWaitTime); + metrics[kGilAverageWaitTime] = std::to_string(averageGilWaitTime); } } diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 822079b12ecf..bde9d1ad61ad 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -389,7 +389,7 @@ std::string wireSerialize( // out of scope of this loop. auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]); entries.push_back( - {c10::to_string(i), + {std::to_string(i), writeableTensorData.data(), writeableTensorData.sizeInBytes()}); } @@ -401,7 +401,7 @@ std::string wireSerialize( tot += e.size; header.append(e.name) .append(" ") - .append(c10::to_string(e.size)) + .append(std::to_string(e.size)) .append("\n"); } header.push_back('\n'); diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index c3321b244735..d2eb41f51115 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2,8 +2,11 @@ #include #include #include +#include #include +#include #include +#include #include #include #include @@ -742,6 +745,27 @@ static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } +static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) { + HANDLE_TH_ERRORS; + static PythonArgParser parser( + {"_reinterpret_tensor(Tensor base, IntArrayRef sizes, IntArrayRef strides, int64_t offset_increment=0)"}, + /*traceable=*/true); + + ParsedArgs<4> parsed_args; + auto r = parser.parse(args, /*kwargs=*/nullptr, parsed_args); + + Tensor self = r.tensor(0); + auto sizes = r.intlist(1); + auto strides = r.intlist(2); + auto offset_increment = r.toInt64(3); + + auto res = torch::inductor::_reinterpret_tensor( + self, sizes, strides, offset_increment); + return torch::autograd::utils::wrap(res); + + END_HANDLE_TH_ERRORS; +} + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) static PyMethodDef _methods[] = { {"check_type_id", check_type_id, METH_VARARGS, nullptr}, @@ -750,6 +774,7 @@ static PyMethodDef _methods[] = { {"dict_version", dict_version, METH_VARARGS, nullptr}, {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr}, {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr}, + {"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { @@ -3222,13 +3247,13 @@ void install_tensor_aliasing_guard( void install_no_tensor_aliasing_guard( const py::list& guard_managers, - py::list tensor_names, + const py::list& tensor_names, py::object verbose_code_parts) { // Adds a guard that checks none of tensors alias. This is a an example of // relational guard. There is one guard object that is shared between multiple // guard managers. std::shared_ptr guard = std::make_shared( - std::move(tensor_names), std::move(verbose_code_parts)); + tensor_names, std::move(verbose_code_parts)); // Register the resetter on the toor guard mananger, so that it can reset // the newly added relational guard when the guard eval fails. @@ -3981,7 +4006,15 @@ PyObject* torch_c_dynamo_guards_init() { DictSubclassGuardManager, DictGuardManager, std::unique_ptr>( - py_m, "DictSubclassGuardManager"); // NOLINT + py_m, "DictSubclassGuardManager") // NOLINT + .def( + "add_no_hasattr_guard", + [](DictSubclassGuardManager& self, + py::object attr_name, + py::object verbose_code_parts) -> void { + self.add_permitted_leaf_guard(std::make_shared( + std::move(attr_name), std::move(verbose_code_parts))); + }); py_m.def("install_tensor_aliasing_guard", install_tensor_aliasing_guard); py_m.def( diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 3a79a7bc6372..2e5cb3bfab02 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -186,6 +186,7 @@ struct CacheNode { next.clear(); key_storage.clear(); expected_sizes.clear(); + runtime_wrapper = nullptr; compiled_fn = nullptr; } @@ -193,10 +194,12 @@ struct CacheNode { return next.empty() && !compiled_fn; } - CacheNode() : compiled_fn(nullptr) {} + CacheNode() : runtime_wrapper(nullptr), compiled_fn(nullptr) {} ~CacheNode() { if (!Py_IsInitialized()) { - compiled_fn.release(); // leak on shutdown + // leak on shutdown + runtime_wrapper.release(); + compiled_fn.release(); } } CacheNode(CacheNode&&) = delete; @@ -250,6 +253,7 @@ struct CacheNode { if (!cache_hit) { // we missed cache because static size inputs didn't match; force // recompilation with the varying size input as dynamic + runtime_wrapper = nullptr; compiled_fn = nullptr; } return cache_hit; @@ -298,6 +302,7 @@ struct CacheNode { std::vector key_storage; std::vector expected_sizes; + THPObjectPtr runtime_wrapper; THPObjectPtr compiled_fn; }; @@ -591,12 +596,22 @@ CacheNode* _compiled_autograd_impl( } } - cache->compiled_fn = check(call_end_capture(py_compiler, state.outputs)); + PyObject* res = check(call_end_capture(py_compiler, state.outputs)); + TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); + TORCH_CHECK( + PyTuple_Size(res) == 2, + "Expected end_capture to return tuple of size 2"); + cache->runtime_wrapper = Py_NewRef(PyTuple_GetItem(res, 0)); + TORCH_CHECK( + PyCallable_Check(cache->runtime_wrapper), + "Expected end_capture to return runtime_wrapper"); + cache->compiled_fn = Py_NewRef(PyTuple_GetItem(res, 1)); + TORCH_CHECK( + PyCallable_Check(cache->compiled_fn), + "Expected end_capture to return compiled_fn"); state.debug_asserts(); } // End cache miss region - // TODO(jansel): we should release all the variables and then use a - // boxed calling convention so activation memory can be freed // TODO(jansel): clear grads we will overwrite below if (!graph_task.keep_graph_) { for (auto& call : calls) { @@ -615,9 +630,6 @@ variable_list compiled_autograd( GraphTask& graph_task, bool accumulate_grad, const edge_list& output_edges) { - TORCH_CHECK( - output_edges.empty() || !accumulate_grad, - "specifying inputs= with .backward() not yet implemented for compiled autograd") TORCH_CHECK( c10::impl::TorchDispatchModeTLS::stack_len() == 0, "TorchDispatchMode not yet implemented for compiled autograd") @@ -639,7 +651,12 @@ variable_list compiled_autograd( &hooks); THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs( - cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL))); + cache->runtime_wrapper.get(), + cache->compiled_fn.get(), + inputs.get(), + sizes.get(), + hooks.get(), + NULL))); variable_list outputs = THPVariable_UnpackList(pyresult); TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size()); return outputs; diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp new file mode 100644 index 000000000000..dc96737abdab --- /dev/null +++ b/torch/csrc/fx/node.cpp @@ -0,0 +1,257 @@ +#include + +#include +#include + +//////////////////////////////// +// NodeBase +/////////////////////////////// + +struct NodeBase { + PyObject_HEAD bool _erased; + NodeBase* _prev; + NodeBase* _next; +}; + +static PyObject* NodeBase_new( + PyTypeObject* type, + PyObject* args, + PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + if (!self) + return nullptr; + return self; +} + +static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { + self->_erased = false; + Py_INCREF(self); + self->_prev = self; + Py_INCREF(self); + self->_next = self; + return 0; +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) +static struct PyMemberDef NodeBase_members[] = { + {"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr}, + {"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr}, + {"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr}, + {nullptr} /* Sentinel */ +}; + +static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { + Py_VISIT(self->_prev); + Py_VISIT(self->_next); + return 0; +} + +static int NodeBase_clear(NodeBase* self) { + Py_CLEAR(self->_prev); + Py_CLEAR(self->_next); + return 0; +} + +static void NodeBase_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + (void)NodeBase_clear((NodeBase*)self); + Py_TYPE(self)->tp_free(self); +} + +static PyTypeObject NodeBaseType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ + sizeof(NodeBase), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)NodeBase_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_HAVE_GC, /* tp_flags */ + nullptr, /* tp_doc */ + (traverseproc)NodeBase_traverse, /* tp_traverse */ + (inquiry)NodeBase_clear, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + NodeBase_members, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)NodeBase_init_fn, /* tp_init */ + nullptr, /* tp_alloc */ + NodeBase_new, /* tp_new */ +}; + +bool NodeBase_init(PyObject* module) { + if (PyModule_AddType(module, &NodeBaseType) < 0) { + return false; + } + return true; +} + +//////////////////////////////// +// NodeIter +//////////////////////////////// + +struct NodeIter { + PyObject_HEAD bool _reversed; + NodeBase* _root; + NodeBase* _cur; +}; + +static PyObject* NodeIter_new( + PyTypeObject* type, + PyObject* args, + PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + if (!self) + return nullptr; + return self; +} + +static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) { + NodeBase* root = nullptr; + bool reversed = false; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + constexpr const char* keywords[] = {"root", "reversed", nullptr}; + if (!PyArg_ParseTupleAndKeywords( + args, + kwargs, + "Ob|", + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(keywords), + &root, + &reversed)) { + return -1; + } + self->_reversed = reversed; + Py_INCREF(root); + self->_root = root; + Py_INCREF(root); + self->_cur = root; + return 0; +} + +template +PyObject* NodeIter_iternext_helper(NodeIter* self) { + // It should be possible to relax the ref counting here + // but in practice, we do not have that many _erased Nodes, + // so probably not worth it. + if constexpr (reversed) { + NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); + Py_CLEAR(self->_cur); + self->_cur = prev; + } else { + NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); + Py_CLEAR(self->_cur); + self->_cur = next; + } + while (self->_cur != self->_root) { + if (!self->_cur->_erased) { + Py_INCREF(self->_cur); + return (PyObject*)self->_cur; + } + if constexpr (reversed) { + NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev); + Py_CLEAR(self->_cur); + self->_cur = prev; + } else { + NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next); + Py_CLEAR(self->_cur); + self->_cur = next; + } + } + PyErr_SetNone(PyExc_StopIteration); + return nullptr; +} + +PyObject* NodeIter_iternext(PyObject* _self) { + NodeIter* self = (NodeIter*)_self; + if (self->_reversed) { + return NodeIter_iternext_helper(self); + } else { + return NodeIter_iternext_helper(self); + } +} + +static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) { + Py_VISIT(self->_root); + Py_VISIT(self->_cur); + return 0; +} + +static int NodeIter_clear(NodeIter* self) { + Py_CLEAR(self->_root); + Py_CLEAR(self->_cur); + return 0; +} + +static void NodeIter_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + (void)NodeIter_clear((NodeIter*)self); + Py_TYPE(self)->tp_free(self); +} + +static PyTypeObject NodeIterType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */ + sizeof(NodeIter), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)NodeIter_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ + nullptr, /* tp_doc */ + (traverseproc)NodeIter_traverse, /* tp_traverse */ + (inquiry)NodeIter_clear, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + NodeIter_iternext, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)NodeIter_init_fn, /* tp_init */ + nullptr, /* tp_alloc */ + NodeIter_new, /* tp_new */ +}; + +bool NodeIter_init(PyObject* module) { + if (PyModule_AddType(module, &NodeIterType) < 0) { + return false; + } + return true; +} diff --git a/torch/csrc/fx/node.h b/torch/csrc/fx/node.h new file mode 100644 index 000000000000..2ea74e839f25 --- /dev/null +++ b/torch/csrc/fx/node.h @@ -0,0 +1,6 @@ +#pragma once + +#include + +bool NodeBase_init(PyObject* module); +bool NodeIter_init(PyObject* module); diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index ba716e213a0f..65fbbd9fc23d 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -475,7 +475,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle* out); AOTI_TORCH_EXPORT AOTITorchError -aoti_check_inf_and_nan(AtenTensorHandle tensor); +aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( AtenTensorHandle out, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index a96cdaee5eb3..c973f69cb69d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -19,8 +19,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__addmm_activation(AtenTensorHan AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cudnn_rnn(AtenTensorHandle input, const AtenTensorHandle* weight, int64_t weight_len_, int64_t weight_stride0, AtenTensorHandle* weight_buf, AtenTensorHandle hx, AtenTensorHandle* cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, int32_t batch_first, double dropout, int32_t train, int32_t bidirectional, const int64_t* batch_sizes, int64_t batch_sizes_len_, AtenTensorHandle* dropout_state, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle out, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, AtenTensorHandle logsumexp, double dropout_p, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, int64_t custom_mask_type, int32_t bias_requires_grad, double* scale, int64_t* num_splits_key, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t* max_seqlen_q, int64_t* max_seqlen_k, double dropout_p, int64_t custom_mask_type, int32_t compute_log_sumexp, double* scale, AtenTensorHandle* causal_diagonal, AtenTensorHandle* seqlen_k, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle out, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, AtenTensorHandle logsumexp, double dropout_p, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, int64_t custom_mask_type, int32_t bias_requires_grad, double* scale, int64_t* num_splits_key, int64_t* window_size, int32_t shared_storage_dqdkdv, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t* max_seqlen_q, int64_t* max_seqlen_k, double dropout_p, int64_t custom_mask_type, int32_t compute_log_sumexp, double* scale, AtenTensorHandle* seqlen_k, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 6f93407aa467..1306c006ba94 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -770,17 +770,13 @@ AOTITorchError aoti_torch_repeat_interleave_Tensor( } // Function to check existence of inf and NaN -AOTITorchError aoti_check_inf_and_nan(AtenTensorHandle tensor) { +AOTITorchError aoti_torch_check_inf_and_nan( + const char* tensor_name, + AtenTensorHandle tensor) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* check_tensor = tensor_handle_to_tensor_pointer(tensor); - auto flattened = check_tensor->view({-1}); - for (int64_t i = 0; i < flattened.numel(); i++) { - auto value = flattened[i].item(); - if (std::isinf(value) || std::isnan(value)) { - assert(false); - } - } + assert_inf_and_nan(tensor_name, *check_tensor); }); } diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 44ca34b1c6e8..6e7bd355c57c 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -48,6 +48,21 @@ inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) { return tensor_pointer_to_tensor_handle(new_tensor); } +inline void assert_inf_and_nan( + const std::string& tensor_name, + at::Tensor& check_tensor) { + auto flattened = check_tensor.view({-1}); + + for (int64_t i = 0; i < flattened.numel(); i++) { + auto value = flattened[i].item(); + if (std::isinf(value)) { + throw std::runtime_error("At least one INF in " + tensor_name); + } else if (std::isnan(value)) { + throw std::runtime_error("At least one NaN in " + tensor_name); + } + } +} + // utility functions to convert a pointer to an optional value template inline std::optional pointer_to_optional(T* ptr) { diff --git a/torch/csrc/jit/api/function_impl.cpp b/torch/csrc/jit/api/function_impl.cpp index c0f0b4e486b4..5f25ce51702a 100644 --- a/torch/csrc/jit/api/function_impl.cpp +++ b/torch/csrc/jit/api/function_impl.cpp @@ -28,7 +28,7 @@ c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) { for (const auto i : c10::irange(num_inputs)) { const Value* v = g.inputs().at(i); std::string name = v->hasDebugName() ? v->debugNameBase() - : ("argument_" + c10::to_string(i)); + : ("argument_" + std::to_string(i)); args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); } for (const auto i : c10::irange(g.outputs().size())) { diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index e779542e315f..92b9c96c3a6e 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -541,9 +541,7 @@ struct slot_list_impl { size_t size() const { if (!size_) { size_ = size_t(0); - // NOLINTNEXTLINE(clang-diagnostic-unused-variable) - for (const value_type& s : *(this)) { - (void)s; // Suppress unused variable warning + for ([[maybe_unused]] const value_type& _ : *(this)) { ++*size_; } } diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 2f9217e13369..a2d26979c1e0 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -30,15 +30,15 @@ size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes}; )"); static std::string valueName(const Value* n) { - return "n" + c10::to_string(n->unique()); + return "n" + std::to_string(n->unique()); } static std::string scalarValue(const int64_t v) { - return c10::to_string(v); + return std::to_string(v); } static std::string scalarValue(const bool v) { - return c10::to_string(v); + return std::to_string(v); } // Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific @@ -274,10 +274,10 @@ static std::string encodeRHS(const Node* n) { // PyTorch converts (scalar) argument types to result before applying the // operator e.g. 1.4-torch.tensor(3) = -2 env.s( - c10::to_string(i), + std::to_string(i), typeCastedValueName(*in->type(), *outtype, valueName(in))); // Uncasted operands only used for comparison operators - env.s(c10::to_string(i) + "_nocast", valueName(in)); + env.s(std::to_string(i) + "_nocast", valueName(in)); i++; } @@ -391,7 +391,7 @@ std::string generateKernel( 1); // + 1 because the first argument is the linearIndex std::string tensor = "t" + - c10::to_string( + std::to_string( formals.size()); // can't be unique() because Param may be an output const auto nDim = desc.nDim(); emitCheckFor(tensorChecks, tensor, nDim, desc); @@ -413,7 +413,7 @@ std::string generateKernel( 1); // + 1 because the first argument is the linearIndex std::string scalar = "s" + - c10::to_string( + std::to_string( formals.size()); // can't be unique() because Param may be an output env.d( "formal_index", diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 3c05b70e8341..b4bc3e8f4727 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -281,7 +281,7 @@ std::shared_ptr compileKernel( } const bool use_cuda = device.is_cuda(); - const std::string name = "kernel_" + c10::to_string(next_kernel_id++); + const std::string name = "kernel_" + std::to_string(next_kernel_id++); std::string code = generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda); const FusedKernelConstructor& kernel_ctor = diff --git a/torch/csrc/jit/frontend/function_schema_parser.cpp b/torch/csrc/jit/frontend/function_schema_parser.cpp index 13497c20e15c..ba86a891d31d 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.cpp +++ b/torch/csrc/jit/frontend/function_schema_parser.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -157,7 +156,7 @@ struct SchemaParser { // note: an array with a size hint can only occur at the Argument level fake_type = ListType::create(std::move(fake_type)); real_type = ListType::create(std::move(real_type)); - N = c10::stoll(L.expect(TK_NUMBER).text()); + N = std::stoll(L.expect(TK_NUMBER).text()); L.expect(']'); auto container = type_parser.parseAliasAnnotation(); if (alias_info) { @@ -244,14 +243,14 @@ struct SchemaParser { n = L.expect(TK_NUMBER).text(); if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) { - auto imag = c10::stod(n.substr(0, n.size() - 1)); + auto imag = std::stod(n.substr(0, n.size() - 1)); return c10::complex(0, imag); } else if ( kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) { - return c10::stod(n); + return std::stod(n); } else { - int64_t v = c10::stoll(n); + int64_t v = std::stoll(n); return v; } } diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 0aca3ea80062..350305b83567 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -722,7 +722,7 @@ struct to_ir { std::vector def_stack_; size_t temp_name_count_ = 0; std::string createTempName(const std::string& prefix) { - return prefix + c10::to_string(temp_name_count_++); + return prefix + std::to_string(temp_name_count_++); } void pushFrame(Block* b, bool starts_def = false) { @@ -3222,7 +3222,7 @@ struct to_ir { case TK_IN: return aten::__contains__; default: - throw std::runtime_error("unknown kind " + c10::to_string(kind)); + throw std::runtime_error("unknown kind " + std::to_string(kind)); } } @@ -3269,7 +3269,7 @@ struct to_ir { case TK_RSHIFT: return "__rshift__"; default: - throw std::runtime_error("unknown kind " + c10::to_string(kind)); + throw std::runtime_error("unknown kind " + std::to_string(kind)); } } diff --git a/torch/csrc/jit/frontend/name_mangler.cpp b/torch/csrc/jit/frontend/name_mangler.cpp index fbf1d24932e8..698bdd1e67b7 100644 --- a/torch/csrc/jit/frontend/name_mangler.cpp +++ b/torch/csrc/jit/frontend/name_mangler.cpp @@ -21,7 +21,7 @@ c10::QualifiedName NameMangler::mangle(const c10::QualifiedName& name) { // Append the part of the name up to the end of the prefix newAtomPrefix.append(atom, 0, pos); newAtomPrefix.append(manglePrefix); - atom = newAtomPrefix + c10::to_string(mangleIndex_++); + atom = newAtomPrefix + std::to_string(mangleIndex_++); // increment mangleIndex_ until the type is not defined return c10::QualifiedName(atoms); } @@ -29,7 +29,7 @@ c10::QualifiedName NameMangler::mangle(const c10::QualifiedName& name) { // Otherwise add a mangle namespace right before the basename TORCH_INTERNAL_ASSERT(!atoms.empty()); - atoms.insert(atoms.end() - 1, manglePrefix + c10::to_string(mangleIndex_++)); + atoms.insert(atoms.end() - 1, manglePrefix + std::to_string(mangleIndex_++)); return c10::QualifiedName(atoms); } diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index b81a6c720770..2adacb976a04 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index a1cc856a22e1..fef018dc7388 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -381,12 +381,12 @@ TORCH_API void ensureUniqueIfOutOfPlaced( template < typename T, - typename = torch::enable_if_t< - (!std::is_convertible_v, at::TensorList> && - !std::is_convertible_v, c10::List> && - !std::is_convertible_v, at::Tensor> && + typename = std::enable_if_t< + (!std::is_convertible_v, at::TensorList> && + !std::is_convertible_v, c10::List> && + !std::is_convertible_v, at::Tensor> && !std::is_convertible_v< - torch::decay_t, + std::decay_t, c10::intrusive_ptr>)>> void addOutput(Node* node, T&&) { AT_ERROR( diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index a6488c92f406..77d06bee94a9 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -1,5 +1,4 @@ #pragma once -#include #include #include #include @@ -1032,7 +1031,7 @@ struct SliceExpr : public Expr { private: Expr createInt(int64_t value) const { - return Expr(Const::create(range(), c10::to_string(value))); + return Expr(Const::create(range(), std::to_string(value))); } }; diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index c39ceb7e91f9..a6b0116d7fb6 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -128,12 +128,6 @@ static std::ostream& operator<<( return printValueRefs(out, nodes); } -static std::ostream& operator<<( - std::ostream& out, - const at::ArrayRef nodes) { - return printValueRefs(out, nodes); -} - struct const_value_list_with_types { const ArrayRef values; std::string delim; diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 549f4a11001f..859da3cb3cae 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -224,7 +224,7 @@ struct Value { if (hasDebugName()) { return unique_name_; } - return c10::to_string(unique()); + return std::to_string(unique()); } TORCH_API std::string debugNameBase() const; Node* node() { diff --git a/torch/csrc/jit/ir/named_value.h b/torch/csrc/jit/ir/named_value.h index 277e7f269969..a594b4d045e9 100644 --- a/torch/csrc/jit/ir/named_value.h +++ b/torch/csrc/jit/ir/named_value.h @@ -30,18 +30,18 @@ struct NamedValue { template < typename T, - typename = enable_if_t< - (!std::is_same, NamedValue>::value && - !std::is_same, Value*>::value && - !std::is_same, IValue>::value)>> + typename = std::enable_if_t< + (!std::is_same_v, NamedValue> && + !std::is_same_v, Value*> && + !std::is_same_v, IValue>)>> // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) NamedValue(T&& t) : NamedValue(IValue(std::forward(t))) {} template < typename T, - typename = enable_if_t< - (!std::is_same, Value*>::value && - !std::is_same, IValue>::value)>> + typename = std::enable_if_t< + (!std::is_same_v, Value*> && + !std::is_same_v, IValue>)>> NamedValue(const std::string& name, T&& t) : NamedValue(name, IValue(std::forward(t))) {} diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 09c5df58f0be..f0dd562cc1cd 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -348,7 +348,7 @@ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { for (const auto& td : data_pickle.tensorData()) { WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = prefix + c10::to_string(i++); + std::string fname = prefix + std::to_string(i++); writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); } std::string fname = archive_name + ".pkl"; diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 239deb76d267..bca407358913 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -296,6 +296,11 @@ mobile::Module FlatbufferLoader::parseModule( "Parsing flatbuffer module: Corrupted ivalues/object_types field"); TORCH_CHECK( reinterpret_cast(ivalues) < end, "Corrupted ivalues field"); + TORCH_CHECK( + module->storage_data_size() >= 0, + "Parsing flatbuffer module: illegal storage_data_size: ", + module->storage_data_size(), + ", expected to be non negative"); all_ivalues_.resize(ivalues->size()); all_types_.resize(module->object_types()->size()); storages_.resize(module->storage_data_size()); diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 1f7ba264048f..98638ff62e26 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -383,6 +383,8 @@ static std::vector> generateExampleInputs( return example_inputs; } +// TODO(mvz): temporarily disable NNC backend in mobile builds. +/* static c10::IValue preprocess( const torch::jit::Module& mod, const c10::Dict& compile_spec, @@ -440,8 +442,8 @@ static c10::IValue preprocess( } return cu.serialize(); } +*/ -// TODO(mvz): temporarily disable NNC backend in mobile builds. // static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); } // namespace nnc diff --git a/torch/csrc/jit/mobile/train/export_data.cpp b/torch/csrc/jit/mobile/train/export_data.cpp index 731ffef15424..aeb9f95dad67 100644 --- a/torch/csrc/jit/mobile/train/export_data.cpp +++ b/torch/csrc/jit/mobile/train/export_data.cpp @@ -61,7 +61,7 @@ class IValuePickler final { std::string prefix = archive_name + "/"; for (const auto& td : data_pickle.tensorData()) { WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = prefix + c10::to_string(i++); + std::string fname = prefix + std::to_string(i++); writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); } std::string fname = archive_name + ".pkl"; diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 052ba45ceb40..7fac68aec4d7 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -464,18 +464,6 @@ static void BatchMMSide(Block* block, AliasDb& alias_db) { } } -static bool hasMutableOperators(Block* block) { - for (auto n : block->nodes()) { - if (n->kind().is_aten() && n->schema().is_mutable()) - return true; - for (auto b : n->blocks()) { - if (hasMutableOperators(b)) - return true; - } - } - return false; -} - static bool hasMMOperators(std::shared_ptr& graph) { DepthFirstGraphNodeIterator it(graph); Node* n = nullptr; diff --git a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp index 6f1aa4aee308..b4c0fd053511 100644 --- a/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp +++ b/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp @@ -388,7 +388,7 @@ std::string mangleMethodName( for (size_t method_idx = 0;; method_idx++) { auto mangled = method_name; if (method_idx != 0) { - mangled += c10::to_string(method_idx); + mangled += std::to_string(method_idx); } bool found = false; for (Function* fn : mod_type->methods()) { diff --git a/torch/csrc/jit/passes/hoist_conv_packed_params.cpp b/torch/csrc/jit/passes/hoist_conv_packed_params.cpp index c3db2373f2a3..5034626923b5 100644 --- a/torch/csrc/jit/passes/hoist_conv_packed_params.cpp +++ b/torch/csrc/jit/passes/hoist_conv_packed_params.cpp @@ -64,10 +64,10 @@ static void hoistConvPackedParams( } std::string newNameBase = prefix + "." + suffix + "_packed_params"; nameUniqueCounter++; - std::string newName = newNameBase + "." + c10::to_string(nameUniqueCounter); + std::string newName = newNameBase + "." + std::to_string(nameUniqueCounter); while (rootModule.hasattr(newName)) { nameUniqueCounter++; - newName = newNameBase + "." + c10::to_string(nameUniqueCounter); + newName = newNameBase + "." + std::to_string(nameUniqueCounter); } // copy the packed params diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 73c19851e569..b468e739a03f 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -710,7 +710,7 @@ static void eraseListUnpack(Node* n, int opset_version) { // onnx::SequenceAt was introduced in onnx opset version 11 throw std::runtime_error( "Unsupported: ONNX export of prim::ListUnpack in opset " + - c10::to_string(opset_version) + ". Please try opset version 11."); + std::to_string(opset_version) + ". Please try opset version 11."); } auto g = n->owningGraph(); diff --git a/torch/csrc/jit/passes/prepack_folding.cpp b/torch/csrc/jit/passes/prepack_folding.cpp index 1c7372e23633..d37201c5b3d5 100644 --- a/torch/csrc/jit/passes/prepack_folding.cpp +++ b/torch/csrc/jit/passes/prepack_folding.cpp @@ -30,7 +30,7 @@ void PrePackingOpsFolder( if (optional_outputs) { auto outputs = optional_outputs.value(); TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output"); - auto attr_name = attr_name_base + c10::to_string(uid++); + auto attr_name = attr_name_base + std::to_string(uid++); TORCH_CHECK( !(m.type()->findAttributeSlot(attr_name)), "Attribute name ", diff --git a/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp b/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp index 65e900d3888a..2c83bcbc10e1 100644 --- a/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp +++ b/torch/csrc/jit/passes/quantization/dedup_module_uses.cpp @@ -97,9 +97,9 @@ class ModuleUseDeduper { // Original name of the child module const std::string& original_name = path[path.size() - 1]; int uid = 0; - std::string child_name = original_name + "_" + c10::to_string(uid++); + std::string child_name = original_name + "_" + std::to_string(uid++); while (parent_of_leaf.hasattr(child_name)) { - child_name = original_name + "_" + c10::to_string(uid++); + child_name = original_name + "_" + std::to_string(uid++); } parent_of_leaf.register_module(child_name, child_module.deepcopy()); return child_name; diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index de1cff1ba9d1..145448210958 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -953,9 +953,9 @@ void InsertObserversHelper::insertObserverFor( } GRAPH_DEBUG("Inserting observer for:", v->debugName()); Module observer = observer_module.deepcopy(); - std::string observer_name = "_observer_" + c10::to_string(uid_++); + std::string observer_name = "_observer_" + std::to_string(uid_++); while (module.hasattr(observer_name)) { - observer_name = "_observer_" + c10::to_string(uid_++); + observer_name = "_observer_" + std::to_string(uid_++); } module.register_module(observer_name, observer); observer_name_and_modules.emplace_back(observer_name, observer); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 02f4f1096976..92fb2fc79bcc 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1042,10 +1042,10 @@ void InsertQuantDeQuantHelper::quantizeTensors( const auto& qparam = pr.second; size_t uid = 0; auto qparam_name = - original_value->debugName() + name + "_" + c10::to_string(uid++); + original_value->debugName() + name + "_" + std::to_string(uid++); while (module.hasattr(qparam_name)) { qparam_name = - original_value->debugName() + name + "_" + c10::to_string(uid++); + original_value->debugName() + name + "_" + std::to_string(uid++); } qparam_name_map_for_node_[n][name] = qparam_name; module.register_attribute(qparam_name, qparam.type(), qparam); diff --git a/torch/csrc/jit/passes/quantization/register_packed_params.cpp b/torch/csrc/jit/passes/quantization/register_packed_params.cpp index bd93c6535e61..1d7dcfe72eea 100644 --- a/torch/csrc/jit/passes/quantization/register_packed_params.cpp +++ b/torch/csrc/jit/passes/quantization/register_packed_params.cpp @@ -73,13 +73,13 @@ std::unordered_set RegisterPrePackParams( WithInsertPoint ins(n->next()); Value* packed_param_value = n->output(0); TORCH_CHECK(n->outputs().size() == 1, "Prepack ops have single output"); - auto attr_name = attr_name_base + c10::to_string(uid++); + auto attr_name = attr_name_base + std::to_string(uid++); TORCH_CHECK( packed_param_value->uses().size() == 1, "Packed param must be used by exactly one op."); auto use = packed_param_value->uses()[0]; while (m.hasattr(attr_name)) { - attr_name = attr_name_base + "_" + c10::to_string(uid++); + attr_name = attr_name_base + "_" + std::to_string(uid++); } // Now register attribute for this packed param but dont set it to any // value. No value because we dont know what the value is at this point. diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index 1bb82432e218..377621c04b6d 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -606,7 +606,7 @@ static std::string truncateStrWithHash(const std::string& s, size_t maxlen) { if (s.size() <= maxlen) { return s; } - std::string hash_str = c10::to_string(c10::hash{}(s)); + std::string hash_str = std::to_string(c10::hash{}(s)); // If hash-string plus '_' can fit into maxlen, then truncate the original // string correspondingly so that the final string with the hash included fits // into maxlen. If that's not possible, at least truncate the original string diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7ce337f9ac8..818f09bee7bc 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) + SYMNODE_BINARY(int_truediv) + SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) + SYMNODE_BINARY(float_pow) + SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) @@ -1774,13 +1779,18 @@ void initJITBindings(PyObject* module) { [](py::handle op_overload_packet, py::args args, py::kwargs kwargs) { py::list ns_method = op_overload_packet.attr("_qualified_op_name").attr("split")("::"); - return _maybe_handle_torch_function( + auto res = _maybe_handle_torch_function( py::cast(ns_method[0]), py::cast(ns_method[1]), "", false, args, kwargs); + if (res) { + return py::make_tuple(true, *res); + } else { + return py::make_tuple(false, py::none()); + } }); m.def( diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 4cfe3309a766..a731640223c0 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -13,6 +13,7 @@ #include #include +#include namespace torch::jit { @@ -816,7 +817,7 @@ py::object invokeOperatorFromPython( return createPyObjectForStack(std::move(stack)); } -py::tuple _maybe_handle_torch_function( +std::optional _maybe_handle_torch_function( const std::string& ns, const std::string& method_name, const std::string& overload_name, @@ -861,18 +862,16 @@ py::tuple _maybe_handle_torch_function( } std::string module_name("torch.ops"); module_name.append(ns); - return py::make_tuple( - true, - pybind11::reinterpret_steal( - handle_torch_function_no_python_arg_parser( - overloaded_args, - args.ptr(), - kwargs.ptr(), - method_name.c_str(), - self_func.ptr(), - module_name.c_str()))); + return {pybind11::reinterpret_steal( + handle_torch_function_no_python_arg_parser( + overloaded_args, + args.ptr(), + kwargs.ptr(), + method_name.c_str(), + self_func.ptr(), + module_name.c_str()))}; } - return py::make_tuple(false, py::none()); + return std::nullopt; } py::object _get_operation_for_overload_or_packet( @@ -887,9 +886,9 @@ py::object _get_operation_for_overload_or_packet( std::string overload_name = operations[0]->schema().overload_name(); auto res = _maybe_handle_torch_function( ns, method_name, overload_name, is_overload, args, kwargs); - auto torch_function_called = py::cast(res[0]); + auto torch_function_called = res.has_value(); return torch_function_called - ? res[1] + ? *res : invokeOperatorFromPython(operations, args, kwargs, dk); } diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 242da11af7c0..23fda5b0d784 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -1257,7 +1257,7 @@ TORCH_PYTHON_API py::object invokeOperatorFromPython( const py::kwargs& kwargs, std::optional dk = c10::nullopt); -TORCH_PYTHON_API py::tuple _maybe_handle_torch_function( +TORCH_PYTHON_API std::optional _maybe_handle_torch_function( const std::string& ns, const std::string& method_name, const std::string& overload_name, diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 55cde36c0e62..bf9e516566e5 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -1,3 +1,4 @@ +#include #include #include diff --git a/torch/csrc/jit/python/python_custom_class.h b/torch/csrc/jit/python/python_custom_class.h index d7cff488f273..1033fc008f27 100644 --- a/torch/csrc/jit/python/python_custom_class.h +++ b/torch/csrc/jit/python/python_custom_class.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include diff --git a/torch/csrc/jit/runtime/logging.h b/torch/csrc/jit/runtime/logging.h index b0b67c680883..fda364e0a923 100644 --- a/torch/csrc/jit/runtime/logging.h +++ b/torch/csrc/jit/runtime/logging.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index 7335f132dfbf..a057367af81c 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -133,7 +133,7 @@ void checkDoubleInRange(double a) { a > double(std::numeric_limits::max()) || a < double(std::numeric_limits::min())) { throw c10::Error( - "Cannot convert float " + c10::to_string(a) + " to integer"); + "Cannot convert float " + std::to_string(a) + " to integer"); return; } } diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 15e59acb9fe6..3386bc3e4a49 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -32,7 +32,6 @@ #include #include #include -#include namespace torch::jit { constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { diff --git a/torch/csrc/jit/runtime/static/generated_ops.cpp b/torch/csrc/jit/runtime/static/generated_ops.cpp index af61ee72a00e..4597e1298cd6 100644 --- a/torch/csrc/jit/runtime/static/generated_ops.cpp +++ b/torch/csrc/jit/runtime/static/generated_ops.cpp @@ -36,7 +36,8 @@ #include #include -namespace torch::jit { +namespace torch { +namespace jit { REGISTER_OPERATOR_FUNCTOR( aten::absolute, @@ -190,6 +191,29 @@ REGISTER_OPERATOR_FUNCTOR(aten::addr, aten_addr, [](Node* n) -> SROperator { return nullptr; }); +REGISTER_OPERATOR_FUNCTOR( + aten::_test_functorch_fallback, + aten__test_functorch_fallback, + [](Node* n) -> SROperator { + if (n->matches(torch::schema( + "aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor"))) { + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto& other = p_node->Input(1).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = + at::native::_test_functorch_fallback(self, other); + return; + } + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + at::native::_test_functorch_fallback_out(self, other, out); + }; + } + LogAndDumpSchema(n); + return nullptr; + }); + REGISTER_OPERATOR_FUNCTOR(aten::argmax, aten_argmax, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"))) { @@ -2430,6 +2454,25 @@ REGISTER_OPERATOR_FUNCTOR(aten::addbmm, aten_addbmm, [](Node* n) -> SROperator { return nullptr; }); +REGISTER_OPERATOR_FUNCTOR(aten::diag, aten_diag, [](Node* n) -> SROperator { + if (n->matches( + torch::schema("aten::diag(Tensor self, int diagonal=0) -> Tensor"))) { + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto diagonal = p_node->Input(1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::diag(self, diagonal); + return; + } + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + at::native::diag_out(self, diagonal, out); + }; + } + LogAndDumpSchema(n); + return nullptr; +}); + REGISTER_OPERATOR_FUNCTOR(aten::cross, aten_cross, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor"))) { @@ -2684,6 +2727,30 @@ REGISTER_OPERATOR_FUNCTOR( return nullptr; }); +REGISTER_OPERATOR_FUNCTOR( + aten::nonzero_static, + aten_nonzero_static, + [](Node* n) -> SROperator { + if (n->matches(torch::schema( + "aten::nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor"))) { + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto size = p_node->Input(1).toInt(); + const auto fill_value = p_node->Input(2).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = + at::native::nonzero_static_cpu(self, size, fill_value); + return; + } + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + at::native::nonzero_static_out_cpu(self, size, fill_value, out); + }; + } + LogAndDumpSchema(n); + return nullptr; + }); + REGISTER_OPERATOR_FUNCTOR(aten::gather, aten_gather, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"))) { @@ -4463,132 +4530,6 @@ REGISTER_OPERATOR_FUNCTOR( return nullptr; }); -REGISTER_OPERATOR_FUNCTOR(aten::fft_fft, aten_fft_fft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_fft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_fft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_ifft, aten_fft_ifft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_ifft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_ifft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_rfft, aten_fft_rfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_rfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_rfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_irfft, aten_fft_irfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_irfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_irfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_hfft, aten_fft_hfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_hfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_hfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR(aten::fft_ihfft, aten_fft_ihfft, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto n = p_node->Input(1).toOptional(); - const auto dim = p_node->Input(2).toInt(); - const auto norm = p_node->Input(3).toOptional(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::fft_ihfft_symint(self, n, dim, norm); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::fft_ihfft_symint_out(self, n, dim, norm, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - REGISTER_OPERATOR_FUNCTOR( aten::linalg_cross, aten_linalg_cross, @@ -5281,4 +5222,5 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( return nullptr; }); -} // namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 193675672f6b..9dc31446d1e1 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -1871,8 +1870,8 @@ bool BlockRunner::check_for_memory_leak( // `BlockRunner::deallocateOutputTensors`. continue; } - const std::string error_msg = "Output " + c10::to_string(i) + ", %" + - val->debugName() + " of node " + c10::to_string(n) + + const std::string error_msg = "Output " + std::to_string(i) + ", %" + + val->debugName() + " of node " + std::to_string(n) + " which has kind " + pnode.node()->kind().toQualString() + " was not cleaned up"; if (output_ivalues.count(ival) == 0) { @@ -1948,8 +1947,8 @@ bool BlockRunner::checkOutputTensorMemoryLeaks() { const auto& t = ival->toTensor(); if (t.defined()) { auto* storage_impl = t.storage().unsafeGetStorageImpl(); - const std::string error_msg = "Output " + c10::to_string(i) + ", %" + - val->debugName() + " of node " + c10::to_string(n) + + const std::string error_msg = "Output " + std::to_string(i) + ", %" + + val->debugName() + " of node " + std::to_string(n) + " was not cleaned up"; TORCH_CHECK(storage_impl->data() == nullptr, error_msg); } diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 1b32600426ca..1237120cc806 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -321,7 +321,7 @@ std::string BlockCodeGen::GetUniqueFuncName(const std::string& func_prefix) { static int64_t counter = 0; ++counter; int64_t value = counter; - return func_prefix + "_" + c10::to_string(value); + return func_prefix + "_" + std::to_string(value); } void BlockCodeGen::Initialize() { diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index d0b9abaa1fa6..5666097f2dd4 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1178,7 +1178,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case kIsNan: return std::isnan(v); default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("Invalid op_type: " + std::to_string(op_type)); } } @@ -1198,7 +1198,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } default: throw std::runtime_error( - "Invalid integral op_type: " + c10::to_string(op_type)); + "Invalid integral op_type: " + std::to_string(op_type)); } } @@ -1208,7 +1208,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case kIsNan: return std::isnan(v); default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("Invalid op_type: " + std::to_string(op_type)); } } @@ -1224,7 +1224,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case kAtan2: return std::atan2(v1, v2); default: - throw std::runtime_error("Invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("Invalid op_type: " + std::to_string(op_type)); } } diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 9bbea1bd28a4..0959151fb734 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -9,7 +9,6 @@ #include #include -#include #include #include #include diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index cea5170afcfe..889eeafc028f 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -175,7 +175,7 @@ int Intrinsics::OpArgCount(IntrinsicsOp op_type) { case kRemainder: return 2; default: - throw std::runtime_error("invalid op_type: " + c10::to_string(op_type)); + throw std::runtime_error("invalid op_type: " + std::to_string(op_type)); } } diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index f35bafb332ea..89c3f96aba6e 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -827,7 +826,7 @@ class TORCH_API Intrinsics : public ExprNode { return "isnan"; default: throw std::runtime_error( - "invalid op_type: " + c10::to_string(op_type())); + "invalid op_type: " + std::to_string(op_type())); } } diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 50578a041457..d18a3d65f21e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -885,7 +884,7 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { inner1->set_gpu_thread_index(0); } else { throw std::runtime_error( - "Invalid loop-level: " + c10::to_string(loopLevels)); + "Invalid loop-level: " + std::to_string(loopLevels)); } } } @@ -953,7 +952,7 @@ std::string TensorExprKernel::getCodeGenName(BackendType backendType) { default: throw std::runtime_error( "invalid backend type: " + - c10::to_string(static_cast(backendType))); + std::to_string(static_cast(backendType))); } } @@ -1190,7 +1189,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { ToDtype(static_cast(*tt->scalarType()))); result = Compute( - "input" + c10::to_string(bufs_.size() + 1), + "input" + std::to_string(bufs_.size() + 1), size_handles, [&](const std::vector& axes) { ExprHandle idx = 0; diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 1b08286fbd9f..62a67af7fb14 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include @@ -3140,7 +3139,7 @@ void LoopNest::computeAt(StmtPtr s, ForPtr f) { for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' temp_indices[i] = - alloc(std::string("idx") + c10::to_string(i), dims[i]->dtype()); + alloc(std::string("idx") + std::to_string(i), dims[i]->dtype()); } // Prepare substitute rules for constructing the temp statement from the prod diff --git a/torch/csrc/jit/tensorexpr/operators/misc.cpp b/torch/csrc/jit/tensorexpr/operators/misc.cpp index 70991f6db1f4..938cab6ffd88 100644 --- a/torch/csrc/jit/tensorexpr/operators/misc.cpp +++ b/torch/csrc/jit/tensorexpr/operators/misc.cpp @@ -576,7 +576,7 @@ static Tensor computeCatWoConditionals( std::vector store_indices(dims.size()); for (int64_t i = 0; i < static_cast(dims.size()); ++i) { for_vars[i] = alloc( - "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), + "i" + std::to_string(inp_pos) + "_" + std::to_string(i), dims[i].dtype()); load_indices[i] = for_vars[i]; if (i == norm_concat_dim) { diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 939f82c616dc..5e57209f39e2 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -732,7 +732,7 @@ void RegisterizerReplacer::buildReplacements() { for (auto& info : infoSet_) { VarPtr v = alloc( info->buf()->name_hint() + "_" + - c10::to_string(getBufferAccessCount(info->buf())), + std::to_string(getBufferAccessCount(info->buf())), info->buf()->dtype()); info->replacement().var = v; diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp index 01065f5eff5b..1307e53577f4 100644 --- a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp @@ -1,6 +1,5 @@ #include -#include #include #include @@ -28,7 +27,7 @@ const std::string& UniqueNameManager::get_unique_name(VarPtr v) { int count_v = count++; std::string unique_name = name_hint; if (count_v > 0) { - unique_name += "_" + c10::to_string(count_v); + unique_name += "_" + std::to_string(count_v); } if (all_unique_names_.count(unique_name) == 0) { all_unique_names_.insert(unique_name); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 2dcb215c574b..415fda6165dd 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -121,6 +121,15 @@ static PyObject* MPSModule_driverAllocatedMemory( END_HANDLE_TH_ERRORS } +static PyObject* MPSModule_recommendedMaxMemory( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64( + at::detail::getMPSHooks().getRecommendedMaxMemory()); + END_HANDLE_TH_ERRORS +} + static PyObject* MPSModule_profilerStartTrace( PyObject* _unused, PyObject* args) { @@ -244,6 +253,10 @@ static struct PyMethodDef _MPSModule_methods[] = { MPSModule_driverAllocatedMemory, METH_NOARGS, nullptr}, + {"_mps_recommendedMaxMemory", + MPSModule_recommendedMaxMemory, + METH_NOARGS, + nullptr}, {"_mps_profilerStartTrace", MPSModule_profilerStartTrace, METH_VARARGS, diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index c2a43ce95b1d..d808555da8e4 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -47,6 +47,7 @@ const std::set kXpuTypes = { const std::set kMtiaTypes = { libkineto::ActivityType::MTIA_CCP_EVENTS, libkineto::ActivityType::MTIA_RUNTIME, + libkineto::ActivityType::MTIA_WORKLOADD, }; const std::set kPrivateUse1Types = { libkineto::ActivityType::GPU_MEMCPY, @@ -344,9 +345,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { case libkineto::ActivityType::CONCURRENT_KERNEL: case libkineto::ActivityType::CUDA_SYNC: case libkineto::ActivityType::GPU_USER_ANNOTATION: - case libkineto::ActivityType::CUDA_PROFILER_RANGE: - // TODO: T151322015 - case libkineto::ActivityType::MTIA_CCP_EVENTS: { + case libkineto::ActivityType::CUDA_PROFILER_RANGE: { // PrivateUse1 kineto backend reuse above ActivityTypes, // If PrivateUse1 backend enabled, this should return // c10::DeviceType::PrivateUse1. @@ -358,6 +357,20 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { }(); return device_type; } + // TODO: T151322015 + case libkineto::ActivityType::MTIA_CCP_EVENTS: + case libkineto::ActivityType::MTIA_WORKLOADD: { + // PrivateUse1 kineto backend reuse above ActivityTypes, + // If PrivateUse1 backend enabled, this should return + // c10::DeviceType::PrivateUse1. + c10::DeviceType device_type = []() { + if (c10::get_privateuse1_backend() != "privateuseone") { + return c10::DeviceType::PrivateUse1; + } + return c10::DeviceType::MTIA; + }(); + return device_type; + } case libkineto::ActivityType::CPU_OP: case libkineto::ActivityType::USER_ANNOTATION: case libkineto::ActivityType::EXTERNAL_CORRELATION: diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h index d960b287e20f..87c1f8837239 100644 --- a/torch/csrc/utils/byte_order.h +++ b/torch/csrc/utils/byte_order.h @@ -62,8 +62,7 @@ #error Unexpected or undefined __BYTE_ORDER__ #endif -namespace torch { -namespace utils { +namespace torch::utils { enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 }; @@ -223,5 +222,4 @@ TORCH_API void THP_encodeComplexDoubleBuffer( THPByteOrder order, size_t len); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/cuda_enabled.h b/torch/csrc/utils/cuda_enabled.h index e27c168a8ef4..0e3c2f30a83e 100644 --- a/torch/csrc/utils/cuda_enabled.h +++ b/torch/csrc/utils/cuda_enabled.h @@ -1,9 +1,8 @@ #pragma once -namespace torch { -namespace utils { +namespace torch::utils { -static inline bool cuda_enabled() { +inline constexpr bool cuda_enabled() { #ifdef USE_CUDA return true; #else @@ -11,5 +10,4 @@ static inline bool cuda_enabled() { #endif } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index 4d736898e535..79c05f3c9ada 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -26,21 +26,21 @@ namespace torch::utils { void device_lazy_init(at::DeviceType device_type); void set_requires_device_init(at::DeviceType device_type, bool value); -static inline void maybe_initialize_device(at::Device& device) { +inline void maybe_initialize_device(at::Device& device) { // Add more devices here to enable lazy initialization. if (device.is_cuda() || device.is_xpu() || device.is_privateuseone()) { device_lazy_init(device.type()); } } -static inline void maybe_initialize_device(std::optional& device) { +inline void maybe_initialize_device(std::optional& device) { if (!device.has_value()) { return; } maybe_initialize_device(device.value()); } -static inline void maybe_initialize_device(const at::TensorOptions& options) { +inline void maybe_initialize_device(const at::TensorOptions& options) { auto device = options.device(); maybe_initialize_device(device); } diff --git a/torch/csrc/utils/init.h b/torch/csrc/utils/init.h index bf6dd216bbcc..31b65470c18e 100644 --- a/torch/csrc/utils/init.h +++ b/torch/csrc/utils/init.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace throughput_benchmark { +namespace torch::throughput_benchmark { void initThroughputBenchmarkBindings(PyObject* module); -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark diff --git a/torch/csrc/utils/nested.h b/torch/csrc/utils/nested.h index f3a1061e4712..7683a2412418 100644 --- a/torch/csrc/utils/nested.h +++ b/torch/csrc/utils/nested.h @@ -5,13 +5,11 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { at::Tensor nested_tensor_ctor( c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/out_types.h b/torch/csrc/utils/out_types.h index 68bf759f3003..63d85dc8f5a9 100644 --- a/torch/csrc/utils/out_types.h +++ b/torch/csrc/utils/out_types.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { TORCH_API void check_out_type_matches( const at::Tensor& result, @@ -14,4 +13,3 @@ TORCH_API void check_out_type_matches( bool device_is_none); } -} // namespace torch diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 553738b8999b..a222feeaa22d 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -189,6 +190,35 @@ struct type_caster { } }; +template <> +struct type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); + + // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType + // cannot be default-initialized, we provide this constructor to explicitly + // initialize that field. The value doesn't matter as it will be overwritten + // after a successful call to load. + type_caster() : value(at::kFloat) {} + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPDtype_Check(obj)) { + value = reinterpret_cast(obj)->scalar_type; + return true; + } + return false; + } + + static handle cast( + const at::ScalarType& src, + return_value_policy /* policy */, + handle /* parent */) { + return Py_NewRef(torch::getTHPDtype(src)); + } +}; + template <> struct type_caster { public: @@ -206,7 +236,7 @@ struct type_caster { if (THPStream_Check(obj)) { value = c10::Stream::unpack3( ((THPStream*)obj)->stream_id, - ((THPStream*)obj)->device_index, + static_cast(((THPStream*)obj)->device_index), static_cast(((THPStream*)obj)->device_type)); return true; } @@ -225,7 +255,7 @@ template <> struct type_caster : public type_caster_base { using base = type_caster_base; - c10::DispatchKey tmp; + c10::DispatchKey tmp{}; public: bool load(handle src, bool convert) { diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index a7c53bfb0ad1..8966131f9825 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -77,8 +77,6 @@ #include #include #include -#include -#include #include #include @@ -224,6 +222,7 @@ struct PythonArgs { int idx; bool traceable; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const FunctionSignature& signature; PyObject** args; std::vector overloaded_args; // NOTE: borrowed references @@ -504,7 +503,7 @@ inline std::vector PythonArgs::intlist(int i) { return intlistWithDefault(i, signature.params[i].default_intlist); } -inline PyObject* toPyObject(c10::SymInt symint) { +inline PyObject* toPyObject(const c10::SymInt& symint) { if (symint.is_symbolic()) { auto r = py::cast(symint).release().ptr(); TORCH_INTERNAL_ASSERT(r); diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index e370923b398d..ec0af99842d2 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -255,7 +255,7 @@ void initDispatchBindings(PyObject* module) { .def("debug", &c10::OperatorHandle::debug) .def( "redispatch_boxed", - [](py::object self, + [](const py::object& self, c10::DispatchKeySet keyset, py::args args, const py::kwargs& kwargs) { @@ -819,7 +819,7 @@ void initDispatchBindings(PyObject* module) { auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); for (auto& op : op_names) { - std::cout << op << std::endl; + std::cout << op << '\n'; } }, py::arg("dispatch_key") = static_cast("")); diff --git a/torch/csrc/utils/python_dispatch.h b/torch/csrc/utils/python_dispatch.h index 9549b817ba6a..32d436d8347e 100644 --- a/torch/csrc/utils/python_dispatch.h +++ b/torch/csrc/utils/python_dispatch.h @@ -1,9 +1,7 @@ #include #include -namespace torch { -namespace impl { -namespace dispatch { +namespace torch::impl::dispatch { void initDispatchBindings(PyObject* module); @@ -14,6 +12,4 @@ void python_op_registration_trampoline_impl( torch::jit::Stack* stack, bool with_keyset); -} // namespace dispatch -} // namespace impl -} // namespace torch +} // namespace torch::impl::dispatch diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index 2a17afdf0e18..d5b772b768e2 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -57,8 +57,7 @@ inline bool THPUtils_checkLong(PyObject* obj) { } inline int32_t THPUtils_unpackInt(PyObject* obj) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int overflow; + int overflow = 0; long value = PyLong_AsLongAndOverflow(obj, &overflow); if (value == -1 && PyErr_Occurred()) { throw python_error(); @@ -74,8 +73,7 @@ inline int32_t THPUtils_unpackInt(PyObject* obj) { } inline int64_t THPUtils_unpackLong(PyObject* obj) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int overflow; + int overflow = 0; long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); if (value == -1 && PyErr_Occurred()) { throw python_error(); diff --git a/torch/csrc/utils/python_raii.h b/torch/csrc/utils/python_raii.h index 411e558715e8..bc7b5c263e0d 100644 --- a/torch/csrc/utils/python_raii.h +++ b/torch/csrc/utils/python_raii.h @@ -2,8 +2,7 @@ #include #include -namespace torch { -namespace impl { +namespace torch::impl { template struct RAIIContextManager { @@ -37,9 +36,9 @@ void py_context_manager(const py::module& m, const char* name) { .def( "__exit__", [](ContextManagerT& guard, - py::object exc_type, - py::object exc_value, - py::object traceback) { guard.exit(); }); + const py::object& exc_type, + const py::object& exc_value, + const py::object& traceback) { guard.exit(); }); } template @@ -77,10 +76,9 @@ void py_context_manager_DEPRECATED(const py::module& m, const char* name) { .def( "__exit__", [](ContextManagerT& guard, - py::object exc_type, - py::object exc_value, - py::object traceback) { guard.exit(); }); + const py::object& exc_type, + const py::object& exc_value, + const py::object& traceback) { guard.exit(); }); } -} // namespace impl -} // namespace torch +} // namespace torch::impl diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index 2819f56b6bab..997425ac7de2 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { template inline T unpackIntegral(PyObject* obj, const char* type) { @@ -159,5 +158,4 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index a2754ef4610b..cca161399c44 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -100,8 +100,7 @@ inline void THPUtils_internStringInPlace(PyObject** obj) { * */ -// NOLINTNEXTLINE(clang-diagnostic-unused-function) -static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { +inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { PyTypeObject* tp = Py_TYPE(obj); PyObject* res = (PyObject*)nullptr; diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index f8c710cf6579..15738b1a67e1 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode float_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode float_pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow_by_natural(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode int_floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/csrc/utils/python_torch_function_mode.h b/torch/csrc/utils/python_torch_function_mode.h index f6652dfd9308..f0e6bb9acbe9 100644 --- a/torch/csrc/utils/python_torch_function_mode.h +++ b/torch/csrc/utils/python_torch_function_mode.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace overrides { +namespace torch::overrides { struct StashTorchFunctionModeGuard { StashTorchFunctionModeGuard() { @@ -21,5 +20,4 @@ struct StashTorchFunctionModeGuard { std::shared_ptr cur_mode_; }; -} // namespace overrides -} // namespace torch +} // namespace torch::overrides diff --git a/torch/csrc/utils/schema_info.h b/torch/csrc/utils/schema_info.h index acda1bffc153..18aaa9bc7d35 100644 --- a/torch/csrc/utils/schema_info.h +++ b/torch/csrc/utils/schema_info.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { using SchemaSpecialCasePair = std::pair>; @@ -113,5 +112,4 @@ struct TORCH_API SchemaInfo { bool has_init_; }; -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/structseq.h b/torch/csrc/utils/structseq.h index 0d91d39d34be..60e3429b50cd 100644 --- a/torch/csrc/utils/structseq.h +++ b/torch/csrc/utils/structseq.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { PyObject* returned_structseq_repr(PyStructSequence* obj); } -} // namespace torch diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index ffb2c5801751..906b5422b373 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -10,8 +10,7 @@ using namespace at; -namespace torch { -namespace utils { +namespace torch::utils { struct StridedData { StridedData(const Tensor& tensor) @@ -129,5 +128,4 @@ const Tensor& map2_( return self; } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_apply.h b/torch/csrc/utils/tensor_apply.h index bd06e0f3e30b..0e721542fe69 100644 --- a/torch/csrc/utils/tensor_apply.h +++ b/torch/csrc/utils/tensor_apply.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { const at::Tensor& apply_(const at::Tensor& self, PyObject* fn); const at::Tensor& map_( @@ -17,5 +16,4 @@ const at::Tensor& map2_( const at::Tensor& y_, PyObject* fn); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 200e04eaddb0..5290392d900f 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -1,14 +1,11 @@ #include #include #include -#include #include #include #include -#include -namespace torch { -namespace utils { +namespace torch::utils { std::pair getDtypeNames(at::ScalarType scalarType) { switch (scalarType) { @@ -125,5 +122,4 @@ void initializeDtypes() { } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_dtypes.h b/torch/csrc/utils/tensor_dtypes.h index 32b769971d03..9a947b380e92 100644 --- a/torch/csrc/utils/tensor_dtypes.h +++ b/torch/csrc/utils/tensor_dtypes.h @@ -1,15 +1,13 @@ #pragma once -#include +#include #include #include -namespace torch { -namespace utils { +namespace torch::utils { std::pair getDtypeNames(at::ScalarType scalarType); void initializeDtypes(); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_flatten.cpp b/torch/csrc/utils/tensor_flatten.cpp index 396a6e8a3a8e..fb06ad884d7e 100644 --- a/torch/csrc/utils/tensor_flatten.cpp +++ b/torch/csrc/utils/tensor_flatten.cpp @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { using namespace at; @@ -123,5 +122,4 @@ std::vector unflatten_sparse_tensors( return outputs; } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_flatten.h b/torch/csrc/utils/tensor_flatten.h index 04a55ec7960e..2b65403fb0de 100644 --- a/torch/csrc/utils/tensor_flatten.h +++ b/torch/csrc/utils/tensor_flatten.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { /// Generate an ID for a combination of tensor backend + scalar type to be used /// when ordering tensors ('like' tensors are grouped by pulling out their @@ -82,5 +81,4 @@ TORCH_API std::vector unflatten_sparse_tensors( const at::Tensor& flat_values, at::TensorList tensors); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_layouts.cpp b/torch/csrc/utils/tensor_layouts.cpp index b403f9130bd9..be8816c8a9ab 100644 --- a/torch/csrc/utils/tensor_layouts.cpp +++ b/torch/csrc/utils/tensor_layouts.cpp @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { #define REGISTER_LAYOUT(layout, LAYOUT) \ PyObject* layout##_layout = \ @@ -55,5 +54,4 @@ void initializeLayouts() { REGISTER_LAYOUT(jagged, Jagged); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_layouts.h b/torch/csrc/utils/tensor_layouts.h index 33e32b516b12..7ee7b848cadb 100644 --- a/torch/csrc/utils/tensor_layouts.h +++ b/torch/csrc/utils/tensor_layouts.h @@ -1,9 +1,7 @@ #pragma once -namespace torch { -namespace utils { +namespace torch::utils { void initializeLayouts(); } -} // namespace torch diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index c72de0b5e9e0..84f4688e0ecc 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -9,8 +9,7 @@ using namespace at; -namespace torch { -namespace utils { +namespace torch::utils { static PyObject* recursive_to_list( const char* data, @@ -66,5 +65,4 @@ PyObject* tensor_to_list(const Tensor& tensor) { tensor.numel() == 0 ? 0 : data.dtype().itemsize()); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_list.h b/torch/csrc/utils/tensor_list.h index 8ae77df4700a..8580631921b7 100644 --- a/torch/csrc/utils/tensor_list.h +++ b/torch/csrc/utils/tensor_list.h @@ -6,10 +6,8 @@ namespace at { class Tensor; } -namespace torch { -namespace utils { +namespace torch::utils { PyObject* tensor_to_list(const at::Tensor& tensor); } -} // namespace torch diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp index 63dafaf5f5ff..28d56291bc94 100644 --- a/torch/csrc/utils/tensor_memoryformats.cpp +++ b/torch/csrc/utils/tensor_memoryformats.cpp @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { namespace { // Intentionally leaked @@ -50,5 +49,4 @@ void initializeMemoryFormats() { add_memory_format(at::MemoryFormat::ChannelsLast3d, "channels_last_3d"); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index 70a4fbca0bac..088f8d1927c4 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -5,8 +5,7 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { // NOTE: [torch.tensor, lift_fresh, and device movement] // @@ -134,5 +133,4 @@ at::Tensor asarray( std::optional device, std::optional copy, bool requires_grad); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 9b07b9d32f1c..6014281061bc 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -5,8 +5,8 @@ #include #ifndef USE_NUMPY -namespace torch { -namespace utils { + +namespace torch::utils { PyObject* tensor_to_numpy(const at::Tensor&, bool) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } @@ -40,8 +40,8 @@ void validate_numpy_for_dlpack_deleter_bug() {} bool is_numpy_dlpack_deleter_bugged() { return false; } -} // namespace utils -} // namespace torch +} // namespace torch::utils + #else #include diff --git a/torch/csrc/utils/tensor_qschemes.cpp b/torch/csrc/utils/tensor_qschemes.cpp index 9e9d6dbdcfce..4c2e6f20557e 100644 --- a/torch/csrc/utils/tensor_qschemes.cpp +++ b/torch/csrc/utils/tensor_qschemes.cpp @@ -9,11 +9,10 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -static PyObject* thp_qscheme_array[at::COMPILE_TIME_NUM_QSCHEMES]; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static std::array thp_qscheme_array; void initializeQSchemes() { auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); @@ -40,6 +39,4 @@ PyObject* getTHPQScheme(at::QScheme qscheme) { } return qscheme_; } - -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_qschemes.h b/torch/csrc/utils/tensor_qschemes.h index 71e65479047b..dc982efd1ff9 100644 --- a/torch/csrc/utils/tensor_qschemes.h +++ b/torch/csrc/utils/tensor_qschemes.h @@ -1,11 +1,9 @@ #pragma once #include -namespace torch { -namespace utils { +namespace torch::utils { PyObject* getTHPQScheme(at::QScheme qscheme); void initializeQSchemes(); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_types.h b/torch/csrc/utils/tensor_types.h index 601cc920a2e7..a4b905604da6 100644 --- a/torch/csrc/utils/tensor_types.h +++ b/torch/csrc/utils/tensor_types.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { std::string options_to_string(const at::TensorOptions& options); std::string type_to_string(const at::DeprecatedTypeProperties& type); @@ -18,5 +17,4 @@ std::vector> all_declared_types(); // return python module name of backend, like torch.cuda, torch.foo const char* backend_to_string(const at::Backend& backend); -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/throughput_benchmark-inl.h b/torch/csrc/utils/throughput_benchmark-inl.h index 4334a58683bb..ead63d585a05 100644 --- a/torch/csrc/utils/throughput_benchmark-inl.h +++ b/torch/csrc/utils/throughput_benchmark-inl.h @@ -12,9 +12,7 @@ #include #include -namespace torch { -namespace throughput_benchmark { -namespace detail { +namespace torch::throughput_benchmark::detail { template BenchmarkExecutionStats BenchmarkHelper::benchmark( @@ -156,6 +154,4 @@ BenchmarkExecutionStats BenchmarkHelper::benchmark( return stats; } -} // namespace detail -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark::detail diff --git a/torch/csrc/utils/throughput_benchmark.h b/torch/csrc/utils/throughput_benchmark.h index 2fca95ca16bf..5ec44e012631 100644 --- a/torch/csrc/utils/throughput_benchmark.h +++ b/torch/csrc/utils/throughput_benchmark.h @@ -14,8 +14,7 @@ namespace py = pybind11; -namespace torch { -namespace throughput_benchmark { +namespace torch::throughput_benchmark { /** * The struct is used to provide results of a benchmark to the caller @@ -193,7 +192,6 @@ class C10_HIDDEN ThroughputBenchmark { detail::ScriptModuleBenchmark script_module_; detail::ModuleBenchmark module_; }; -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark #include diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index d1c1392e37d6..8ca451143573 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace torch_dispatch_mode { +namespace torch::torch_dispatch_mode { struct StashTorchDispatchModeGuard { public: @@ -54,5 +53,4 @@ struct StashTorchDispatchStackGuard { c10::impl::TorchDispatchModeTLS saved_state_; }; -} // namespace torch_dispatch_mode -} // namespace torch +} // namespace torch::torch_dispatch_mode diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h index 78ffe2997142..0f3dc992c61d 100644 --- a/torch/csrc/utils/variadic.h +++ b/torch/csrc/utils/variadic.h @@ -4,8 +4,6 @@ #include #include -#include -#include #include #include diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 7bf8abdef204..cfe7b43d19a9 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -11,24 +11,30 @@ #include #include +#ifndef WIN32 #include +#endif using namespace torch; static bool in_bad_fork = false; // True for children forked after xpu init +#ifndef WIN32 // Called in the forked child if xpu has already been initialized static void forked_child() { in_bad_fork = true; torch::utils::set_requires_device_init(at::kXPU, true); } +#endif // Should be called before the first xpu call. It is mainly called in lazy_init. // Note: This is distinct from initExtension because a stub xpu implementation // has some working functions (e.g. device_count) but cannot fully initialize. static void poison_fork() { +#ifndef WIN32 static c10::once_flag flag; c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); }); +#endif } // XPU management methods diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 8c19788d1055..6722114e295b 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package adds support for CUDA tensor types. @@ -420,7 +421,7 @@ def get_device_name(device: Optional[_device_t] = None) -> str: r"""Get the name of a device. Args: - device (torch.device or int, optional): device for which to return the + device (torch.device or int or str, optional): device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). @@ -435,7 +436,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int] r"""Get the cuda capability of a device. Args: - device (torch.device or int, optional): device for which to return the + device (torch.device or int or str, optional): device for which to return the device capability. This function is a no-op if this argument is a negative integer. It uses the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` @@ -629,6 +630,8 @@ def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: def _raw_device_count_amdsmi() -> int: + if not _HAS_PYNVML: # If amdsmi is not available + return -1 try: amdsmi.amdsmi_init() except amdsmi.AmdSmiException as e: @@ -659,6 +662,8 @@ def _raw_device_count_nvml() -> int: def _raw_device_uuid_amdsmi() -> Optional[List[str]]: from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer + if not _HAS_PYNVML: # If amdsmi is not available + return None try: amdsmi.amdsmi_init() except amdsmi.AmdSmiException: @@ -1043,7 +1048,7 @@ def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: handle = _get_amdsmi_handler(device) - return amdsmi.amdsmi_get_power_info(handle)["average_socket_power"] + return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"] def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: @@ -1462,7 +1467,7 @@ def addmm_kernel_impl(*args, **kwargs): _lazy_call(_register_triton_kernels) -from . import amp, jiterator, nvtx, profiler, sparse +from . import amp, jiterator, nvtx, profiler, sparse, tunable __all__ = [ # Typed storage and tensors @@ -1575,5 +1580,6 @@ def addmm_kernel_impl(*args, **kwargs): "stream", "streams", "synchronize", + "tunable", "utilization", ] diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index a44854d1524c..2047ec4efb28 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import pickle import sys import os @@ -382,7 +383,11 @@ def find_segment(addr): def _format_viz(data, viz_kind, device): if device is not None: - warnings.warn('device argument is deprecated, plots now contain all device', FutureWarning) + warnings.warn( + 'device argument is deprecated, plots now contain all device', + FutureWarning, + stacklevel=3, + ) buffer = pickle.dumps(data) buffer += b'\x00' * (3 - len(buffer) % 3) # Encode the buffer with base64 diff --git a/torch/cuda/_sanitizer.py b/torch/cuda/_sanitizer.py index 89766ba8c1a4..bf72f277dd8a 100644 --- a/torch/cuda/_sanitizer.py +++ b/torch/cuda/_sanitizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams. diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index e50206c70577..049ff41c590f 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any from typing_extensions import deprecated @@ -50,6 +51,16 @@ def __call__(self, func): return super().__call__(func) +# Preserved only for BC reasons +@deprecated( + "`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. " + "Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.", + category=FutureWarning, +) +def _cast(value, dtype): + return torch.amp.autocast_mode._cast(value, "cuda", dtype) + + @deprecated( "`torch.cuda.amp.custom_fwd(args...)` is deprecated. " "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.", diff --git a/torch/cuda/amp/common.py b/torch/cuda/amp/common.py index c4e8c1cc99b0..30ccaeede8d9 100644 --- a/torch/cuda/amp/common.py +++ b/torch/cuda/amp/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from importlib.util import find_spec import torch diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index 367f21594f1c..c108e7f49a01 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -2,6 +2,9 @@ import torch +# We need to keep this unused import for BC reasons +from torch.amp.grad_scaler import OptState # noqa: F401 + __all__ = ["GradScaler"] diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 9d9df283ced6..78c572a1822d 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gc import typing diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index 1be552555945..294670f8819e 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re from typing import Callable, List diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 0f12395ac778..9634d1c0d80b 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""This package adds support for device memory management implemented in CUDA.""" import collections diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index 4170e20b5318..4c28443c9e29 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import warnings from typing import Optional, Sequence, Union @@ -32,6 +33,15 @@ def is_available(tensors): def version(): + """ + Returns the version of the NCCL. + + + This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. + The suffix is also included in the tuple if a version suffix exists. + Returns: + tuple: The version information of the NCCL. + """ ver = torch._C._nccl_version() major = ver >> 32 minor = (ver >> 16) & 65535 @@ -92,6 +102,7 @@ def reduce( "`nccl.reduce` with an output tensor list is deprecated. " "Please specify a single output tensor with argument 'output' instead instead.", FutureWarning, + stacklevel=2, ) _output = outputs[root] elif not isinstance(output, torch.Tensor) and isinstance( @@ -102,6 +113,7 @@ def reduce( "nccl.reduce with an output tensor list is deprecated. " "Please specify a single output tensor.", FutureWarning, + stacklevel=2, ) _output = output[root] else: diff --git a/torch/cuda/nvtx.py b/torch/cuda/nvtx.py index 4b902c0c6d4d..195509687905 100644 --- a/torch/cuda/nvtx.py +++ b/torch/cuda/nvtx.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling.""" from contextlib import contextmanager diff --git a/torch/cuda/profiler.py b/torch/cuda/profiler.py index 51c8aa46f714..7e5dc9bab8de 100644 --- a/torch/cuda/profiler.py +++ b/torch/cuda/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import tempfile diff --git a/torch/cuda/random.py b/torch/cuda/random.py index 1cf33114d17b..b736c9d959d8 100644 --- a/torch/cuda/random.py +++ b/torch/cuda/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Iterable, List, Union import torch diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index d36121381586..89271b588711 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ctypes import torch diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py new file mode 100644 index 000000000000..8b387102b43d --- /dev/null +++ b/torch/cuda/tunable.py @@ -0,0 +1,242 @@ +r""" +This module exposes a TunableOp interface. + +Some operations, such as GEMMs, could be implemented using more than one library +or more than one technique. For example, a GEMM could be implemented for CUDA or +ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and +hipblaslt libraries allow the user to query for all possible algorithms and then +choose one. How does one know which implementation is the fastest and should be +chosen? That's what TunableOp provides. + +Enabling TunableOp and Tuning Separately +======================================== + +The TunableOp feature is enabled separately from enabling the tuning phase +itself. Enabling TunableOp means that PyTorch will replace any standard +operators with their Tunable implementations. Any call to a TunableOp first +checks whether it has already been tuned for the given operator inputs. If so, +it will immediately call the tuned operation; no further tuning will take place +even when the tuning setting is enabled. Instead if no tuning result is found, +and tuning is enabled, the TunableOp will benchmark every registered +implementation of that operator for the given set of inputs and select the +fastest. + +File Input and Output +===================== + +The first time any TunableOp is invoked, the internal database of tuned +operations will be prepared by attempting to read the results from the given +file. The default filename is 'tunableop_results.csv'. To support tuning when +multiple GPUs are used across multiple processes, the GPU device ordinal is +automatically inserted into the filename to avoid multiple processes overwriting +the same file. + +If tuning is enabled and new tunings are discovered during the course of your +workload, it will also write out to this same filename with all tunings, both +the ones it read in at startup as well as the new ones found at runtime. This +can be used, for example, to build up a tunings file across many workloads by +reusing the same file. The output file is automatically created when the +application terminates. This behavior can be controlled by the C++ and Python +APIs but not the environment variables. + +Assuming you specified a filename, you'll end up with a CSV file with contents +like so:: + + Validator,PT_VERSION,2.2.0 + Validator,ROCM_VERSION,6.0.0.0-12969-1544e39 + Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7 + Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty + GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 + GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 + +Note the "Validator" lines. If you change a library verison, or ROCm version, or +PyTorch version, TunableOp will detect this and reject the tunings file because +the prior tunings are likely affected by other software changes. + +The remaining lines are the tuned solutions for each TunableOp encountered +during your execution. Each line consists of 4 comma-separated fields: operator +name, operator parameters, solution name, and average execution time. The +execution time is an optional field. The CSV file can be edited, but with +caution. For example, the solution name (field 3) can be changed to "Default" +and it will fall back to the original PyTorch untuned implementation. Or, in the +case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution +index you can override the solution that TunableOp selected by replacing the +value. The operator name and parameters (fields 1 and 2) are internally named +and should not be modified. In the case of GemmTunableOp, field 1 indicates the +datatype and whether the inputs are transposed (T) or not (N) and field 2 +indicates the M, N, K input shapes. + +There is an option to enable verbose output but it is only recommended for +debugging purposes. This will produce a lot of diagnostic messages but may be +useful to see if TunableOp is being used at all. Otherwise, TunableOp is +completely silent, besides file output, unless there is a warning or error +during its use. The verbose option is only available by setting the environment +variable PYTORCH_TUNABLEOP_VEROBSE=1. + +A Note on Tuning Behavior +========================= + +Tuning an operator consists of iterating through the list or registered +implementations and profiling each one. The profile is established by running a +single implementation in a loop multiple times and taking the average execution +time. + +By default, each possible solution for a given operator will be run for either +100 iterations or as many iterations that can be run within 30ms, whichever is +smaller, and its average execution will be calculated. The fastest solution +among all that were successfully profiled will be chosen. A profile might fail +if the given solution doesn't achieve the same accuracy as the default +implementation or if the solution returns an error code. + +Current Tunable Operators +========================= + +TunableGemm for ROCm +-------------------- + +Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of +PyTorch will function correctly when using TunableOp but the only solution +available to CUDA builds is the 'Default' implementation i.e. the original +cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm() +or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a +given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation across both rocblas and hipblaslt. + +Tuning Context +============== + +The behavior of TunableOp is currently manipulated through environment +variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the +torch.cuda.tunable python interfaces that wrap the C++ TuningContext. The +environment variables take precedence over any setting you manipulate using the +C++ or Python APIs. + +""" +from typing import Optional, Tuple + +import torch + + +__all__ = [ + "enable", + "is_enabled", + "tuning_enable", + "tuning_is_enabled", + "set_max_tuning_duration", + "get_max_tuning_duration", + "set_max_tuning_iterations", + "get_max_tuning_iterations", + "set_filename", + "get_filename", + "get_results", + "get_validators", + "write_file_on_exit", + "write_file", + "read_file", +] + + +def enable(val: bool = True) -> None: + r"""This is the big on/off switch for all TunableOp implementations.""" + torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined] + + +def is_enabled() -> bool: + r"""Returns whether the TunableOp feature is enabled.""" + return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined] + + +def tuning_enable(val: bool = True) -> None: + r"""Enable tuning of TunableOp implementations. + + When enabled, if a tuned entry isn't found, run the tuning step and record + the entry. + """ + torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined] + + +def tuning_is_enabled() -> bool: + r"""Returns whether TunableOp implementations can be tuned.""" + return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined] + + +def set_max_tuning_duration(duration: int) -> None: + r"""Set max time in milliseconds to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined] + + +def get_max_tuning_duration() -> int: + r"""Get max time to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined] + + +def set_max_tuning_iterations(iterations: int) -> None: + r"""Set max number of iterations to spend tuning a given solution. + + If both max tuning duration and iterations are set, the smaller of the two + will be honored. At minimum 1 tuning iteration will always be run. + """ + torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined] + + +def get_max_tuning_iterations() -> int: + r"""Get max iterations to spend tuning a given solution.""" + return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined] + + +def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: + r"""Set the filename to use for input/output of tuning results. + + If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal + will be added to the given filename automatically. This can be used in a + 1-process-per-gpu cenario to ensure all processes write to a separate file. + """ + torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined] + + +def get_filename() -> str: + r"""Get the results filename.""" + return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined] + + +def get_results() -> Tuple[str, str, str, float]: + r"""Return all TunableOp results.""" + return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined] + + +def get_validators() -> Tuple[str, str]: + r"""Return the TunableOp validators.""" + return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] + + +def write_file_on_exit(val: bool) -> None: + r"""During Tuning Context destruction, write file to disk. + + This is useful as a final flush of your results to disk if your application + terminates as result of normal operation or an error. Manual flushing of + your results can be achieved by manually calling ``write_file()``.""" + torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined] + + +def write_file(filename: Optional[str] = None) -> bool: + r"""Write results to a CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined] + + +def read_file(filename: Optional[str] = None) -> bool: + r"""Read results from a TunableOp CSV file. + + If :attr:`filename` is not given, ``get_filename()`` is called. + """ + if filename is None: + filename = get_filename() + return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined] diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 3e7dce97b54c..b8e911c8738c 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import sys from enum import Enum diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 8accef6afc34..6716f43a74a0 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager, nullcontext from typing import Any, Tuple diff --git a/torch/distributed/_composable/contract.py b/torch/distributed/_composable/contract.py index 2a6983023f76..6693fa9608df 100644 --- a/torch/distributed/_composable/contract.py +++ b/torch/distributed/_composable/contract.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import uuid from collections import OrderedDict from functools import wraps diff --git a/torch/distributed/_composable/fsdp/_fsdp_api.py b/torch/distributed/_composable/fsdp/_fsdp_api.py index 2bf0278ed488..aa6b5e803b80 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_api.py +++ b/torch/distributed/_composable/fsdp/_fsdp_api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import Optional diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index b7264cb34d6d..1423cfd600fc 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -1,7 +1,9 @@ from typing import List, NamedTuple, Optional, Tuple, Union import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist +from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp from ._fsdp_common import ( _get_dim0_padded_size, @@ -24,6 +26,98 @@ class AllGatherResult(NamedTuple): all_gather_input_split_sizes: List[int] +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define( + """ + all_gather_copy_in( + Tensor[] all_gather_inputs, + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt world_size, + SymInt rank, + ScalarType dtype, + Device device + ) -> (Tensor, Tensor) + """ +) + + +@torch.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device="meta" + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + return all_gather_input, all_gather_output + + +@torch.library.impl(lib, "all_gather_copy_in", "CUDA") +def all_gather_copy_in_cuda( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + with torch.no_grad(): + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + + +lib.define( + "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" +) + + +@torch.library.impl(lib, "split_with_sizes_copy", "Meta") +@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") +def split_with_sizes_copy( + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: List[int], + dim: int, + out: List[torch.Tensor], +) -> None: + torch.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) + + +lib.define( + "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" +) + + +@torch.library.impl(lib, "chunk_cat", "Meta") +@torch.library.impl(lib, "chunk_cat", "CUDA") +def chunk_cat( + tensors: List[torch.Tensor], + dim: int, + num_chunks: int, + out: torch.Tensor, +) -> None: + torch._chunk_cat(tensors, dim, num_chunks, out=out) + + @torch.no_grad() def foreach_all_gather( fsdp_params: List[FSDPParam], @@ -51,14 +145,15 @@ def foreach_all_gather( all_gather_inputs = [t for ts in param_all_gather_inputs for t in ts] inp_split_sizes = [t.numel() for t in all_gather_inputs] all_gather_input_numel = sum(inp_split_sizes) - all_gather_output = torch.empty( - (all_gather_input_numel * world_size,), dtype=dtype, device=device - ) - all_gather_input = all_gather_output.narrow( - 0, all_gather_input_numel * rank, all_gather_input_numel + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + world_size, + rank, + dtype, + device, ) - foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) - torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) del param_all_gather_inputs all_gather_stream.wait_stream(all_gather_copy_in_stream) with torch.cuda.stream(all_gather_stream): @@ -101,17 +196,28 @@ def foreach_all_gather_copy_out( for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params ): - fsdp_param.init_all_gather_outputs( - all_gather_input_numels, all_gather_input_dtypes, world_size, device - ) # no-op after 1st call - fsdp_param.alloc_all_gather_outputs() + if ca.compiled_autograd_enabled: + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, + all_gather_input_dtypes, + world_size, + device, + # NOTE: Under compile, make sure we always recreate all_gather_outputs + # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. + force_recreate=True, + ) + else: + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, all_gather_input_dtypes, world_size, device + ) # no-op after 1st call + fsdp_param.alloc_all_gather_outputs() all_gather_output = all_gather_output.view(world_size, -1) gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs) if all_gather_output.dtype == torch.uint8: out = [t.view(world_size, -1).view(torch.uint8) for t in gen] else: out = [t.view(world_size, -1) for t in gen] - torch.split_with_sizes_copy( + torch.ops.fsdp.split_with_sizes_copy( all_gather_output, all_gather_input_split_sizes, dim=1, out=out ) @@ -222,10 +328,13 @@ def foreach_reduce( # Record an event on which to block the CPU thread to # ensure that the D2H copy finishes before the optimizer fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() - new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(new_sharded_grad) if to_accumulate_grad: - fsdp_param.sharded_param.grad += new_sharded_dtensor_grad + assert isinstance(fsdp_param.sharded_param.grad, DTensor) + fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad else: + new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( + new_sharded_grad + ) fsdp_param.sharded_param.grad = new_sharded_dtensor_grad padded_sharded_numel = padded_unsharded_size.numel() // world_size flat_grad_offset += padded_sharded_numel @@ -243,7 +352,7 @@ def foreach_reduce_scatter_copy_in( world_size: int, ) -> None: reduce_scatter_input = reduce_scatter_input.view(world_size, -1) - torch._chunk_cat( + torch.ops.fsdp.chunk_cat( unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input ) diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 1395e3487847..594ec483bd3b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -1,15 +1,18 @@ +# mypy: allow-untyped-defs import math import traceback from dataclasses import dataclass from enum import auto, Enum -from typing import Any, cast, List, Optional, Tuple +from typing import Any, cast, List, Optional import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry -from torch.distributed._tensor import DeviceMesh, DTensor, Placement +from torch.distributed._tensor import DeviceMesh, DTensor +from torch.distributed._tensor.placement_types import DTensorSpec @dataclass @@ -32,9 +35,7 @@ def __post_init__(self): if self.shard_mesh_dim is None: raise AssertionError("Expects non-None shard_mesh_dim") self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim) - self.shard_process_group = cast( - dist.ProcessGroup, self.mesh.get_group(self.shard_mesh_dim) - ) + self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) self.shard_mesh_rank: int = self.shard_process_group.rank() @@ -45,9 +46,7 @@ def __post_init__(self): if self.replicate_mesh_dim is None: raise AssertionError("Expects non-None replicate_mesh_dim") self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim) - self.replicate_process_group = cast( - dist.ProcessGroup, self.mesh.get_group(self.replicate_mesh_dim) - ) + self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) self.replicate_mesh_rank: int = self.replicate_process_group.rank() @@ -111,34 +110,28 @@ def _get_dim0_chunked_size( def _from_local_no_grad( local_tensor: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], - global_size: torch.Size, - global_stride: Tuple[int, ...], + sharding_spec: DTensorSpec, ) -> DTensor: """ This method is similar to ``DTensor.from_local()`` except that in eager mode it avoids some CPU overhead by avoiding default args and not being differentiable. """ - if not torch._dynamo.compiled_autograd.compiled_autograd_enabled: + + if not ca.compiled_autograd_enabled: return DTensor( # Use the local tensor directly instead of constructing a new tensor # variable, e.g. with `view_as()`, since this is not differentiable local_tensor, - device_mesh, - placements, - shape=global_size, - dtype=local_tensor.dtype, + sharding_spec, requires_grad=local_tensor.requires_grad, - stride=global_stride, ) else: return DTensor.from_local( local_tensor, - device_mesh, - placements, - shape=global_size, - stride=global_stride, + sharding_spec.mesh, + sharding_spec.placements, + shape=sharding_spec.shape, + stride=sharding_spec.stride, ) diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index f0d64aa3e8f1..c56dc79e266b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -1,16 +1,18 @@ +# mypy: allow-untyped-defs import itertools from dataclasses import dataclass, field from enum import auto, Enum from typing import Any, cast, List, Optional, Sequence, Tuple import torch +import torch._dynamo.compiled_autograd as ca import torch.nn as nn from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DTensor, Placement, Replicate, Shard +from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed._tensor.device_mesh import _mesh_resources -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( _chunk_with_empty, @@ -127,12 +129,10 @@ class FSDPParam: _sharded_post_forward_param: Optional[nn.Parameter] # ND _unsharded_param: nn.Parameter # ND unsharded_accumulated_grad: Optional[torch.Tensor] # ND - _global_placements: Tuple[Placement, ...] - _global_size: torch.Size - _global_stride: Tuple[int, ...] - all_gather_outputs: List[torch.Tensor] # 1D + _sharding_spec: DTensorSpec # DTensor attributes (only defined for DTensor `param`): _tp_spec: DTensorSpec + all_gather_outputs: List[torch.Tensor] # 1D # All-gather extension attributes _extensions_data: ExtensionsData _unsharded_inner_tensors: List[torch.Tensor] @@ -199,41 +199,48 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" ) - self._global_mesh = dp_global_mesh + + name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" + assert dp_mesh.mesh_dim_names is not None, name_dims_error + assert tp_mesh.mesh_dim_names is not None, name_dims_error + + submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names + self._spmd_mesh = dp_global_mesh[submesh_names] if len(self._tp_spec.placements) != 1: raise NotImplementedError( f"FSDP only supports 1D TP, not {self._tp_spec.placements}" ) - global_placements: List[Placement] = [Replicate(), Replicate()] - global_dp_mesh_dim = _mesh_resources.get_parent_mesh_dim(dp_mesh) - global_tp_mesh_dim = _mesh_resources.get_parent_mesh_dim(tp_mesh) - assert global_dp_mesh_dim is not None # mypy - assert global_tp_mesh_dim is not None # mypy - # for PP, DP, TP case, dp mesh dim would be 1, tp mesh dim would be 2 - # DP/TP would only live in the inner most 2-3 dims (HSDP + TP would be 3) - dp_tp_mesh_ndim = dp_mesh.ndim + tp_mesh.ndim - outer_mesh_ndim = self._global_mesh.ndim - dp_tp_mesh_ndim - if self._global_mesh.ndim > dp_tp_mesh_ndim: - global_dp_mesh_dim = global_dp_mesh_dim - outer_mesh_ndim - global_tp_mesh_dim = global_tp_mesh_dim - outer_mesh_ndim # TODO: Hard code FSDP + TP; need to support HSDP + TP - global_placements[global_dp_mesh_dim] = Shard(0) - global_placements[global_tp_mesh_dim] = self._tp_spec.placements[0] - self._global_placements = tuple(global_placements) - self._global_size = param.size() - self._global_stride = param.stride() + self._spmd_placements: Tuple[Placement, ...] = ( + Shard(0), + self._tp_spec.placements[0], + ) + + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._tp_spec.tensor_meta, + ) param_data = cast(DTensor, param)._local_tensor else: - self._global_mesh = self.mesh_info.mesh + self._spmd_mesh = self.mesh_info.mesh if isinstance(self.mesh_info, HSDPMeshInfo): - self._global_placements = (Replicate(), Shard(0)) + self._spmd_placements = (Replicate(), Shard(0)) else: - self._global_placements = (Shard(0),) - self._global_size = param.size() - self._global_stride = param.stride() + self._spmd_placements = (Shard(0),) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta( + param.size(), + param.stride(), + param.dtype, + ), + ) param_data = param self._orig_size = param_data.size() + self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) shard_rank = self.mesh_info.shard_mesh_rank shard_world_size = self.mesh_info.shard_mesh_size chunks = _chunk_with_empty(param_data, shard_world_size, dim=0) @@ -305,8 +312,9 @@ def init_all_gather_outputs( all_gather_input_dtypes: List[torch.dtype], world_size: int, device: torch.device, + force_recreate: bool = False, ): - if self.all_gather_outputs: + if not force_recreate and len(self.all_gather_outputs) > 0: return # already initialized self.all_gather_outputs = [ torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) @@ -314,7 +322,24 @@ def init_all_gather_outputs( ] def init_unsharded_param(self): - if hasattr(self, "_unsharded_param"): # after the 1st all-gather + """ + [Note: Invariants for torch.compile Traceable FSDP2] + 1. Under compile, we always re-populate the content of `self._unsharded_param` + per AllGather using the slow path. + 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. + This is to ensure the buffer creation is internal to the graph and + avoid `self.all_gather_outputs` being captured as a graph input. + 3. Under compile, at the end of `free_unsharded_param()`, we always clean up + `self.all_gather_outputs` and `self._unsharded_inner_tensors`, + to avoid them being captured as graph output. + + With these invariants, only these tensors will be inputs to the graph: + - Sharded parameters + - Placeholders for the `self._unsharded_param` nn.Parameter + """ + if not ca.compiled_autograd_enabled and hasattr( + self, "_unsharded_param" + ): # after the 1st all-gather inner_tensor = self._sharded_local_tensor if not hasattr(inner_tensor, "fsdp_post_all_gather"): return # already initialized @@ -330,7 +355,9 @@ def init_unsharded_param(self): self._extensions_data.clear() return inner_tensor = self._sharded_local_tensor - if hasattr(inner_tensor, "fsdp_post_all_gather"): + if not ca.compiled_autograd_enabled and hasattr( + inner_tensor, "fsdp_post_all_gather" + ): all_gather_outputs = self._unflatten_all_gather_outputs() ( unsharded_tensor, @@ -349,19 +376,20 @@ def init_unsharded_param(self): unsharded_param = torch.as_strided( unsharded_tensor, self._orig_size, - make_contiguous_strides_for(self._orig_size), + self._contiguous_orig_stride, storage_offset=0, ) if self.is_dtensor: - unsharded_param = _from_local_no_grad( - unsharded_param, - self._tp_spec.mesh, - self._tp_spec.placements, - self._global_size, - self._global_stride, + unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) + if hasattr(self, "_unsharded_param"): + assert ca.compiled_autograd_enabled + with torch.no_grad(): + alloc_storage(self._unsharded_param) + self._unsharded_param.copy_(unsharded_param) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad ) - self._unsharded_param = nn.Parameter(unsharded_param) - self._unsharded_param.requires_grad_(self.sharded_param.requires_grad) def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]: return tuple( @@ -443,10 +471,7 @@ def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: ) return _from_local_no_grad( tensor, - self._global_mesh, - self._global_placements, - self._global_size, - self._global_stride, + self._sharding_spec, ) def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: @@ -457,13 +482,12 @@ def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo) # TODO: Prefer this DTensor to be read-only and generalize the # placement once we support TP. - return _from_local_no_grad( - tensor, + post_forward_sharding_spec = DTensorSpec( self.post_forward_mesh_info.mesh, (Replicate(), Shard(0)), - self._global_size, - self._global_stride, + tensor_meta=self._sharding_spec.tensor_meta, ) + return _from_local_no_grad(tensor, post_forward_sharding_spec) def to_accumulated_grad_if_needed(self) -> None: # Access `_unsharded_param` to bypass the sharded state check since we @@ -495,12 +519,17 @@ def free_unsharded_param(self) -> None: self.all_gather_outputs, self._unsharded_inner_tensors ): free_storage(tensor) + if ca.compiled_autograd_enabled: + self.all_gather_outputs = [] + self._unsharded_inner_tensors = [] @property def all_gather_inputs(self) -> List[torch.Tensor]: # 1D self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) if self.sharded_state == ShardedState.SHARDED: - if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"): + if not ca.compiled_autograd_enabled and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): sharded_local_tensor = self._sharded_local_tensor if self.offload_to_cpu: sharded_local_tensor = sharded_local_tensor.to( @@ -521,7 +550,9 @@ def all_gather_inputs(self) -> List[torch.Tensor]: # 1D ) return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: - if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"): + if not ca.compiled_autograd_enabled and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): raise NotImplementedError all_gather_input = _to_dtype_if_needed( cast(torch.Tensor, self._sharded_post_forward_param_data), diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index ea2307222ce1..2361b7ba7c7e 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -1,8 +1,10 @@ +# mypy: allow-untyped-defs import contextlib from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple import torch +import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates @@ -402,6 +404,9 @@ def use_training_state(self, training_state: TrainingState): def _register_post_backward_hook( self, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + # Compile relies on `root_post_backward_callback` to call each `FSDPParamGroup.post_backward` + if ca.compiled_autograd_enabled: + return args, kwargs if not torch.is_grad_enabled(): return args, kwargs args_list, args_spec = tree_flatten(args) diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 15a00e83f086..f080e7550338 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 3efb8f7afd85..018333a65886 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Any, cast, NoReturn, Optional, Union @@ -128,10 +129,10 @@ def fully_shard( offload_policy, ) - # for dynamo - for module in managed_modules: - module._is_fsdp_managed_module = True # type: ignore[assignment] - module._fsdp_use_orig_params = True # type: ignore[assignment] + # For Dynamo + for managed_module in managed_modules: + managed_module._is_fsdp_managed_module = True # type: ignore[assignment] + managed_module._fsdp_use_orig_params = True # type: ignore[assignment] # Place FSDP leftmost for highest priority in the method resolution order cls = module.__class__ @@ -181,6 +182,8 @@ def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]: ``False``, then returns ``None`` and waits on the handle inside this function. + .. warning:: This method is experimental and subject to change. + .. note:: If ``async_op=True``, then the user does not have to call :meth:`wait` on the returned handle if waiting on the unshard op in the module's pre-forward is tolerable. FSDP will wait on the diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index 45e1b9d8ab7f..0cb4ea79bc7d 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index f91d1f29d98b..c77902c0d3a7 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -1,13 +1,17 @@ +# mypy: allow-untyped-defs +from collections import defaultdict from contextlib import contextmanager from functools import partial -from typing import Callable, cast, List, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d -from torch._C._distributed_c10d import _DistributedBackendOptions, Backend + +if TYPE_CHECKING: + from torch._C._distributed_c10d import _DistributedBackendOptions, Backend """ @@ -90,6 +94,8 @@ def _create_cuda_p2p_group( def is_cuda_p2p_group(group: c10d.ProcessGroup) -> bool: + if _test_with_non_cuda_p2p_group: + return True if not c10d.is_nccl_available(): return False try: @@ -121,3 +127,365 @@ def get_p2p_buffer_size(group: c10d.ProcessGroup) -> int: extended_api=True, devices=["cuda"], ) + + +_test_with_non_cuda_p2p_group: bool = False + + +@contextmanager +def test_with_non_cuda_p2p_group(): + """ + Force ops in this file to work with non-cuda_p2p groups for testing + purposes. Not thread safe. + """ + global _test_with_non_cuda_p2p_group + prev = _test_with_non_cuda_p2p_group + try: + _test_with_non_cuda_p2p_group = True + yield + finally: + _test_with_non_cuda_p2p_group = prev + + +_current_p2p_usage_counter: Optional[Dict[str, int]] = None + + +@contextmanager +def p2p_usage_counter(): + """ + Record the number of ops that utilized p2p capability for testing purposes. + Fallbacks are excluded. + """ + global _current_p2p_usage_counter + prev = _current_p2p_usage_counter + try: + _current_p2p_usage_counter = defaultdict(int) + yield _current_p2p_usage_counter + finally: + _current_p2p_usage_counter = prev + + +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group: c10d.ProcessGroup, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + tensor = all_gather_tensor(shard, gather_dim=1, group=group) + chunks = tensor.chunk(group.size()) + for src_rank, chunk in enumerate(chunks): + shard_consumer(chunk, src_rank) + + NOTE: + - The shard passed to shard consumer will always be contiguous. + """ + p2p_buf_sz_req = shard.numel() * shard.element_size() + if get_p2p_buffer_size(group) < p2p_buf_sz_req: + # We preferred the caller to handle fallback so that the computation + # doesn't need to be decomposed. + raise RuntimeError( + f"_pipelined_all_gather_and_consume on input with shape={shard.shape} " + f"and dtype={shard.dtype} requires {p2p_buf_sz_req} bytes of p2p buffers " + f"(got {get_p2p_buffer_size(group)} bytes)." + ) + + backend = get_cuda_p2p_backend(group) + group_size = group.size() + rank = group.rank() + + backend.stream().wait_stream(torch.cuda.current_stream()) + local_p2p_buf = backend.get_p2p_buffer(rank, shard.shape, shard.dtype) + + chunks = ag_out.chunk(group.size()) + + # While consuming local shard, copy it to the local p2p buffer + # in another stream. + shard_consumer(shard, rank) + chunks[rank].copy_(shard) + + with torch.cuda.stream(backend.stream()): + local_p2p_buf.copy_(shard) + work = backend.intra_node_barrier() + work.wait() + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + for i in range(1, group_size): + if i % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend.stream() + remote_rank = (i + rank) % group_size + remote_p2p_buf = backend.get_p2p_buffer(remote_rank, shard.shape, shard.dtype) + with torch.cuda.stream(stream): + chunks[remote_rank].copy_(remote_p2p_buf) + shard_consumer(chunks[remote_rank], remote_rank) + + torch.cuda.current_stream().wait_stream(backend.stream()) + + with torch.cuda.stream(backend.stream()): + work = backend.intra_node_barrier() + work.wait() + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, torch.Tensor], None], + output: torch.Tensor, + group: c10d.ProcessGroup, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group.size()): + ] + dist.all_to_all_single(output=output, input=torch.cat(chunks)) + """ + group_size = group.size() + rank = group.rank() + + out_chunks = output.chunk(group_size) + p2p_buf_sz_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + if get_p2p_buffer_size(group) < p2p_buf_sz_req: + # We preferred the caller to handle fallback so that the computation + # doesn't need to be decomposed. + raise RuntimeError( + f"_pipelined_produce_and_all2all on output with shape={output.shape} " + f"and dtype={output.dtype} requires {p2p_buf_sz_req} bytes of p2p buffers " + f"(got {get_p2p_buffer_size(group)} bytes)." + ) + + backend = get_cuda_p2p_backend(group) + backend.stream().wait_stream(torch.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return backend.get_p2p_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + # Directly write the local result to the destination. + # No need to go through the p2p buffers. + chunk_producer(rank, out_chunks[rank]) + + with torch.cuda.stream(backend.stream()): + chunk_producer((rank + 1) % group_size, local_p2p_buf_0) + backend.intra_node_barrier() + remote_p2p_buf = get_p2p_buf((rank - 1) % group_size, 0) + out_chunks[(rank - 1) % group_size].copy_(remote_p2p_buf) + + for step in range(2, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = torch.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend.stream() + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with torch.cuda.stream(stream): + chunk_producer((rank + step) % group_size, p2p_buf) + backend.intra_node_barrier() + out_chunks[remote_rank].copy_(remote_p2p_buf) + + torch.cuda.current_stream().wait_stream(backend.stream()) + backend.intra_node_barrier() + + +lib = torch.library.Library("cuda_p2p", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(Tensor A, Tensor[] Bs, int gather_dim, str group_name) -> (Tensor, Tensor[])" +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor" +) + + +@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: torch.Tensor, + Bs: List[torch.Tensor], + gather_dim: int, + group_name: str, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + return A.movedim(0, gather_dim), [ + torch.matmul(A, B).movedim(0, gather_dim) for B in Bs + ] + + +@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: torch.Tensor, + Bs: List[torch.Tensor], + gather_dim: int, + group_name: str, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + """ + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + p2p_buf_sz_req = A_shard.numel() * A_shard.element_size() + if ( + _test_with_non_cuda_p2p_group + or get_p2p_buffer_size(group) < p2p_buf_sz_req + # Pipelining a mamtul with split-k is not supported + or gather_dim == len(A_shard.shape) - 1 + ): + return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name) + + if _current_p2p_usage_counter is not None: + _current_p2p_usage_counter["fused_all_gather_matmul"] += 1 + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + x = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + x = x.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t): + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + ag_out = x.new_empty( + x.shape[0] * group.size(), + x.shape[1], + ) + outputs = [ + x.new_empty( + x.shape[0] * group.size(), + B.shape[1], + ) + for B in Bs + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + # Computing block-wise matmul along the first dim of A + def shard_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, B in enumerate(Bs): + torch.mm(shard, B, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + x, + shard_consumer, + ag_out, + group, + ) + return unflatten(ag_out), [unflatten(output) for output in outputs] + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + NOTE: + - The K dim across ranks are currently accumulated with bf16 with results + in accuracy loss. + """ + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + p2p_buf_sz_req = torch.Size(out_shape).numel() * A.element_size() * 2 + if _test_with_non_cuda_p2p_group or get_p2p_buffer_size(group) < p2p_buf_sz_req: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + if _current_p2p_usage_counter is not None: + _current_p2p_usage_counter["fused_matmul_reduce_scatter"] += 1 + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + shards = x.chunk(group.size()) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + torch.matmul(shards[rank], B, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1]) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group, + ) + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 0e58f0a2b3a1..9ac89166b25f 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import warnings from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union @@ -768,6 +769,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: "identifier has been deprecated. Please switch to " "using ProcessGroup, DeviceMesh, or group name instead.", FutureWarning, + stacklevel=3, ) return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) else: diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index 7abd33e42afa..c39cb4a9d50d 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 9afa7d9e793a..441bb421b195 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from typing import Optional import torch diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index c426503161c7..7506f17b046d 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils import _pytree as pytree from typing import Optional diff --git a/torch/distributed/_shard/metadata.py b/torch/distributed/_shard/metadata.py index b7bae9e6664a..850b065e4dab 100644 --- a/torch/distributed/_shard/metadata.py +++ b/torch/distributed/_shard/metadata.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import List, Union, Optional from functools import reduce diff --git a/torch/distributed/_shard/op_registry_utils.py b/torch/distributed/_shard/op_registry_utils.py index 4febe841186a..033dc7c58e0a 100644 --- a/torch/distributed/_shard/op_registry_utils.py +++ b/torch/distributed/_shard/op_registry_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from inspect import signature from .common_op_utils import _basic_validation diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py index 54d8a94ad3fe..e1acf7dc17a8 100644 --- a/torch/distributed/_shard/sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Union, Mapping, Dict, Any import torch.optim as optim diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 602f75163782..1b846a8dabb4 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import List, TYPE_CHECKING diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py index e672c54927db..4d35d24ecafc 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index 0a7999a4c263..034f91498161 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist import torch.distributed.distributed_c10d as distributed_c10d diff --git a/torch/distributed/_shard/sharded_tensor/_ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py index dfb661653e71..736190d491e1 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed._shard.sharded_tensor as sharded_tensor from torch.distributed._shard.sharded_tensor import ( diff --git a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py index 0e0911bb1d18..82737f82de53 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index f96eded95f31..7de78bf61f3f 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import torch from torch.distributed._shard.sharded_tensor import ( diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 65da388d0f4f..bf5db21b9a16 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations # type: ignore[attr-defined] from dataclasses import dataclass from typing import ( @@ -397,7 +398,11 @@ def shard_size(shard_md): return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] if enforce_dtype: - warnings.warn("`enforce_dtype` is deprecated. Please use `dtype` instead.", FutureWarning) + warnings.warn( + "`enforce_dtype` is deprecated. Please use `dtype` instead.", + FutureWarning, + stacklevel=2, + ) rank = dist.get_rank(self._process_group) full_size = self.metadata().size diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index cb112da5686b..8b3257240e38 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass, field from enum import Enum from typing import List diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index de7a44bb8200..549dde38cdf8 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import List, Tuple diff --git a/torch/distributed/_shard/sharded_tensor/shard.py b/torch/distributed/_shard/sharded_tensor/shard.py index d448cc6321b1..ac1e881370e8 100644 --- a/torch/distributed/_shard/sharded_tensor/shard.py +++ b/torch/distributed/_shard/sharded_tensor/shard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from typing import List diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index d904137ba6f0..782def0e4d4c 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections.abc import copy from typing import Optional, List, Sequence, TYPE_CHECKING diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index e8275063e038..07d3c2e19bc0 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple from torch.distributed._shard.metadata import ShardMetadata diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py index 1824b66a8194..7493eccdf015 100644 --- a/torch/distributed/_shard/sharding_spec/api.py +++ b/torch/distributed/_shard/sharding_spec/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod from dataclasses import dataclass import functools diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index 2775dbd9dd8d..bd2c960f7f60 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass import torch import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py index c869b71d69e7..83d3371c7f90 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py index c9cfcba1fe1a..117aed79520d 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index 2f954398f988..01a148b5a9a9 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py index 2848060bf28d..ce9984efac6e 100644 --- a/torch/distributed/_spmd/api.py +++ b/torch/distributed/_spmd/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod from contextlib import contextmanager, nullcontext from copy import copy diff --git a/torch/distributed/_spmd/batch_dim_utils.py b/torch/distributed/_spmd/batch_dim_utils.py index 6d36b2e38118..d3c39295c0e6 100644 --- a/torch/distributed/_spmd/batch_dim_utils.py +++ b/torch/distributed/_spmd/batch_dim_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Dict, List, Set import torch diff --git a/torch/distributed/_spmd/comm_tensor.py b/torch/distributed/_spmd/comm_tensor.py index 292f5b250861..a54ed2f46d21 100644 --- a/torch/distributed/_spmd/comm_tensor.py +++ b/torch/distributed/_spmd/comm_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from functools import partial from typing import Any, List, Optional, Tuple diff --git a/torch/distributed/_spmd/config.py b/torch/distributed/_spmd/config.py index 54f0cc4dc5c8..73ee19e803dc 100644 --- a/torch/distributed/_spmd/config.py +++ b/torch/distributed/_spmd/config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import sys from types import ModuleType diff --git a/torch/distributed/_spmd/data_parallel.py b/torch/distributed/_spmd/data_parallel.py index 5e376d9f0c4a..8b18c6c86763 100644 --- a/torch/distributed/_spmd/data_parallel.py +++ b/torch/distributed/_spmd/data_parallel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from contextlib import contextmanager from enum import Enum diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index 0ed2bcabb907..5fb5ff766799 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import operator from dataclasses import dataclass diff --git a/torch/distributed/_spmd/experimental_ops.py b/torch/distributed/_spmd/experimental_ops.py index e108061e5d74..94a0da822449 100644 --- a/torch/distributed/_spmd/experimental_ops.py +++ b/torch/distributed/_spmd/experimental_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import cast, List, Optional, Sequence, Tuple diff --git a/torch/distributed/_spmd/graph_optimization.py b/torch/distributed/_spmd/graph_optimization.py index 10423fb55cd4..4a5cad7917d8 100644 --- a/torch/distributed/_spmd/graph_optimization.py +++ b/torch/distributed/_spmd/graph_optimization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] import collections import itertools diff --git a/torch/distributed/_spmd/iter_graph_module.py b/torch/distributed/_spmd/iter_graph_module.py index f1e8e960f361..cd5f934c5c7f 100644 --- a/torch/distributed/_spmd/iter_graph_module.py +++ b/torch/distributed/_spmd/iter_graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import inspect import logging diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 48d1a6bfb9c2..4d7a7b086509 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import io import math @@ -331,6 +332,7 @@ def _copy_state_dict( state_dict: Dict[str, Any], copy_state_dict: Dict[str, Any], non_blocking: bool = False, + type_check: bool = True, ) -> Dict[str, Any]: """ Copies all tensors in a given state dict into a different state_dict with the @@ -352,6 +354,9 @@ def _copy_state_dict( The state dict we are copying into. This state_dict must have exactly the same structure as the source `state_dict`. non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. Returns: State Dict copy @@ -367,7 +372,7 @@ def _copy_state_dict( cpu_offload=False, ranks_only=tuple(), companion_obj=copy_state_dict, - type_check=True, + type_check=type_check, non_blocking=non_blocking, ) @@ -509,7 +514,11 @@ def _broadcast_tensors( if pg is None: pg = dist.distributed_c10d._get_default_group() - dist._broadcast_coalesced(pg, tensors, 500, 0) + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) for key in keys: _local_state = local_state_dict.get(key, None) @@ -528,9 +537,11 @@ def _broadcast_state_dict( local_state_dict: Dict[str, Any], device: torch.device, pg: Optional[dist.ProcessGroup] = None, + strict: bool = False, ) -> None: - # Gather the full state dict keys, non tensor values, scalar tensor values, - # and tensor information. + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. ret = {} if dist.get_rank() == 0: for key, value in full_state_dict.items(): @@ -547,7 +558,10 @@ def _broadcast_state_dict( # Gather values keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() for key, value in ret.items(): + global_keys.add(key) if not isinstance(value, _TensorInfo): if key in local_state_dict: local_state_dict[key] = value @@ -557,11 +571,16 @@ def _broadcast_state_dict( ret[key] = full_state_dict[key] keys.append(key) - # Broadcast every 10 tensors, just hardcode the number for now - if len(keys) >= 10: + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: _broadcast_tensors(ret, local_state_dict, keys, device, pg) keys.clear() + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + if keys: _broadcast_tensors(ret, local_state_dict, keys, device, pg) diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index 6ab35e10a69f..85de716d8439 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Optional, Sequence @@ -52,6 +53,8 @@ def _dtensor_init_helper( placements=None, **kwargs, ) -> DTensor: + from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + # if device_mesh is None, use the one from mesh resources device_mesh = device_mesh or _mesh_resources.get_current_mesh() kwargs["device"] = device_mesh.device_type @@ -77,8 +80,6 @@ def _dtensor_init_helper( # this tensor meta is not used except `shape` dtype = kwargs.get("dtype", torch.get_default_dtype()) - from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - tensor_meta = TensorMeta(size, (0,), dtype) spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta) @@ -91,13 +92,19 @@ def _dtensor_init_helper( else: local_tensor = init_op(local_shape, **kwargs) + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + return DTensor( - local_tensor=local_tensor, - device_mesh=device_mesh, - placements=tuple(placements), - shape=size, - dtype=local_tensor.dtype, - stride=torch_stride, + local_tensor, + spec, requires_grad=kwargs["requires_grad"], ) diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py index 93052d6ddd62..4c1d18403666 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import math from dataclasses import dataclass diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index 17f565b6d776..1739243a5d3b 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -395,16 +395,7 @@ def wrap(res: object, spec: OutputSpecType) -> object: assert isinstance( spec, DTensorSpec ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." - assert spec.tensor_meta is not None - return dtensor.DTensor( - res, - spec.mesh, - spec.placements, - shape=spec.tensor_meta.shape, - dtype=spec.tensor_meta.dtype, - requires_grad=res.requires_grad, - stride=spec.tensor_meta.stride, - ) + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) else: # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor assert res.ndim == 0, "output tensor should be scalar!" diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py index 85c14746ce13..071c2ac4748f 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/_tensor/_op_schema.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass from functools import cached_property from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -238,15 +239,24 @@ def args_spec(self) -> Tuple[DTensorSpec, ...]: with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) mainly used by sharding propagation to propagate the output spec """ - # filter out non-relevant values from args schema to get a clean spec list - # this would mainly be used by sharding propagation rules - if self.schema_info is not None and self.schema_info.needs_pytree: - return tuple( - item - for item in tree_leaves(self.args_schema) - if isinstance(item, DTensorSpec) - ) - return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec)) + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, DTensorSpec)) + + @property + def args_strategy(self) -> Tuple[OpStrategy, ...]: + # filter out non-relevant values from args schema to get a clean OpStrategy list + # separate with args_spec for the ease of type annotation + # TODO: see if we should merge this with args_spec + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, OpStrategy)) def __repr__(self) -> str: args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/_tensor/_redistribute.py index 5cef7dbb047c..2653423a257f 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/_tensor/_redistribute.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from functools import lru_cache from typing import cast, Dict, List, NamedTuple, Tuple @@ -7,11 +8,12 @@ import torch.distributed._tensor.api as dtensor from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, + TensorMeta, ) @@ -177,7 +179,7 @@ def redistribute_local_tensor( if target.is_replicate(): # Case 1: target is Replicate if current.is_partial(): - partial_spec = cast(_Partial, current) + partial_spec = cast(Partial, current) new_local_tensor = partial_spec._reduce_value( local_tensor, device_mesh, i ) @@ -195,7 +197,7 @@ def redistribute_local_tensor( target_placement = cast(Shard, target) target_dim = target_placement.dim if current.is_partial(): - partial_spec = cast(_Partial, current) + partial_spec = cast(Partial, current) new_local_tensor = partial_spec._reduce_shard_value( local_tensor, device_mesh, i, target_placement ) @@ -219,7 +221,7 @@ def redistribute_local_tensor( ) elif target.is_partial(): if current.is_replicate(): - partial_spec = cast(_Partial, target) + partial_spec = cast(Partial, target) # skip the replicate to partial transformation when we are in backward pass # In this case we keep the grad as replicate, this is because we don't # want to convert the replicated gradients back to partial, although @@ -283,15 +285,12 @@ def forward( # type: ignore[override] else: # use the same local tensor if placements are the same. output = input._local_tensor + target_spec = current_spec return dtensor.DTensor( output, - device_mesh, - placements, - shape=input.shape, - dtype=input.dtype, + target_spec, requires_grad=input.requires_grad, - stride=input.stride(), ) @staticmethod @@ -316,14 +315,20 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] normalized_placements.append(Replicate()) else: normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=grad_output.dtype, + ), + ) output_dtensor = dtensor.DTensor( output, - previous_spec.mesh, - tuple(normalized_placements), - shape=grad_output.shape, - dtype=grad_output.dtype, + spec, requires_grad=grad_output.requires_grad, - stride=grad_output.stride(), ) return ( diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index 314ef87193eb..449cf6c23775 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union @@ -88,7 +89,7 @@ def register_op_strategy( if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info - @lru_cache + @lru_cache # noqa: B019 def _propagate_tensor_meta( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/_tensor/_tp_conv.py index ebcc981d2c93..d480e9d7f79e 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/_tensor/_tp_conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates # implement matrix related ops for distributed tensor from typing import cast, Dict, List, Tuple diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/_tensor/_utils.py index 08c381dd3d1d..a3cc8ee5a602 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/_tensor/_utils.py @@ -4,8 +4,8 @@ import torch.distributed._tensor.api as dtensor from torch._prims_common import ShapeType from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -178,7 +178,7 @@ def compute_global_tensor_info( if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: # rescale the stride by the shard size tensor_stride[i] = tensor_stride[i] * mesh_dim_size - elif not isinstance(placement, (Replicate, _Partial)): + elif not isinstance(placement, (Replicate, Partial)): raise RuntimeError(f"placement type {type(placement)} not supported!") return tensor_shape, tensor_stride diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index c0c0e1470df5..7da5f4e3dfcb 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import inspect import warnings @@ -15,8 +16,8 @@ ) from torch.distributed._tensor._utils import compute_global_tensor_info from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -86,16 +87,21 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] ) tensor_stride = tuple(tensor_stride) grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) return ( DTensor( grad_output, - mesh, - grad_placements, - shape=dtensor_meta.shape, - dtype=dtensor_meta.dtype, + grad_spec, requires_grad=grad_output.requires_grad, - stride=tensor_stride, ), None, ) @@ -146,17 +152,23 @@ def forward( # type: ignore[override] input = input.contiguous() mesh_broadcast(input, device_mesh, mesh_dim=idx) + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + # We want a fresh Tensor object that shares memory with the input tensor dist_tensor = DTensor( input.view_as(input), - device_mesh, - placements, - shape=tensor_shape, - dtype=input.dtype, + dist_spec, # requires_grad of the dist tensor depends on if input # requires_grad or not requires_grad=input.requires_grad, - stride=tensor_stride, ) return dist_tensor @@ -202,13 +214,9 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ def __new__( cls, local_tensor: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], + spec: DTensorSpec, *, - shape: torch.Size, - dtype: torch.dtype, requires_grad: bool, - stride: Tuple[int, ...], ) -> "DTensor": """ Construct a DTensor from a local tensor, device mesh, and placement and @@ -228,19 +236,18 @@ def __new__( # new method instruct wrapper tensor from local_tensor and add # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, - shape, - strides=stride, - dtype=dtype, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, device=local_tensor.device, layout=local_tensor.layout, requires_grad=requires_grad, ) - tensor_meta = TensorMeta(shape, stride, dtype) - # deepcopy and set spec - r._spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta) + r._spec = spec r._local_tensor = local_tensor return r @@ -264,21 +271,27 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): ), "Expecting spec to be not None from `__tensor_flatten__` return value!" local_tensor = inner_tensors["_local_tensor"] spec, requires_grad = flatten_spec - return DTensor( - local_tensor, - spec.mesh, - spec.placements, + unflatten_tensor_meta = TensorMeta( shape=outer_size, + stride=outer_stride, dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, requires_grad=requires_grad, - stride=outer_stride, ) def __coerce_tangent_metadata__(self): - if not any(isinstance(p, _Partial) for p in self.placements): + if not any(isinstance(p, Partial) for p in self.placements): return self placements = [ - Replicate() if isinstance(p, _Partial) else p for p in self.placements + Replicate() if isinstance(p, Partial) else p for p in self.placements ] return self.redistribute(device_mesh=self.device_mesh, placements=placements) @@ -406,6 +419,9 @@ def to_local( .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned will depend on if the `DTensor` requires_grad or not. """ + if not torch.is_grad_enabled(): + return self._local_tensor + if grad_placements is not None and not isinstance(grad_placements, tuple): grad_placements = tuple(grad_placements) return _ToTorchTensor.apply( @@ -456,7 +472,7 @@ def redistribute( for i, placement in enumerate(placements): if placement.is_partial(): raise RuntimeError( - "Can not redistribute to _Partial, _Partial is for internal use only!" + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" ) elif isinstance(placement, Shard) and placement.dim < 0: # normalize shard dim to be positive @@ -638,14 +654,19 @@ def distribute_tensor( assert local_tensor is not None, "distributing a tensor should not be None" # detach the local tensor passed to DTensor since after the construction # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) return DTensor( local_tensor.requires_grad_(tensor.requires_grad), - device_mesh, - placements, - shape=tensor.size(), - dtype=tensor.dtype, + spec, requires_grad=tensor.requires_grad, - stride=tensor.stride(), ) @@ -747,6 +768,7 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: "Deprecating input_fn that takes two arguments (inputs, device_mesh), " "please use input_fn that takes in (module, inputs, device_mesh) instead!", FutureWarning, + stacklevel=2, ) module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] elif num_args == 3: @@ -767,6 +789,7 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: "Deprecating output_fn that takes two arguments (inputs, device_mesh), " "please use output_fn that takes in (module, inputs, device_mesh) instead!", FutureWarning, + stacklevel=2, ) module.register_forward_hook( lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index 2cd388cf93e4..b7bde685fd1e 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.debug.comm_mode import CommDebugMode diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/_tensor/debug/_op_coverage.py index a722136e2baf..4f5424633235 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/_tensor/debug/_op_coverage.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from operator import itemgetter from typing import List diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 62e10a160384..cc28498d766c 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -1,10 +1,19 @@ +# mypy: allow-untyped-defs from collections import defaultdict from typing import Any, Dict import torch +from torch.autograd.graph import register_multi_grad_hook from torch.distributed._tensor.api import DTensor + +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten +from torch.utils.module_tracker import ModuleTracker funcol_native = torch.ops._c10d_functional funcol_py = torch.ops.c10d_functional @@ -25,17 +34,71 @@ } c10d_collective_ops = { - c10d_ops.allreduce_, c10d_ops._allgather_base_, c10d_ops._reduce_scatter_base_, + c10d_ops.allgather_, + c10d_ops.allgather_coalesced_, + c10d_ops.allgather_into_tensor_coalesced_, + c10d_ops.allreduce_, c10d_ops.allreduce_coalesced_, + c10d_ops.alltoall_, + c10d_ops.alltoall_base_, c10d_ops.broadcast_, c10d_ops.gather_, c10d_ops.scatter_, c10d_ops.reduce_, + c10d_ops.reduce_scatter_, + c10d_ops.reduce_scatter_tensor_coalesced_, } +class ModuleParamaterShardingTracker(ModuleTracker): + """ + Inherits ModuleTracker and expands on its functionality to track the + parameters and sharding information of a model at a module-level + """ + + def __init__(self): + super().__init__() + self.module_parameters_dict = {} + self.sharding_dict = {} + + def _fw_pre_hook(self, mod, input): + name = super()._get_mod_name(mod) + super()._get_append_fn(name, False)() + + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if tensors: + register_multi_grad_hook(tensors, super()._get_pop_fn(name, True)) + + for param_name, param in mod.named_parameters(recurse=False): + if name not in self.module_parameters_dict: + self.module_parameters_dict[name] = {} + + self.module_parameters_dict[name][param_name] = param.data + + if isinstance(param.data, DTensor): + key_name = name + "." + param_name + self.sharding_dict[key_name] = param.data.placements + + def __enter__(self): + self.module_parameters_dict.clear() + self.sharding_dict.clear() + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook(super()._fw_post_hook) + + def __exit__(self, *args): + super().__exit__(*args) + + def print_paramater_info(self): + print(self.module_parameters_dict) + + def print_sharding_info(self): + for key, value in self.sharding_dict.items(): + print(key + ": " + str(value)) + + class CommDebugMode(TorchDispatchMode): """ ``CommDebugMode`` is a context manager that counts the number of @@ -64,6 +127,7 @@ def __init__(self): self.comm_registry.add(py_op) self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) + self.advanced_module_tracker = ModuleParamaterShardingTracker() def get_total_counts(self) -> int: return sum(self.comm_counts.values()) @@ -76,14 +140,28 @@ def get_comm_counts(self) -> Dict[Any, int]: """ return self.comm_counts + def get_parameter_info(self) -> Dict[str, Dict[str, Any]]: + return self.advanced_module_tracker.module_parameters_dict + + def get_sharding_info(self) -> Dict[str, Dict[str, Any]]: + return self.advanced_module_tracker.sharding_dict + def __enter__(self): self.comm_counts.clear() super().__enter__() + self.advanced_module_tracker.__enter__() return self def __exit__(self, *args): + self.advanced_module_tracker.__exit__() super().__exit__(*args) + def print_paramater_info(self): + self.advanced_module_tracker.print_paramater_info() + + def print_sharding_info(self): + self.advanced_module_tracker.print_sharding_info() + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # When running this mode with DTensor, ordinarily all modes will # run **before** subclasses get a chance to run. diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py index 91bc9c2a382c..76cd8f3e9208 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/_tensor/debug/visualize_sharding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Sequence, Tuple import numpy as np diff --git a/torch/distributed/_tensor/examples/checkpoint_example.py b/torch/distributed/_tensor/examples/checkpoint_example.py index 9bccc07d9625..1cb292f12c41 100644 --- a/torch/distributed/_tensor/examples/checkpoint_example.py +++ b/torch/distributed/_tensor/examples/checkpoint_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The following example contains a simple MLP model that uses different DTensor layouts, and use the checkpointing API to diff --git a/torch/distributed/_tensor/examples/convnext_example.py b/torch/distributed/_tensor/examples/convnext_example.py index df6b7d3d71fd..61f8d0234938 100644 --- a/torch/distributed/_tensor/examples/convnext_example.py +++ b/torch/distributed/_tensor/examples/convnext_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The following example demonstrates how to train a ConvNeXt model with intermediate activations sharded across mutliple GPUs via DTensor diff --git a/torch/distributed/_tensor/examples/display_sharding_example.py b/torch/distributed/_tensor/examples/display_sharding_example.py new file mode 100644 index 000000000000..4a0eb113e9c3 --- /dev/null +++ b/torch/distributed/_tensor/examples/display_sharding_example.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict + +import torch + +from torch.distributed._tensor import DeviceMesh, Shard +from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed._tensor.debug.comm_mode import ModuleParamaterShardingTracker + +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + MLPModule, + MLPStacked, + NUM_DEVICES, +) + + +def get_device_type(): + return ( + "cuda" + if torch.cuda.is_available() and torch.cuda.device_count() >= 4 + else "cpu" + ) + + +c10d_functional = torch.ops.c10d_functional + +aten = torch.ops.aten +supported_ops = [aten.view.default, aten._to_copy.default] + + +class DisplayShardingExample: + """ + Checks if the set of keys in ground truth dictionary and the set + produced in advanced_module_tracker are in the same order + """ + + def __init__(self, world_size, rank): + self.world_size = world_size + self.rank = rank + self.device_type = get_device_type() + + def same_set_of_keys(self, dict1, dict2): + dict1_keys = [] + dict2_keys = [] + + for key in dict1: + for nested_key in dict1[key]: + dict1_keys.append((key, nested_key)) + + for key in dict2: + for nested_key in dict2[key]: + dict2_keys.append((key, nested_key)) + + if len(dict1_keys) != len(dict2_keys): + return False + + for i in range(len(dict1_keys)): + if dict1_keys[i] != dict2_keys[i]: + return False + + return True + + def ground_truth(self, model): + module_parameters_dict: Dict[str, Any] = {} + + for name, parameters in model.named_parameters(): + module_name = model.__class__.__name__ + "." + name.rsplit(".", 1)[0] + parameter_name = name.rsplit(".", 1)[1] + + if module_name not in module_parameters_dict: + module_parameters_dict[module_name] = {} + + module_parameters_dict[module_name][parameter_name] = parameters.data + + return module_parameters_dict + + def test_display_parameters_MLP(self): + """Example of obtaining all module's FQN and parameters for a given model""" + + inp_size = [8, 10] + + rng_seed = 0 + torch.manual_seed(rng_seed) + inp = torch.rand(*inp_size) + model = MLPModule(None) + + LR = 0.25 + + comm_mode = CommDebugMode() + module_tracker = ModuleParamaterShardingTracker() + + with comm_mode, module_tracker: + output = model(inp) + output.sum().backward() + + print( + self.same_set_of_keys( + self.ground_truth(model), module_tracker.module_parameters_dict + ) + ) + + model2 = MLPStacked(None) + with comm_mode, module_tracker: + output = model2(inp) + + print( + self.same_set_of_keys( + self.ground_truth(model2), module_tracker.module_parameters_dict + ) + ) + + def test_display_parameters_MLP_distributed( + self, is_seq_parallel=False, recompute_activation=False + ): + "Example of obtaining all module's FQN and parameters for a given distributed model and printing the sharding info" + device_mesh = DeviceMesh( + self.device_type, + torch.arange(0, NUM_DEVICES), + ) + inp_size = [8, 10] + rng_seed = self.rank if is_seq_parallel else 0 + torch.manual_seed(rng_seed) + inp = torch.rand(*inp_size, device=self.device_type) + model = MLPModule(self.device_type) + + LR = 0.25 + + parallelize_plan = { + "net1": ColwiseParallel(input_layouts=Shard(0)) + if is_seq_parallel + else ColwiseParallel(), + "net2": RowwiseParallel(output_layouts=Shard(0)) + if is_seq_parallel + else RowwiseParallel(), + } + + model = parallelize_module(model, device_mesh, parallelize_plan) + + comm_mode = CommDebugMode() + + with comm_mode: + output_tp = model(inp) + output_tp.sum().backward() + + print( + self.same_set_of_keys( + self.ground_truth(model), comm_mode.get_parameter_info() + ) + ) + + comm_mode.print_sharding_info() + + +def run_example(world_size, rank): + # set manual seed + torch.manual_seed(0) + + # run the example + instantiated_test = DisplayShardingExample(world_size, rank) + instantiated_test.test_display_parameters_MLP_distributed() + + +if __name__ == "__main__": + # this script is launched via torchrun which automatically manages ProcessGroup + import os + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + assert world_size == 4 # our example uses 4 worker ranks + + run_example(world_size, rank) diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/_tensor/examples/torchrec_sharding_example.py index 8edbad13301f..3e6c63dd18eb 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/_tensor/examples/torchrec_sharding_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ The following example demonstrates how to represent torchrec's embedding sharding with the DTensor API. diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py index 587eef3011ba..2dd21605ffcc 100644 --- a/torch/distributed/_tensor/experimental/__init__.py +++ b/torch/distributed/_tensor/experimental/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from contextlib import contextmanager diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 002ff5542a11..0fc6ce96e6e0 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -1,7 +1,9 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Callable, Optional, Sequence, Tuple, Union import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor.placement_types import Placement @@ -12,7 +14,7 @@ PlacementType = Optional[Sequence[Placement]] -InputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] +InputPlacements = Optional[Tuple[PlacementType, ...]] OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] @@ -32,24 +34,36 @@ def local_map( func (Callable): the function to be applied on each local shard of :class:`DTensor`s. out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): - the desired placements of the output :class:`DTensor`s. If the `output` of - `func` is a Python collection, the `out_placements` will be a Tuple of - `PlacementType` values 1:1 mapping to the flattened `output`. For - :class:`Tensor` output, the corresponding `PlacementType` will be its + the desired placements of the :class:`DTensor`s in `func`'s flattened output. + If the flattened `output` is a single value, the `out_placements` should be + of type `PlacementType`. Otherwise if the flattened `output` has multiple + values, the `out_placements` should be a tuple of `PlacementType` values 1:1 + mapping to the flattened `output`. + Besides, for :class:`Tensor` output, we use `PlacementType` as its placements (a `Tuple[Placement]` value). For non-:class:`Tensor` output, - the `PlacementType` will be `None`. - in_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]], optional): - the required placements of the input :class:`DTensor`s. If not specified, - the input :class:`DTensor` will not be redistributed before passing its local - tensor to `func`. Similarly to `out_placements`, `in_placements` should keep - a 1:1 mapping to the flattened input of `func`. If a redistribution is - required according to `in_placements` and `redistribute_inputs` is `False`, - an exception will be raised. + the `PlacementType` should be `None`. + Note that the only exception is when no :class:`DTensor` argument is passed + in. In this case, even if `out_placements` is not `None`, the result function + should ignore the desired placements because the application is not on + :class:`DTensors`. + in_placements (Tuple[`PlacementType`, ...], optional): + the required placements of the :class:`DTensor`s in `func`'s flattened input. + If `in_placements` is specified, `local_map` would examine whether the + placements of each :class:`DTensor` argument is the same as the required + placements or not. If the placements are not the same and + `redistribute_inputs` is `False`, an exception will be raised. Otherwise if + `redistribute_inputs` is `True`, the argument will be first redistributed to + the required sharding placements before passing its local tensor to `func`. + The only exception is when required placements are not `None` and the + argument is a :class:`torch.Tensor`. In this case, the placements examination + will be skipped and the argument will be directly passed to `func`. + If `in_placements` is `None`, no placements examination will be performed. + Default: `None` device_mesh (:class:`DeviceMesh`, optional): the device mesh that all the :class:`DTensor`s are placed on. If not specified, this will be inferred from the input :class:`DTensor`s' device mesh. `local_map` requires every :class:`DTensor`s to be placed on the same - device mesh. + device mesh. Default: `None`. redistribute_inputs (bool, optional): the bool value indicating whether to reshard the input :class:`DTensor`s when their placements are different from the required input placements. If this @@ -93,9 +107,9 @@ def local_map( >>> device_mesh=device_mesh, >>> ) >>> - >>> W_dt = distribute_tensor(W, device_mesh, col_wise) # col-wisely sharded W tensor - >>> X_dt = distribute_tensor(X, device_mesh, row_wise) # row-wisely sharded X tensor - >>> Y_dt = local_mm_allreduce_forward(W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors + >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor + >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor + >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors NOTE: This API is currently experimental and subject to change """ @@ -103,10 +117,16 @@ def local_map( def wrapped(*args, **kwargs): # process input args flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) # we assume every DTensor object is placed on the same device mesh flat_local_args = [] nonlocal device_mesh # access var device_mesh from the outer scope + seen_dtensor_arg = False for idx, arg in enumerate(flat_args): if isinstance(arg, DTensor): # TODO: the current code doesn't consider the uneven sharding case @@ -115,17 +135,16 @@ def wrapped(*args, **kwargs): if device_mesh is None: # infer device mesh from the DTensor arg device_mesh = arg.device_mesh + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True + assert arg.device_mesh == device_mesh, ( - f"arg {arg} in local_map has a mismatched device mesh:" - f"{arg} has device mesh {arg.device_mesh} while" + f"arg {arg} in local_map has a mismatched device mesh: " + f"{arg} has device mesh {arg.device_mesh} while " f"the expected device mesh is {device_mesh}!" ) if in_placements is not None: - spec = ( - in_placements[idx] - if isinstance(in_placements, tuple) - else in_placements - ) + spec = in_placements[idx] assert ( spec is not None ), f"DTensor input {arg} expects placements but received {spec}!" @@ -139,44 +158,62 @@ def wrapped(*args, **kwargs): arg = arg.redistribute(device_mesh, spec) else: raise ValueError( - f"arg {arg} in local_map has a mismatched placements:" - f"arg placements is {arg.placements} but the input" - f"placements is {spec}!" - "If redistribute_inputs is wanted, set redistribute_inputs=True to local_map." + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." ) - flat_local_args.append(arg.to_local()) + local_arg = arg.to_local() + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() + + flat_local_args.append(local_arg) else: + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) + flat_local_args.append(arg) local_args = pytree.tree_unflatten(flat_local_args, args_spec) - out = func(device_mesh, *local_args, **kwargs) + out = func(*local_args, **kwargs) - # process output - flat_out, out_spec = pytree.tree_flatten(out) - flat_dist_out = [] - for idx, out in enumerate(flat_out): - spec = ( - out_placements[idx] - if isinstance(out_placements, tuple) - else out_placements - ) - if isinstance(out, torch.Tensor): - assert not isinstance( - out, DTensor - ), f"torch.Tensor output expected but received {type(out)}: {out}" + if seen_dtensor_arg: + # process output + flat_out, out_spec = pytree.tree_flatten(out) - flat_dist_out.append( - DTensor.from_local(out, device_mesh, spec, run_check=False) + flat_dist_out = [] + for idx, out in enumerate(flat_out): + spec = ( + out_placements[idx] + if isinstance(out_placements, tuple) + else out_placements ) - else: - assert ( - spec is None - ), f"Non-tensor output {out} expects None placements but received {spec}!" - flat_dist_out.append(out) + if isinstance(out, torch.Tensor): + assert not isinstance( + out, DTensor + ), f"torch.Tensor output expected but received {type(out)}: {out}" + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert ( + spec is None + ), f"Non-tensor output {out} expects None placements but received {spec}!" + + flat_dist_out.append(out) - return pytree.tree_unflatten(flat_dist_out, out_spec) + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out return wrapped diff --git a/torch/distributed/_tensor/experimental/tp_transform.py b/torch/distributed/_tensor/experimental/tp_transform.py index b36f3d87e3d8..4a18d36bbc64 100644 --- a/torch/distributed/_tensor/experimental/tp_transform.py +++ b/torch/distributed/_tensor/experimental/tp_transform.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import operator from typing import Any, cast, Dict, List, Optional, Sequence, Tuple diff --git a/torch/distributed/_tensor/ops/basic_strategy.py b/torch/distributed/_tensor/ops/basic_strategy.py index 6274be44cd67..cc28cc19d370 100644 --- a/torch/distributed/_tensor/ops/basic_strategy.py +++ b/torch/distributed/_tensor/ops/basic_strategy.py @@ -5,8 +5,8 @@ from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -126,7 +126,7 @@ def gen_einsum_strategies( # split contracting dim for contracting_dim in edims.contracting_dims: - placement_list = [_Partial()] + placement_list = [Partial()] for input_dim in input_dims: input_contracting_dim = input_dim.index(contracting_dim) placement_list.append(Shard(input_contracting_dim)) @@ -157,9 +157,9 @@ def gen_einsum_strategies( # linearity strategy if linearity: - linearity_placement_list: List[Placement] = [_Partial()] + linearity_placement_list: List[Placement] = [Partial()] for input_dim in input_dims: - linearity_placement_list.append(_Partial()) + linearity_placement_list.append(Partial()) mesh_dim_strategies.append(linearity_placement_list) all_mesh_dim_strategies.append(mesh_dim_strategies) diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index e79bdd13cd8c..6f8cc8c67851 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -1,26 +1,19 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates # implement matrix related ops for distributed tensor -import itertools from dataclasses import dataclass, field from typing import cast, List, Optional import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - StrategyType, -) +from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, StrategyType from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, - is_tensor_shardable, + expand_to_full_mesh_op_strategy, register_op_strategy, ) from torch.distributed._tensor.placement_types import ( - _Partial, - DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -42,7 +35,7 @@ def materialize_mask(self, mask): def release_mask(self): # TODO: evaluate if we need to release the mask buffer or the buffer - # can just have the same lifetime as the _Partial placement + # can just have the same lifetime as the Partial placement if self.data is None: raise RuntimeError("MaskBuffer has not been materialized") self.data = None @@ -62,7 +55,7 @@ def apply_mask(self, tensor): @dataclass(frozen=True) -class _MaskPartial(_Partial): +class _MaskPartial(Partial): """ A partial mask placement devised for rowwise sharded embedding op, where we need to mask and adjust the indices to the local embedding shard, embedding masking @@ -182,64 +175,35 @@ def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: indices_shape = indices_strategy.shape output_emd_dim = len(indices_shape) - all_mesh_dim_strategies = [] - - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate - colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial - embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) - - # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates - # from the input indices and use it for output reduction - rowwise_sharding = [ - embedding_partial_placement, - Shard(0), - embedding_partial_placement, - ] - single_mesh_dim_strategies.append(rowwise_sharding) - - # batch dim sharding, weight replicated, input can shard on any dim, output follows input - for input_dim in range(len(indices_shape)): - batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)] - single_mesh_dim_strategies.append(batch_sharding) - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - if is_tensor_shardable(weight_shape, spec_list[1]) and is_tensor_shardable( - indices_shape, spec_list[2] - ): - # only add to the strategy list when both weight and indices are shardable - weight_spec, indices_spec = spec_list[1:] - redistribute_cost = [ - generate_redistribute_costs(weight_strategy, weight_spec), - generate_redistribute_costs(indices_strategy, indices_spec), - ] - strat = PlacementStrategy( - output_specs=spec_list[0], - input_specs=spec_list[1:], - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strat) - - return OpStrategy(all_strategies) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0]) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) @register_op_strategy(aten.embedding_dense_backward.default) @@ -257,55 +221,26 @@ def embedding_dense_backward_strategy( indices_shape = indices_strategy.shape grad_out_ndim = len(grad_out_shape) - all_mesh_dim_strategies = [] - - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding backward, grad_out shard on last dim, input replicate, - # weight grad shard colwise - colwise_sharding = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # batch dim sharding, weight replicated, grad_out/input have same sharding - # that can shard on any dim, weight grad partial - for input_dim in range(len(indices_shape)): - batch_sharding = [_Partial(), Shard(input_dim), Shard(input_dim)] - single_mesh_dim_strategies.append(batch_sharding) - - # grad_out partial, input replicate, weight grad keep partial - partial_sharding = [_Partial(), _Partial(), Replicate()] - single_mesh_dim_strategies.append(partial_sharding) - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - if is_tensor_shardable(grad_out_shape, spec_list[1]) and is_tensor_shardable( - indices_shape, spec_list[2] - ): - # only add to the strategy list when both grad_out and indices are shardable - grad_out_spec, indices_spec = spec_list[1:] - redistribute_cost = [ - generate_redistribute_costs(grad_out_strategy, grad_out_spec), - generate_redistribute_costs(indices_strategy, indices_spec), - ] - strat = PlacementStrategy( - output_specs=spec_list[0], - input_specs=spec_list[1:], - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strat) - - return OpStrategy(all_strategies) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding backward, grad_out shard on last dim, input replicate, + # weight grad shard colwise + colwise_sharding = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # batch dim sharding, weight replicated, grad_out/input have same sharding + # that can shard on any dim, weight grad partial + for input_dim in range(len(indices_shape)): + batch_sharding = [Partial(), Shard(input_dim), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + # grad_out partial, input replicate, weight grad keep partial + partial_sharding = [Partial(), Partial(), Replicate()] + single_mesh_dim_strategies.append(partial_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 9a02f798f8ac..377c50dffa13 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -1,5 +1,5 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -import itertools import math from dataclasses import dataclass from enum import Enum @@ -16,17 +16,17 @@ ) from torch.distributed._tensor.ops.utils import ( as_list, + expand_to_full_mesh_op_strategy, generate_redistribute_costs, is_tensor_evenly_shardable, - is_tensor_shardable, normalize_dim, normalize_dims, normalize_to_torch_size, register_op_strategy, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -52,7 +52,7 @@ class NormReduction: @dataclass(frozen=True) -class _NormPartial(_Partial): +class _NormPartial(Partial): """ This placement is used for partial vector norm. @@ -229,7 +229,7 @@ def map_placements_after_reduction( """ new_placements: List[Placement] = [] for placement in placements: - if isinstance(placement, (Replicate, _Partial)): + if isinstance(placement, (Replicate, Partial)): new_placements.append(placement) else: assert isinstance(placement, Shard) @@ -247,7 +247,7 @@ def map_placements_after_reduction( def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: if isinstance(reduction_op, NormReduction): return _NormPartial(norm_type=reduction_op.norm_type) - return _Partial(reduction_op) + return Partial(reduction_op) def common_reduction_strategy( @@ -1021,44 +1021,20 @@ def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: ) topk_dim = normalize_dim(topk_dim, input_strategy.ndim) - all_mesh_dim_strategies = [] + single_mesh_dim_strategies = [] - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] + # two outputs (values, indices), 1 input + # replicate always works + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) - # two outputs (values, indices), 1 input - # replicate always works - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) + # every dim except topk dim should work + for dim in range(input_strategy.ndim): + if dim != topk_dim: + dim_shardings: List[Placement] = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + # TODO: topk on sharded dim requries non-trival reduction, address it later - # every dim except topk dim should work - for dim in range(input_strategy.ndim): - if dim != topk_dim: - dim_shardings: List[Placement] = [Shard(dim)] * 3 - single_mesh_dim_strategies.append(dim_shardings) - - # TODO: topk on sharded dim requries non-trival reduction, address it later - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - input_spec = spec_list[2] - if is_tensor_shardable(input_shape, input_spec): - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec) - ] - strategy = PlacementStrategy( - output_specs=tuple(spec_list[:2]), - input_specs=(input_spec,), - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strategy) - - return OpStrategy(all_strategies) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index 4a6fc0458119..ab80f783cf5b 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -22,8 +22,8 @@ register_op_strategy, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -460,7 +460,7 @@ def common_pointwise_strategy( common_ndim = len(common_shape) new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim out_placements.append(Shard(new_shard_dim)) - elif isinstance(placement, _Partial) and not linearity: + elif isinstance(placement, Partial) and not linearity: # clear the partial placemnet if op does not support linearity # by default we just replicate the partial, need to see if this # is optimal for all cases diff --git a/torch/distributed/_tensor/ops/random_ops.py b/torch/distributed/_tensor/ops/random_ops.py index 3f33d16cc152..390dc419ecd7 100644 --- a/torch/distributed/_tensor/ops/random_ops.py +++ b/torch/distributed/_tensor/ops/random_ops.py @@ -24,7 +24,7 @@ def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: arg_spec = arg_strategy.output_spec if is_tensor_partial(arg_spec): # TODO: figure out how inplace random op should behave when it's partial - raise RuntimeError(f"{op_schema.op} with _Partial is not supported yet!") + raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec)) return random_strategy diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 54a607d58c55..d2feb19ba2f9 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -1,10 +1,11 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -import itertools from typing import cast, List, Optional, Sequence, Tuple import torch from torch.distributed._tensor._op_schema import ( + _is_inplace_op, OpSchema, OpStrategy, OutputSharding, @@ -16,18 +17,17 @@ from torch.distributed._tensor.ops.common_rules import pointwise_rule from torch.distributed._tensor.ops.embedding_ops import _MaskPartial from torch.distributed._tensor.ops.utils import ( - generate_redistribute_costs, + expand_to_full_mesh_op_strategy, is_tensor_dim_sharded, is_tensor_evenly_shardable, is_tensor_partial, - is_tensor_shardable, normalize_dim, register_op_strategy, register_prop_rule, ) from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -103,7 +103,7 @@ def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: output_spec = DTensorSpec( mesh=arg_spec.mesh, placements=tuple( - Replicate() if isinstance(p, _Partial) else p + Replicate() if isinstance(p, Partial) else p for p in arg_spec.placements ), ) @@ -154,7 +154,7 @@ def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: output_spec = DTensorSpec( mesh=arg_spec.mesh, placements=tuple( - Replicate() if isinstance(p, _Partial) else p + Replicate() if isinstance(p, Partial) else p for p in arg_spec.placements ), ) @@ -361,6 +361,32 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType return OpStrategy([PlacementStrategy(replicate_spec)]) +@register_op_strategy( + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], + schema_info=RuntimeSchemaInfo(1), +) +def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + if len(op_schema.args_strategy) < 3: + # scatter_.src/scatter.src with src be float number instead of tensor + all_replicate: List[Placement] = [Replicate()] * 3 + else: + all_replicate = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + # TODO: see if we can support input sharding pattern + inplace_op = _is_inplace_op(op_schema.op) + + op_strategy = expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op + ) + return op_strategy + + @register_op_strategy(aten.gather.default) def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_strategy = cast(OpStrategy, op_schema.args_schema[0]) @@ -370,59 +396,33 @@ def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: input_shape = input_strategy.shape index_shape = index_strategy.shape - all_mesh_dim_strategies = [] - - for mesh_dim in range(mesh.ndim): - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, input, index] - # first we always have replicate all for inputs and output - all_replicate: List[Placement] = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # input sharding, input sharded, index accepts mask partial, output follows index - # this only works when the input is sharded on the gather dimension, and - # index has size 1 on the gather dimension - if index_shape[dim] == 1: - index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) - input_sharding = [ - index_partial_placement, - Shard(dim), - index_partial_placement, - ] - single_mesh_dim_strategies.append(input_sharding) - - # index sharding, input replicated, index sharded, output follows index - # this only works when the sharding dimension is the gather dimension - index_sharding = [Shard(dim), Replicate(), Shard(dim)] - single_mesh_dim_strategies.append(index_sharding) - - all_mesh_dim_strategies.append(single_mesh_dim_strategies) - - strategy_combs = itertools.product(*all_mesh_dim_strategies) - - all_strategies = [] - for strategy_comb in strategy_combs: - spec_list = [] - for specs in zip(*strategy_comb): - spec_list.append(DTensorSpec(mesh, tuple(specs))) - - if is_tensor_shardable(input_shape, spec_list[1]) and is_tensor_shardable( - index_shape, spec_list[2] - ): - input_spec, index_spec = spec_list[1:] - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec), - generate_redistribute_costs(index_strategy, index_spec), - ] - strat = PlacementStrategy( - output_specs=spec_list[0], - input_specs=spec_list[1:], - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strat) - - return OpStrategy(all_strategies) + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: List[Placement] = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim]) + input_sharding = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) def _derive_follow_placements_from_tuple_strategy( @@ -613,7 +613,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding: # 2. Other dimensions of values_spec can remain sharded if they are so. # For indices: # Indices can be either sharded or replicated. All index tensors need to be sharded - # in a compatible way, following the pointwise rule (including resolving _Partial + # in a compatible way, following the pointwise rule (including resolving Partial # into either sharded or replicated) values_spec, multi_indices_spec = op_schema.args_schema @@ -683,7 +683,7 @@ def place(vp: Placement, ip: Placement) -> Placement: ) if isinstance(ip, Shard): return Shard(ip.dim + insert_dim) - # _Partial or Replicated + # Partial or Replicated return vp value_placements = tuple( @@ -737,13 +737,13 @@ def split_rule(op_schema: OpSchema) -> OutputSharding: dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 dim = normalize_dim(dim, ndim) - # TODO: tensor to split cannot have _Partial + # TODO: tensor to split cannot have Partial # in its placements for now. Will need to # support in future. if input_spec.sums: raise NotImplementedError( f"splitting distributed tensor with " - f"_Partial placement is not implemented!\n" + f"Partial placement is not implemented!\n" f"DTensorSpec={input_spec}" ) diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index 149e690cedc4..ecc3c5d06bee 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -1,15 +1,23 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import functools +import itertools import operator from typing import cast, Iterable, List, Sequence, Tuple, Union import torch from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import OpStrategy, RuntimeSchemaInfo +from torch.distributed._tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementStrategy, + RuntimeSchemaInfo, +) from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed._tensor.placement_types import ( - _Partial, DTensorSpec, + Partial, Placement, Replicate, Shard, @@ -193,7 +201,7 @@ def map_placements_after_broadcast( """Map each placement based on the output shape after broadcast.""" new_placements: List[Placement] = [] for placement in placements: - if isinstance(placement, (Replicate, _Partial)): + if isinstance(placement, (Replicate, Partial)): new_placements.append(placement) else: assert isinstance(placement, Shard) @@ -224,3 +232,55 @@ def generate_redistribute_costs( redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec)) return redistribute_costs + + +def expand_to_full_mesh_op_strategy( + mesh: DeviceMesh, + op_schema: OpSchema, + single_mesh_dim_strategies: List[List[Placement]], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [] + for specs in zip(*strategy_comb): + spec_list.append(DTensorSpec(mesh, tuple(specs))) + + input_specs = spec_list[input_index:] + input_args_strategy = op_schema.args_strategy + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the placement strategy to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + strategy = PlacementStrategy( + output_specs=tuple(spec_list[:input_index]) + if input_index > 1 + else spec_list[0], + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 449526f13a43..7161988adf25 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass from typing import ( @@ -439,7 +440,7 @@ def dim_reduction( ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool ) -> DimMap: """ - General fallback for reduction ops where _Partial() does not apply. + General fallback for reduction ops where Partial() does not apply. This will cause incoming tensor to be replicated on the reducing dimensions. """ diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index d90bcb6c258a..31e280c2f5b8 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass @@ -32,7 +33,7 @@ def is_replicate(self) -> bool: return isinstance(self, Replicate) def is_partial(self) -> bool: - return isinstance(self, _Partial) + return isinstance(self, Partial) @dataclass(frozen=True) @@ -412,7 +413,7 @@ class Partial(Placement): def _reduce_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: - # _Partial placement contract #1: + # Partial placement contract #1: # _reduce_value: reduce the value of the tensor on the mesh dimension return funcol.all_reduce( tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) @@ -425,7 +426,7 @@ def _reduce_shard_value( mesh_dim: int, shard_spec: Placement, ) -> torch.Tensor: - # _Partial placement contract #2: + # Partial placement contract #2: # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension shard_spec = cast(Shard, shard_spec) return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) @@ -433,7 +434,7 @@ def _reduce_shard_value( def _partition_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: - # _Partial placement contract #3: + # Partial placement contract #3: # _partition_value: partition the value of a replicated tensor on the mesh dimension # _partition_value is the conjugate operation of _reduce_value @@ -446,7 +447,7 @@ def _partition_value( return tensor / num_chunks def __eq__(self, other: object) -> bool: - if not isinstance(other, _Partial): + if not isinstance(other, Partial): return False return self.reduce_op == other.reduce_op @@ -457,7 +458,7 @@ def __repr__(self) -> str: """ machine readable representation of the Partial placement """ - return f"_Partial({self.reduce_op})" + return f"Partial({self.reduce_op})" def __str__(self) -> str: """ @@ -668,7 +669,7 @@ def from_dim_map( # find all mesh dims that need pending reductions for s in sums: - placements[s] = _Partial() + placements[s] = Partial() for i, m in enumerate(dim_map): if m >= 0: diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/_tensor/random.py index f2eff6bb5ec3..ed331736c5ce 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/_tensor/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib import warnings diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index d8b6765230a1..10f70c9ce18e 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict from itertools import chain diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 0d55b37c0044..86ab1de003db 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from enum import auto, Enum from functools import partial @@ -234,7 +235,7 @@ def checkpoint_wrapper( f"{CheckpointImpl.REENTRANT} will soon be removed as " "the default and eventually deprecated.", FutureWarning, - stacklevel=1, + stacklevel=2, ) return CheckpointWrapper( module, diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 53c8eb7e163f..d370fabafc37 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch import torch.distributed as dist diff --git a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py index 8044557e71dc..1afbb8d7967f 100644 --- a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py +++ b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from abc import ABC, abstractmethod import inspect from typing import Dict, Type diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index 911cc8255ee5..c421076bde3e 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch import torch.distributed as dist diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 570aa34cf02e..2366a9d28c13 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from functools import partial diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 52f9b419ab14..8ab58cb58442 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref from typing import Any, Callable, List, Optional diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 791061e34f90..621e46fc1989 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Callable, cast, Tuple import torch @@ -85,7 +86,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) @@ -134,7 +135,7 @@ def decompress(fut): decompressed_tensor.copy_(value) return decompressed_tensor - if torch.compiler.is_compiling(): + if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index dc7e5ee2fdc5..76d4cd6de2bd 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Callable, List, no_type_check import torch diff --git a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py index 218ee08dbd46..3528f3987479 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import torch diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 9d2d5649f745..fbc3b9e8739e 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import defaultdict import logging import math diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py index 9d5cd573eed6..cbc1290e76e4 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist from torch import nn diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 7c1aa3cac5ac..2936747a1c6e 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod from types import TracebackType diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index e1f8c0800c50..178efd1dbad9 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod from typing import Union, Iterable, Dict diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py index 637ae144b379..02802466ab62 100644 --- a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright 2022 Cruise LLC import logging import warnings diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index eaa1cd2e968d..de1977959d21 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # flake8: noqa C101 import itertools from typing import Union, Iterable, Dict, Iterator diff --git a/torch/distributed/argparse_util.py b/torch/distributed/argparse_util.py index a214dadd312a..c475eebf2127 100644 --- a/torch/distributed/argparse_util.py +++ b/torch/distributed/argparse_util.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index e94ab1bb9d63..6546c38a37b9 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import torch diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index 7294fce61ff3..60f71e12213b 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import io import os diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index 5d2aa9b62991..c1cc67b40681 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/checkpoint/_dedup_save_plans.py b/torch/distributed/checkpoint/_dedup_save_plans.py index 2160c7dc366d..16d46e73baff 100644 --- a/torch/distributed/checkpoint/_dedup_save_plans.py +++ b/torch/distributed/checkpoint/_dedup_save_plans.py @@ -11,7 +11,10 @@ __all__ = ["dedup_save_plans"] -def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]: +def dedup_save_plans( + all_plans: List[SavePlan], + save_to_lowest_rank: bool = False, +) -> List[SavePlan]: """ Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. @@ -29,7 +32,12 @@ def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]: to_remove: List[Set] = [set() for _ in range(len(all_plans))] plan_to_size = [0] * len(all_plans) for write_item_idx, plan_indices in write_item_to_plan_indices.items(): - select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_to_size[plan_idx]) + if save_to_lowest_rank: + select_plan_idx = min(plan_indices) + else: + select_plan_idx = min( + plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] + ) write_item = write_item_idx_to_write_item[write_item_idx] # essentially ignores the storage size of anything that is not a tensor, since diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py index 828685103261..660196bc28de 100644 --- a/torch/distributed/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import traceback as tb from typing import Any, Dict, Tuple diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index c9590c38d3e6..83b76718a6b7 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import dataclasses @@ -67,11 +68,12 @@ def __init__( flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, dedup_replicated_tensors: Optional[bool] = None, + dedup_save_to_lowest_rank: bool = False, ) -> None: self.flatten_state_dict = flatten_state_dict self.flatten_sharded_tensors = flatten_sharded_tensors self.mappings = {} - + self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank if dedup_replicated_tensors is not None: logger.warning( "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " @@ -103,7 +105,7 @@ def create_local_plan(self) -> SavePlan: def create_global_plan( self, all_plans: List[SavePlan] ) -> Tuple[List[SavePlan], Metadata]: - all_plans = dedup_save_plans(all_plans) + all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) global_plan, metadata = create_default_global_save_plan(all_plans) diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index d4e2b5268de7..5eaba9a67227 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] import os diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py index 9e2438c47bb8..38c637d3a4fd 100644 --- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py +++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates """ diff --git a/torch/distributed/checkpoint/examples/stateful_example.py b/torch/distributed/checkpoint/examples/stateful_example.py index 6c23dc3e298f..6c76ec436364 100644 --- a/torch/distributed/checkpoint/examples/stateful_example.py +++ b/torch/distributed/checkpoint/examples/stateful_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index aa25d1fb5369..4d512891f122 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import io diff --git a/torch/distributed/checkpoint/format_utils.py b/torch/distributed/checkpoint/format_utils.py index 41ebaf8be61b..e82284704565 100644 --- a/torch/distributed/checkpoint/format_utils.py +++ b/torch/distributed/checkpoint/format_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import os from enum import Enum diff --git a/torch/distributed/checkpoint/logger.py b/torch/distributed/checkpoint/logger.py index 08e2bee2a78b..270240490c99 100644 --- a/torch/distributed/checkpoint/logger.py +++ b/torch/distributed/checkpoint/logger.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import time from typing import Any, Callable, Dict, List, TypeVar diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py index bbcfcbc01e17..b3bc7a580dad 100644 --- a/torch/distributed/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os from dataclasses import dataclass, field from enum import Enum diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index ad2466a50ee8..5eec8bf75466 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -426,3 +426,39 @@ def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: The contents of tensor will follow its device synchronization model. """ pass + + +class _Checkpointable: + """ + Interface for checkpointable objects. + This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface. + """ + + @abc.abstractmethod + def _create_write_items(self, fqn: str, object: Any) -> List[WriteItem]: + """ + Return a list of WriteItems based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_write_items is not implemented" + ) + + @abc.abstractmethod + def _create_chunk_list(self, tensor: torch.Tensor) -> List[ChunkStorageMetadata]: + """ + Return a list of `ChunkStorageMetadata` based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_chunk_list is not implemented" + ) + + @abc.abstractmethod + def _get_tensor_shard( + self, tensor: torch.Tensor, index: MetadataIndex + ) -> torch.Tensor: + """ + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + raise NotImplementedError( + "_Checkpointable._get_tensor_shard is not implemented" + ) diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index 5829ab6111e2..4bbe26876c88 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, cast, List import torch @@ -8,6 +9,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.checkpoint.planner import _Checkpointable from torch.utils._pytree import tree_map_only @@ -217,7 +219,12 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: - if isinstance(object, DTensor): + if isinstance(object, _Checkpointable): + return object._create_write_items(fqn, object) + elif isinstance(object, DTensor): + # DTensor can contain a local tensor that is a tensor subclass + if isinstance(object.to_local(), _Checkpointable): + return object.to_local()._create_write_items(fqn, object) # type: ignore[arg-type] return [_create_write_items_for_dtensor(fqn, object)] elif isinstance(object, ShardedTensor): return [ @@ -242,7 +249,12 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]: - if isinstance(tensor, DTensor): + if isinstance(tensor, _Checkpointable): + local_chunks = tensor._create_chunk_list(tensor) + elif isinstance(tensor, DTensor): + # DTensor can contain a local tensor that is a tensor subclass + if isinstance(tensor.to_local(), _Checkpointable): + return tensor.to_local()._create_chunk_list(tensor) # type: ignore[arg-type] local_chunks = [_create_chunk_from_dtensor(tensor)] elif isinstance(tensor, ShardedTensor): local_chunks = [ diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py index 1ebb0ba57d73..a1bf112f1795 100644 --- a/torch/distributed/checkpoint/resharding.py +++ b/torch/distributed/checkpoint/resharding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple from torch.distributed.checkpoint.metadata import ChunkStorageMetadata diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index e7072d623012..cc55b1a5b42c 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -1,6 +1,8 @@ +# mypy: allow-untyped-defs import contextlib import functools import gc +import warnings from dataclasses import asdict, dataclass, field from itertools import chain from typing import ( @@ -52,19 +54,12 @@ from torch.utils._pytree import tree_map_only __all__ = [ - "FLAT_PARAM", - "PG", - "PG_PREFIX", - "STATE", - "STATE_PREFIX", - "PARAMS", "FQNS_T", "PrimitiveType", "ValueType", "DictValueType", "ListDictValueType", "OptimizerStateType", - "gc_context", "StateDictOptions", "get_model_state_dict", "get_optimizer_state_dict", @@ -74,17 +69,13 @@ "set_state_dict", ] -FLAT_PARAM = "_flat_param" -PG = "param_groups" -PG_PREFIX = f"{PG}." -STATE = "state" -STATE_PREFIX = f"{STATE}." -PARAMS = "params" -FQNS_T = Set[str] - -_patched_state_dict: Set[Callable] = set() +_FLAT_PARAM = "_flat_param" +_PG = "param_groups" +_PARAMS = "params" +_STATE = "state" +FQNS_T = Set[str] PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] ValueType = Union[ PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"] @@ -94,14 +85,16 @@ OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] +_patched_state_dict: Set[Callable] = set() + + @contextlib.contextmanager -def gc_context(): +def _gc_context(): is_enabled = gc.isenabled() gc.disable() try: yield finally: - # TODO: add logging for the gc details/time if is_enabled: gc.enable() @@ -123,7 +116,7 @@ class StateDictOptions: won't contain any frozen parameters -- the ``requires_grad`` is False. The default value is False. - - ``keep_submodule_prefixes``: when ``submodules`` is not None, this option + - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option indicates whether to keep the submodule prefixes from the state_dict keys. or example, if the submodule is ``module.pretrain`` and the full FQN of the parameter is ``pretrain.layer1.weight`` of the param. When this option @@ -150,6 +143,7 @@ class StateDictOptions: keep_submodule_prefixes: bool = True strict: bool = True broadcast_from_rank0: bool = False + flatten_optimizer_state_dict: bool = False @dataclass @@ -157,7 +151,6 @@ class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) - all_fqns: Set[str] = field(default_factory=set) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True handle_optim: bool = True @@ -202,9 +195,9 @@ def _get_fqns( if not skip_ddp_prefix: fqn_obj_names.append(curr_obj_name) elif isinstance(curr_obj, FSDP): - if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM: + if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: prefix = ".".join(fqn_obj_names) - flat_param = getattr(curr_obj, FLAT_PARAM) + flat_param = getattr(curr_obj, _FLAT_PARAM) if prefix: prefix = f"{prefix}." return {f"{prefix}{fqn}" for fqn in flat_param._fqns} @@ -274,6 +267,13 @@ def _verify_options( """ Verify the model and options passed by the user and generates _StateDictInfo. """ + if submodules: + warnings.warn( + "Getting submodules only model/optim state_dict is deprecated and " + "will be removed in 2.5. This feature can be achieved by manually " + "filtering out the state_dict returned from get_state_dict.", + FutureWarning, + ) if optim_only and not optims: raise RuntimeError( "Optimizers are not passed in but optim_only is set to True." @@ -284,7 +284,6 @@ def _verify_options( fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[Set[str], torch.Tensor] ] = {} - all_fqns = set() for name, param in _iterate_valid_model_state(model): fqns = _get_fqns(model, name) if not isinstance(param, _EXTRA_STATE): @@ -292,7 +291,6 @@ def _verify_options( for fqn in fqns: if not isinstance(param, _EXTRA_STATE): fqn_param_mapping[fqn] = param - all_fqns.add(fqn) submodule_prefixes: Set[str] = set() if submodules: @@ -332,8 +330,24 @@ def _verify_options( ) state_dict_type = StateDictType.SHARDED_STATE_DICT + @contextlib.contextmanager + def fsdp_state_dict_type_without_warning( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ): + with warnings.catch_warnings(): + with FSDP.state_dict_type( + module=module, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ): + yield + fsdp_context = functools.partial( - FSDP.state_dict_type, + fsdp_state_dict_type_without_warning, module=model, state_dict_type=state_dict_type, state_dict_config=state_dict_config, @@ -345,7 +359,6 @@ def _verify_options( return _StateDictInfo( **asdict(options), fqn_param_mapping=fqn_param_mapping, - all_fqns=all_fqns, submodule_prefixes=submodule_prefixes, fsdp_context=fsdp_context, fsdp_modules=cast(List[nn.Module], fsdp_modules), @@ -382,7 +395,7 @@ def _verify_state_dict( if info.handle_optim: if ( - not (optim_state_dict and optim_state_dict[STATE]) + not optim_state_dict and not (info.cpu_offload and info.full_state_dict) and (not info.broadcast_from_rank0) ): @@ -392,9 +405,9 @@ def _verify_state_dict( ) for key in model_state_dict.keys(): - if FLAT_PARAM in key: + if _FLAT_PARAM in key: raise RuntimeError( - f"{key} contains {FLAT_PARAM}. This can happen if the model " + f"{key} contains {_FLAT_PARAM}. This can happen if the model " "is not the root module." ) @@ -406,6 +419,24 @@ def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Ca return call +def _maybe_full_or_cpu_state_dict( + state_dict: Dict[str, Any], info: _StateDictInfo +) -> Dict[str, Any]: + if info.full_state_dict: + ranks_only = ( + tuple() + if (not info.cpu_offload or not torch.distributed.is_initialized()) + else (0,) + ) + return _gather_state_dict( + state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + ) + elif info.cpu_offload: + return _offload_state_dict_to_cpu(state_dict) + else: + return state_dict + + def _get_model_state_dict( model: nn.Module, info: _StateDictInfo ) -> Dict[str, ValueType]: @@ -470,15 +501,7 @@ def verify(key, fqn) -> bool: if torch.is_tensor(p) and p.is_meta: state_dict.pop(key) - if info.full_state_dict: - ranks_only = tuple() if not info.cpu_offload else (0,) - return _gather_state_dict( - state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only - ) - elif info.cpu_offload: - return _offload_state_dict_to_cpu(state_dict) - else: - return state_dict + return _maybe_full_or_cpu_state_dict(state_dict, info) def _load_model_state_dict( @@ -512,7 +535,9 @@ def _load_model_state_dict( else: assert device == value.device assert device is not None - _broadcast_state_dict(state_dict, local_state_dict, device=device) + _broadcast_state_dict( + state_dict, local_state_dict, device=device, strict=info.strict + ) for fqn, local_state in local_state_dict.items(): state_dict[fqn] = local_state @@ -534,7 +559,7 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None: return for param_group in optim.param_groups: - for param in param_group[PARAMS]: + for param in param_group[_PARAMS]: if param.grad is not None: raise RuntimeError( "state_dict can only be used if the optimizer " @@ -563,6 +588,115 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None: optim.zero_grad(set_to_none=True) +def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]: + """ + This API flattens the optimizer state_dict to support optimizer resharding for + MPMD, e.g., pipeline parallelism. + + Without the API, the original optimizer state_dict looks like: + { + "state": { + "layer1.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + "layer2.weight": { + "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor + }, + }, + "param_group": [ + { + "lr": 0.0, + "betas": (0.9, 0.95), ..., + "params": ["layer1.weight", "layer2.weight"] + } + ] + } + + With this API, the optimizer state_dict looks like: + { + "state.layer1.weight.step": 10, + "state.layer2.weight.step": 10, + "state.layer1.weight.exp_avg": SomeTensor, + "state.layer2.weight.exp_avg": SomeTensor, + "state.layer1.weight.exp_avg_sq": SomeTensor, + "state.layer2.weight.exp_avg_sq": SomeTensor, + "param_group.layer1.weight.lr" : 0.1, + "param_group.layer2.weight.lr" : 0.1, + "param_group.layer1.weight.betas" : (0.9, 0.95), + "param_group.layer2.weight.betas" : (0.9, 0.95), + } + + Note that if any of the value is a container, like the betas in the example, + this API won't flattent it. + """ + + def _raise_if_type_not_supported(v): + if not isinstance(v, (torch.Tensor, int, float)): + raise NotImplementedError( + "Flattening optimizer state_dict only supports " + "tensor, int, float states now. " + f"Type is {type(v)}." + ) + + ret: Dict[str, ValueType] = {} + for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): + for k, v in cast(DictValueType, state).items(): + _raise_if_type_not_supported(v) + ret[f"{_STATE}.{fqn}.{k}"] = v + + for param_group in cast(ListDictValueType, state_dict[_PG]): + fqns = param_group.pop(_PARAMS) + for fqn in cast(List[str], fqns): + for k, v in param_group.items(): + ret[f"{_PG}.{fqn}.{k}"] = v + return ret + + +def _unflatten_optim_state_dict( + optim: torch.optim.Optimizer, + state_dict: Dict[str, ValueType], + info: _StateDictInfo, +) -> OptimizerStateType: + """ + This API unflattens the state_dict generated by _flatten_optim_state_dict(). + See the docstring of _flatten_optim_state_dict() for more detail. + """ + state: DictValueType = {} + pg_state: ListDictValueType = [] + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} + + for param_group in optim.param_groups: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: + for fqn in info.fqn_param_mapping[param]: + params = pg_state[-1][_PARAMS] + assert isinstance(params, list) # typing + params.append(fqn) + if not param.requires_grad: + continue + state[fqn] = {} + for state_name in optim.state[param].keys(): + cast(DictValueType, state[fqn])[state_name] = state_dict[ + f"{_STATE}.{fqn}.{state_name}" + ] + + first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0] + for k in param_group.keys(): + if k == _PARAMS: + continue + value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] + if k not in pg_state[-1]: + pg_state[-1][k] = value + elif pg_state[-1][k] != value: + raise RuntimeError( + "All the parameters in the same parameter group should have " + f"the same saved param_group value. But {first_param_fqn}.{k} " + f"is {value} while other(s) is {pg_state[-1][k]}." + ) + + return return_osd + + def _get_optim_state_dict( model: nn.Module, optimizers: Tuple[torch.optim.Optimizer, ...], @@ -571,7 +705,7 @@ def _get_optim_state_dict( if not info.handle_optim: return {} - optim_state_dict: OptimizerStateType = {STATE: {}, PG: []} + optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} for optim in optimizers: _init_optim_state(optim) osd = _state_dict_fn(optim, "state_dict")() @@ -585,14 +719,14 @@ def _get_optim_state_dict( # We can only use a string replacment without correctness check. if not osd: continue - for k in list(osd[STATE].keys()): + for k in list(osd[_STATE].keys()): if "_orig_mod" in k: - osd[STATE][k.replace("_orig_mod.", "")] = osd[STATE].pop(k) - for g in osd[PG]: - params = [k.replace("_orig_mod.", "") for k in g[PARAMS]] - g[PARAMS] = params + osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) + for g in osd[_PG]: + params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] + g[_PARAMS] = params else: - params = list(chain.from_iterable(g[PARAMS] for g in optim.param_groups)) + params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) param_pid_mapping = dict(zip(params, range(len(params)))) fqn_pid_mapping = {} for key, param in model.named_parameters(): @@ -605,28 +739,25 @@ def _get_optim_state_dict( fqn_pid_mapping[fqn] = pid fqn_pid_mapping[pid] = fqn - for key in list(osd[STATE].keys()): + for key in list(osd[_STATE].keys()): fqn = fqn_pid_mapping[key] - osd[STATE][fqn] = osd[STATE].pop(key) + osd[_STATE][fqn] = osd[_STATE].pop(key) - for group in osd[PG]: - group[PARAMS] = [fqn_pid_mapping[pid] for pid in group[PARAMS]] + for group in osd[_PG]: + group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] if not osd: continue - cast(DictValueType, optim_state_dict[STATE]).update(osd[STATE]) - cast(ListDictValueType, optim_state_dict[PG]).extend(osd[PG]) + cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) + cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) - if info.full_state_dict: - ranks_only = tuple() if not info.cpu_offload else (0,) - return _gather_state_dict( - optim_state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only + if info.flatten_optimizer_state_dict: + optim_state_dict = cast( + OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) ) - elif info.cpu_offload: - return _offload_state_dict_to_cpu(optim_state_dict) - else: - return optim_state_dict + + return _maybe_full_or_cpu_state_dict(optim_state_dict, info) def _split_optim_state_dict( @@ -652,30 +783,37 @@ def _split_optim_state_dict( state: DictValueType = {} pg_state: ListDictValueType = [] - return_osd: OptimizerStateType = {STATE: state, PG: pg_state} + return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} pg_mapping: Dict[int, int] = {} + if all( + isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() + ): + return optim_state_dict + for param_group in optim.param_groups: - pg_state.append({PARAMS: []}) - for param in param_group[PARAMS]: + pg_state.append({_PARAMS: []}) + for param in param_group[_PARAMS]: for fqn in info.fqn_param_mapping[param]: - params = pg_state[-1][PARAMS] + params = pg_state[-1][_PARAMS] assert isinstance(params, list) params.append(fqn) if param.requires_grad: - state[fqn] = cast(DictValueType, optim_state_dict[STATE])[fqn] - for loaded_param_group in cast(ListDictValueType, optim_state_dict[PG]): - params = loaded_param_group[PARAMS] + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + params = loaded_param_group[_PARAMS] assert isinstance(params, list) if fqn in params: - pg_mapping[id(loaded_param_group)] = len(return_osd[PG]) - 1 + pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 - for param_group in cast(ListDictValueType, optim_state_dict[PG]): + for param_group in cast(ListDictValueType, optim_state_dict[_PG]): idx = pg_mapping.get(id(param_group), -1) if idx == -1: continue for key, value in param_group.items(): - if key == PARAMS: + if key == _PARAMS: continue # TODO: check if value is the same if exists. pg_state[idx][key] = value @@ -695,7 +833,14 @@ def _load_optim_state_dict( for optim in optimizers: _init_optim_state(optim) if state_dict: - optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) + if _STATE in state_dict: + optim_state_dict = _split_optim_state_dict( + model, optim, state_dict, info + ) + else: + optim_state_dict = _unflatten_optim_state_dict( + optim, cast(Dict[str, ValueType], state_dict), info + ) else: optim_state_dict = {} if info.fsdp_modules: @@ -712,13 +857,13 @@ def _load_optim_state_dict( assert len(fqns) == 1 fqn = fqns.pop() fqn_with_compiler = fqns_with_compiler.pop() - for g in optim_state_dict[PG]: + for g in optim_state_dict[_PG]: val = cast(Dict[str, Any], g) params = [ - key.replace(fqn, fqn_with_compiler) for key in val[PARAMS] + key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] ] - val[PARAMS] = params - osd_state = cast(DictValueType, optim_state_dict[STATE]) + val[_PARAMS] = params + osd_state = cast(DictValueType, optim_state_dict[_STATE]) for k in list(osd_state.keys()): if fqn in k: osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) @@ -747,13 +892,22 @@ def _device(t): flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) + # The modifications listed seek to address the problem where optim might possess + # dissimilar parameters in comparison to optim_state_dict. This is achieved by + # incorporating differential parameters within local, which may result in optim + # having additional parameters ultimately. + for optim_key in flatten_osd.keys(): + if optim_key not in flatten_local_osd: + assert optim_key in osd_mapping + flatten_local_osd[optim_key] = flatten_osd[optim_key] + local_osd_mapping[optim_key] = osd_mapping[optim_key] optim_state_dict = _unflatten_state_dict( flatten_local_osd, local_osd_mapping ) # Note that we do not have to convert the FQN back to param id here if - # order in optim.param_groups[idx][PARAMS] is the same as the one in - # optim_state_dict[PG][idx][PARAMS]. + # order in optim.param_groups[idx][_PARAMS] is the same as the one in + # optim_state_dict[_PG][idx][_PARAMS]. _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) @@ -770,7 +924,7 @@ def get_model_state_dict( Args: model (nn.Module): the nn.Module to the model. - submodules: Optional[Set[nn.Module]]: only return the model parameters + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See @@ -781,7 +935,7 @@ def get_model_state_dict( :rtype: typing.Dict[str, ValueType] """ - with gc_context(): + with _gc_context(): info = _verify_options( model, tuple(), @@ -810,7 +964,7 @@ def get_optimizer_state_dict( model (nn.Module): the nn.Module to the model. optimizers (Union[None, Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. - submodules: Optional[Set[nn.Module]]: only return the model parameters + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See @@ -821,7 +975,7 @@ def get_optimizer_state_dict( :rtype: OptimizerStateType """ - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) @@ -897,7 +1051,7 @@ def get_state_dict( model (nn.Module): the nn.Module to the model. optimizers (Union[None, Optimizer, Iterable[Optimizer]]): The optimizers that are used to optimize ``model``. - submodules: Optional[Set[nn.Module]]: only return the model parameters + submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters that belong to the submodules. options (StateDictOptions): the options to control how model state_dict and optimizer state_dict should be returned. See @@ -909,7 +1063,7 @@ def get_state_dict( :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] """ - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) @@ -936,6 +1090,13 @@ def _unflatten_model_state_dict( return {} if isinstance(next(iter(state_dict.keys())), nn.Module): + warnings.warn( + "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" + "is deprecated and will be removed in 2.5. If you need this " + "feature, please preprocessing the model_state_dict to achieve the " + "same functionality.", + FutureWarning, + ) cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) new_state_dict: Dict[str, ValueType] = {} for submodule, sub_state_dict in cast_state_dict.items(): @@ -986,7 +1147,7 @@ def set_model_state_dict( model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( model, model_state_dict ) - with gc_context(): + with _gc_context(): info = _verify_options(model, tuple(), optim_only=False, options=options) _verify_state_dict(model_state_dict, {}, info) @@ -996,8 +1157,8 @@ def set_model_state_dict( def set_optimizer_state_dict( model: nn.Module, optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], - *, optim_state_dict: OptimizerStateType, + *, options: Optional[StateDictOptions] = None, ) -> None: """Load the optimizers state_dict. @@ -1020,7 +1181,7 @@ def set_optimizer_state_dict( :type optim_state_dict: typing.OptimizerStateType """ - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) @@ -1077,7 +1238,7 @@ def set_state_dict( model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( model, model_state_dict ) - with gc_context(): + with _gc_context(): optimizers = ( (optimizers,) if isinstance(optimizers, torch.optim.Optimizer) diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index 6c1546e1cc0f..f443f73f02d6 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import warnings from typing import Any, cast, Dict, Optional, Set, Union @@ -177,6 +178,7 @@ def load( elem = state_dict[key] if isinstance(elem, Stateful): elem.load_state_dict(statetful_sd[key]) + state_dict[key] = statetful_sd[key] def _load_state_dict( diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index b715fcdd9ae5..6d04044391ab 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import os import warnings @@ -274,7 +275,7 @@ def _save_state_dict( planner = DefaultSavePlanner() assert planner is not None - global_metatadata = None + global_metadata = None ckpt_kwargs = {} if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: @@ -305,10 +306,10 @@ def local_step(): @_dcp_method_logger(**ckpt_kwargs) def global_step(all_local_plans): - nonlocal global_metatadata + nonlocal global_metadata assert planner is not None - all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) + all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) all_local_plans = storage_writer.prepare_global_plan(all_local_plans) return all_local_plans @@ -325,8 +326,8 @@ def write_data(): @_dcp_method_logger(**ckpt_kwargs) def finish_checkpoint(all_results): - assert global_metatadata is not None - storage_writer.finish(metadata=global_metatadata, results=all_results) - return global_metatadata + assert global_metadata is not None + storage_writer.finish(metadata=global_metadata, results=all_results) + return global_metadata return distW.all_reduce("write", write_data, finish_checkpoint) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index d781d9839bea..0efba34a551b 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import cProfile import inspect import io @@ -14,6 +15,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.planner import _Checkpointable from .api import ( _is_wrapped_exception, @@ -301,7 +303,12 @@ def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: - if isinstance(tensor, DTensor): + if isinstance(tensor, _Checkpointable): + return tensor._get_tensor_shard(tensor, index) + elif isinstance(tensor, DTensor): + # DTensor can contain a local tensor that is a tensor subclass + if isinstance(tensor.to_local(), _Checkpointable): + return tensor.to_local()._get_tensor_shard(tensor, index) # type: ignore[arg-type] return tensor.to_local() if isinstance(tensor, ShardedTensor): return _find_shard(tensor, index).tensor diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 57b8fa1cf564..e46356a36894 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging import math @@ -62,6 +63,9 @@ class _MeshEnv(threading.local): def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {} + self.mesh_dim_group_options: Dict[ + int, Tuple[str, Optional[ProcessGroup.Options]] + ] = {} def get_current_mesh(self) -> "DeviceMesh": if len(self.mesh_stack) == 0: @@ -69,31 +73,46 @@ def get_current_mesh(self) -> "DeviceMesh": return self.mesh_stack[-1] def create_child_mesh( - self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str + self, parent_mesh: "DeviceMesh", submesh_dim_names: Tuple[str, ...] ) -> "DeviceMesh": - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - cur_rank = device_mesh.get_rank() - pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( - -1, device_mesh.mesh.size(mesh_dim) - ) + # submesh_dims are the mesh dimension of the submesh in the parent mesh. + submesh_dims = [ + not_none(parent_mesh.mesh_dim_names).index(mesh_dim_name) + for mesh_dim_name in submesh_dim_names + ] + submesh_dim_sizes = [ + parent_mesh.mesh.size(mesh_dim) for mesh_dim in submesh_dims + ] - for mesh_1d in pg_ranks_by_dim: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_1d, - mesh_dim_names=(mesh_dim_name,), + mesh_dims_remained = list(range(parent_mesh.mesh.ndim)) + for submesh_dim in submesh_dims: + mesh_dims_remained.remove(submesh_dim) + + # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *sub_mesh_dims] + # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with + # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank. + pg_ranks_by_dim = parent_mesh.mesh.permute( + *mesh_dims_remained, *submesh_dims + ).reshape(-1, *submesh_dim_sizes) + + cur_rank = parent_mesh.get_rank() + for mesh_nd in pg_ranks_by_dim: + submesh = DeviceMesh( + parent_mesh.device_type, + mesh_nd, + mesh_dim_names=submesh_dim_names, _init_backend=False, ) - if cur_rank in mesh_1d: - res_sub_mesh = sub_mesh + if cur_rank in mesh_nd: + res_submesh = submesh + + res_submesh._parent_mesh = parent_mesh # type: ignore[possibly-undefined] + res_submesh._dim_group_infos = [ + parent_mesh._dim_group_infos[mesh_dim] for mesh_dim in submesh_dims # type: ignore[possibly-undefined] + ] + self.child_to_parent_mapping[res_submesh] = parent_mesh - res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined] - res_sub_mesh._parent_mesh = device_mesh - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - # We need to update the mappings after the child mesh hash update. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh + return res_submesh def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: return self.child_to_parent_mapping.get(device_mesh, None) @@ -140,6 +159,14 @@ def get_mesh_dim_by_name( ) return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) + def _set_mesh_dim_group_options( + self, + dim: int, + backend: str, + pg_options: Optional[ProcessGroup.Options] = None, + ) -> None: + self.mesh_dim_group_options[dim] = (backend, pg_options) + _mesh_resources: _MeshEnv = _MeshEnv() def _get_device_handle(device_type: str = "cuda"): @@ -297,10 +324,24 @@ def _init_process_groups(self): for dim_mesh in pg_ranks_by_dim: subgroup_ranks = dim_mesh.tolist() + # Respect dim group options specified via _MeshEnv.set_dim_group_options(). + # Inherit from the parent group if no options are specified for the group. + if dim in _mesh_resources.mesh_dim_group_options: + ( + backend, + pg_options, + ) = _mesh_resources.mesh_dim_group_options[dim] + else: + backend, pg_options = None, None + # We temporarily revert the re-use subgroup, since it breaks two internal tests. # Temporarily reverting to resolve test timeout while root-causing. # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. - dim_group = new_group(ranks=subgroup_ranks) + dim_group = new_group( + ranks=subgroup_ranks, + backend=backend, + pg_options=pg_options, + ) # only add to dim_groups if the current rank in the subgroup if self.get_rank() in subgroup_ranks: @@ -367,14 +408,16 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) - def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": + def __getitem__( + self, mesh_dim_names: Union[str, Tuple[str, ...]] + ) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_name given to create a child DeviceMesh. Args: - mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh - to create a child DeviceMesh for. + mesh_dim_name (Union[str, Tuple[str]]): the name or the tuple of names of the + mesh dimension of the parent DeviceMesh to create the child DeviceMesh for. Returns: A :class:`DeviceMesh` object @@ -395,60 +438,83 @@ def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": >>> # of cross-host(dim 0), and within-host (dim 1). >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) """ - if self.mesh.ndim == 1: - if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]: - return self - else: - raise RuntimeError( - f"Invalid mesh_dim_name {mesh_dim_name} specified." - ) + if not self.mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) - mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name) - submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) + error_msg = ( + f"Invalid mesh_dim_name {mesh_dim_names} specified. " + f"Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}." + ) + + if mesh_dim_names == self.mesh_dim_names: + return self + elif len(mesh_dim_names) > len(self.mesh_dim_names) or not all( + mesh_dim_name in self.mesh_dim_names for mesh_dim_name in mesh_dim_names + ): + raise KeyError(error_msg) + # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names + # of the current DeviceMesh. + else: + outermost_dim_name = mesh_dim_names[0] + outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) + for i, j in zip( + mesh_dim_names, + self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + ): + if i != j: + raise KeyError(error_msg) + + submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) return submesh - def get_group( - self, mesh_dim: Optional[Union[int, str]] = None - ) -> Union[ProcessGroup, List[ProcessGroup]]: + def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: """ - Returns a list of ProcessGroups corresponding to the mesh dimensions, or - returns a single ProcessGroup if mesh_dim is specified or the given mesh has - only one mesh dimension. + Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the + DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. Args: mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index of the mesh dimension. Default is None. Returns: - A list of :class:`ProcessGroup` object when `mesh_dim` is not specified for - a DeviceMesh with more than 1 dimension; otherwise, returns a single - :class:`ProcessGroup` object. + A :class:`ProcessGroup` object. """ if not hasattr(self, "_dim_group_infos"): raise RuntimeError("DeviceMesh process groups not initialized!") - if self.mesh.ndim == 1: - return not_none( - _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) + if self.mesh.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + "If you want to get the list of all the ProcessGroups in the DeviceMesh," + "please use `get_all_groups()` instead.", ) - if mesh_dim is not None: - if isinstance(mesh_dim, str): - mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) - return not_none( - _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) - ) + if self.mesh.ndim == 1 and mesh_dim is None: + mesh_dim = 0 else: - dim_groups = [] - for ith_dim in range(self.mesh.ndim): - dim_groups.append( - not_none( - _find_pg_by_ranks_and_tag( - *self._dim_group_infos[ith_dim][:2] - ) - ) - ) - return dim_groups + mesh_dim = ( + _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) + if isinstance(mesh_dim, str) + else mesh_dim + ) + + return not_none( + _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index] + ) + + def get_all_groups(self) -> List[ProcessGroup]: + """ + Returns a list of ProcessGroups for all mesh dimensions. + + Returns: + A list of :class:`ProcessGroup` object. + """ + return [self.get_group(i) for i in range(self.mesh.ndim)] @staticmethod def from_group( diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index decf309cfec1..bd81fd61b02f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Distributed Collective Communication (c10d).""" import itertools @@ -679,6 +680,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device "This usage is deprecated since PyTorch 2.0. Please use a public API " "of PyTorch Distributed instead.", FutureWarning, + stacklevel=3, ) # Most users create Gloo with private API for object collectives _world.pg_default_device[group] = torch.device("cpu") diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 95369ecb61e1..232f28234e65 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py new file mode 100644 index 000000000000..160383637865 --- /dev/null +++ b/torch/distributed/elastic/control_plane.py @@ -0,0 +1,51 @@ +import os +from contextlib import contextmanager, ExitStack +from typing import Generator + +from torch.distributed.elastic.multiprocessing.errors import record + +__all__ = [ + "worker_main", +] + +TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" + + +@contextmanager +def _worker_server(socket_path: str) -> Generator[None, None, None]: + from torch._C._distributed_c10d import _WorkerServer + + server = _WorkerServer(socket_path) + try: + yield + finally: + server.shutdown() + + +@contextmanager +@record +def worker_main() -> Generator[None, None, None]: + """ + This is a context manager that wraps your main entry function. This combines + the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that + exposes handlers via a unix socket specified by + ``Torch_WORKER_SERVER_SOCKET``. + + Example + + :: + + @worker_main() + def main(): + pass + + if __name__=="__main__": + main() + + """ + with ExitStack() as stack: + socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) + if socket_path is not None: + stack.enter_context(_worker_server(socket_path)) + + yield diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index db6cb639ef1c..9f6e1733518a 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -86,6 +86,40 @@ def construct_and_record_rdzv_event( local_id: Optional[int] = None, rank: Optional[int] = None, ) -> None: + """ + Initialize rendezvous event object and record its operations. + + Args: + run_id (str): The run id of the rendezvous. + message (str): The message describing the event. + node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). + name (str): Event name. (E.g. Current action being performed). + hostname (str): Hostname of the node. + pid (Optional[int]): The process id of the node. + master_endpoint (str): The master endpoint for the rendezvous store, if known. + local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py + rank (Optional[int]): The rank of the node, if known. + Returns: + None + Example: + >>> # See DynamicRendezvousHandler class + >>> def _record( + ... self, + ... message: str, + ... node_state: NodeState = NodeState.RUNNING, + ... rank: Optional[int] = None, + ... ) -> None: + ... construct_and_record_rdzv_event( + ... name=f"{self.__class__.__name__}.{get_method_name()}", + ... run_id=self._settings.run_id, + ... message=message, + ... node_state=node_state, + ... hostname=self._this_node.addr, + ... pid=self._this_node.pid, + ... local_id=self._this_node.local_id, + ... rank=rank, + ... ) + """ # We don't want to perform an extra computation if not needed. if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): return diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 62f5d7500922..082499b3af63 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index 767abcc1d60b..d8bea0b3c079 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env/python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 11a3930acf70..7b6d8295ef05 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index eb0b110f25ee..5d294a7d0802 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 95d6a6192245..d63c283b4c35 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 903731a6a2ab..34d6229dda3b 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/errors/handlers.py b/torch/distributed/elastic/multiprocessing/errors/handlers.py index 3071aef17117..09b2aca55f16 100644 --- a/torch/distributed/elastic/multiprocessing/errors/handlers.py +++ b/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/redirects.py b/torch/distributed/elastic/multiprocessing/redirects.py index e63255819383..8ad3e2edf1c1 100644 --- a/torch/distributed/elastic/multiprocessing/redirects.py +++ b/torch/distributed/elastic/multiprocessing/redirects.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # !/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index 8d4477452a20..e122f89a94f7 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 17b0d216e954..804e2e5a6323 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 09b19be479dc..7ddcd7c70b9a 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 62413df02aae..7fb894bd2247 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index a80fa9e97894..0bc92d845d19 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index b642d6201200..1a371b74275a 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py index cacb888590f8..c9d60abdc236 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/etcd_server.py b/torch/distributed/elastic/rendezvous/etcd_server.py index a28f7cc31839..891858534c56 100644 --- a/torch/distributed/elastic/rendezvous/etcd_server.py +++ b/torch/distributed/elastic/rendezvous/etcd_server.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index 7690439237ad..605596475686 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index 2e53034a9d6e..ace82d0a2226 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/rendezvous/utils.py b/torch/distributed/elastic/rendezvous/utils.py index 326bc604a914..8419051d29f8 100644 --- a/torch/distributed/elastic/rendezvous/utils.py +++ b/torch/distributed/elastic/rendezvous/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 0121c98d56d1..77fcaaceed4f 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/timer/debug_info_logging.py b/torch/distributed/elastic/timer/debug_info_logging.py index 2ac2dc5318be..55a1a9e9bcdf 100644 --- a/torch/distributed/elastic/timer/debug_info_logging.py +++ b/torch/distributed/elastic/timer/debug_info_logging.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index f2ded8ba84dd..fce46f053a7e 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index 7c87413aef19..b6a54896fc5e 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # diff --git a/torch/distributed/elastic/utils/data/cycling_iterator.py b/torch/distributed/elastic/utils/data/cycling_iterator.py index 60a5861f7bef..b5dadb96bda4 100644 --- a/torch/distributed/elastic/utils/data/cycling_iterator.py +++ b/torch/distributed/elastic/utils/data/cycling_iterator.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py index a66803fa8c09..8e378c6a1be1 100644 --- a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py +++ b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 1dc4680abc16..04ff2fe680f1 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. @@ -113,6 +114,24 @@ def _check_full_rank(store, world_size, timeout): def get_free_port(): + """ + Returns an unused port on localhost. + + This function finds an unused port on localhost by opening to socket to bind + to a port and then closing it. + + Returns: + int: an unused port on localhost + + Example: + >>> # xdoctest: +SKIP("Nondeterministic") + >>> get_free_port() + 63976 + + ..note: + The port returned by :func:`get_free_port` is not reserved and may be + taken by another process after this function returns. + """ sock = get_socket_with_port() with closing(sock): return sock.getsockname()[1] diff --git a/torch/distributed/elastic/utils/logging.py b/torch/distributed/elastic/utils/logging.py index e305d16400cb..d87504d255d6 100644 --- a/torch/distributed/elastic/utils/logging.py +++ b/torch/distributed/elastic/utils/logging.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 080e92eae91e..6d2e1f046502 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/examples/memory_tracker_example.py b/torch/distributed/examples/memory_tracker_example.py index d4946513098c..cb2ba03777d8 100644 --- a/torch/distributed/examples/memory_tracker_example.py +++ b/torch/distributed/examples/memory_tracker_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torchvision diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index c1d77bf410b5..aae2405d0bb5 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file includes private common utilities for FSDP. """ diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index a41a817724e5..523330e5580d 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import time from collections import defaultdict diff --git a/torch/distributed/fsdp/_dynamo_utils.py b/torch/distributed/fsdp/_dynamo_utils.py index 3a6c63dc5af8..e58c91a5807b 100644 --- a/torch/distributed/fsdp/_dynamo_utils.py +++ b/torch/distributed/fsdp/_dynamo_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Set import torch.nn as nn diff --git a/torch/distributed/fsdp/_exec_order_utils.py b/torch/distributed/fsdp/_exec_order_utils.py index 3ba2a43c0596..ad5fdc1fde5f 100644 --- a/torch/distributed/fsdp/_exec_order_utils.py +++ b/torch/distributed/fsdp/_exec_order_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import warnings from enum import auto, Enum diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index ed141465155c..816b91433063 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import functools import logging @@ -1140,7 +1141,7 @@ def shard_metadata( tuple(fqns_list), tuple(shapes_list), tuple(numels_list), - shard_param_offsets, + tuple(shard_param_offsets), ) @no_type_check diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 5b811f50a032..c8b58091bf89 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import itertools import os @@ -166,8 +167,7 @@ def _init_process_group_state_for_hybrid_shard( state.process_group = device_mesh.get_group(mesh_dim=1) else: raise ValueError( - "Expected device_mesh to have ndim=2 " - f"but got {len(device_mesh.get_group())}" + f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}" ) elif process_group is None: default_group = _get_default_group() @@ -1099,25 +1099,6 @@ def _sync_module_params_and_buffers( ) -def _sync_module_states( - params: List[nn.Parameter], - buffers: List[torch.Tensor], - process_group: dist.ProcessGroup, -) -> None: - # Assumes that each call to this method passes in disjoint `params` and - # and `buffers` across calls, so there is no chance of re-synchronizing - params_and_buffers = [param.detach() for param in params] + [ - buffer.detach() for buffer in buffers - ] - _check_module_states_for_sync_module_states(params_and_buffers) - _sync_params_and_buffers( - process_group, - params_and_buffers, - PARAM_BROADCAST_BUCKET_SIZE, - src=0, - ) - - def _check_module_states_for_sync_module_states( module_states: List[torch.Tensor], ) -> None: diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index b066f930ebaf..d4aa344c1114 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import logging diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index f1e579adae00..833c1d45697a 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging from enum import auto, Enum diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index 8af94b78209b..da243e6aa130 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import itertools import math diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 9489994a3bb4..797a0116587b 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import logging import math diff --git a/torch/distributed/fsdp/_trace_utils.py b/torch/distributed/fsdp/_trace_utils.py index c768b73b8f95..49039e337ea2 100644 --- a/torch/distributed/fsdp/_trace_utils.py +++ b/torch/distributed/fsdp/_trace_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from contextlib import contextmanager from dataclasses import dataclass, field diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 7700d631d73e..435193a88703 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import warnings from typing import cast, Generator diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 16f521f65b8d..84cdf250d8ae 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import inspect diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index fdb72ce0b219..9edd057a8f37 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -686,6 +686,15 @@ def set_state_dict_type( A StateDictSettings that include the previous state_dict type and configuration for the module. """ + warnings.warn( + "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " + "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " + "which can support different parallelisms, FSDP1, FSDP2, DDP. " + "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" + "#torch.distributed.checkpoint.state_dict.get_state_dict ." + "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", + FutureWarning, + ) _state_dict_type_to_config = { StateDictType.FULL_STATE_DICT: FullStateDictConfig, StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, @@ -1198,12 +1207,13 @@ def clip_grad_norm_( return total_norm.to(total_norm_dtype) @staticmethod - def _warn_optim_input(optim_input): + def _warn_optim_input(optim_input, *, stacklevel: int = 1): if optim_input is not None: warnings.warn( "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " "You may remove it from your code without changing its functionality.", FutureWarning, + stacklevel=stacklevel + 1, ) @staticmethod @@ -1218,12 +1228,13 @@ def _is_using_optim_input(optim_input, optim) -> bool: return False @staticmethod - def _warn_legacy_optim_state_dict(curr: str, new: str): + def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): warnings.warn( f"``FullyShardedDataParallel.{curr}``is being deprecated and is " f"replaced by ``FullyShardedDataParallel.{new}``. " f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", FutureWarning, + stacklevel=stacklevel + 1, ) @staticmethod @@ -1241,6 +1252,8 @@ def _optim_state_dict_impl( full_state_dict: bool = True, group: Optional[dist.ProcessGroup] = None, cpu_offload: bool = True, + *, + _stacklevel: int = 1, ) -> Dict[str, Any]: """Transform the state-dict of an optimizer corresponding to a sharded model. @@ -1249,7 +1262,9 @@ def _optim_state_dict_impl( FSDP internal information and internal sharding from the optim_state_dict. """ if full_state_dict: - FullyShardedDataParallel._warn_optim_input(optim_input) + FullyShardedDataParallel._warn_optim_input( + optim_input, stacklevel=_stacklevel + 1 + ) using_optim_input = FullyShardedDataParallel._is_using_optim_input( optim_input, optim, @@ -1400,7 +1415,9 @@ def full_optim_state_dict( then nonzero ranks return an empty :class:`dict`. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "full_optim_state_dict", "optim_state_dict" + "full_optim_state_dict", + "optim_state_dict", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_impl( model=model, @@ -1410,6 +1427,7 @@ def full_optim_state_dict( rank0_only=rank0_only, group=group, full_state_dict=True, + _stacklevel=2, ) @staticmethod @@ -1431,7 +1449,9 @@ def sharded_optim_state_dict( cannot be directly used by the regular ``optim.load_state_dict``. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "sharded_optim_state_dict", "optim_state_dict" + "sharded_optim_state_dict", + "optim_state_dict", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_impl( model=model, @@ -1441,6 +1461,7 @@ def sharded_optim_state_dict( rank0_only=False, full_state_dict=False, group=group, + _stacklevel=2, ) @staticmethod @@ -1509,7 +1530,9 @@ def shard_full_optim_state_dict( restricted to only include this rank's part of the optimizer state. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "shard_full_optim_state_dict", "optim_state_dict_to_load" + "shard_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=full_optim_state_dict, @@ -1546,7 +1569,9 @@ def flatten_sharded_optim_state_dict( Refer to :meth:`shard_full_optim_state_dict`. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "flatten_sharded_optim_state_dict", "optim_state_dict_to_load" + "flatten_sharded_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=sharded_optim_state_dict, @@ -1626,7 +1651,9 @@ def scatter_full_optim_state_dict( restricted to only include this rank's part of the optimizer state. """ FullyShardedDataParallel._warn_legacy_optim_state_dict( - "scatter_full_optim_state_dict", "optim_state_dict_to_load" + "scatter_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, ) return FullyShardedDataParallel._optim_state_dict_to_load_impl( optim_state_dict=full_optim_state_dict, @@ -1857,6 +1884,7 @@ def optim_state_dict( cpu_offload=getattr( state_dict_settings.optim_state_dict_config, "offload_to_cpu", True ), + _stacklevel=2, ) @staticmethod diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 47bfe041cdc2..3487e01263c7 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from collections import abc, defaultdict from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index 90796269de46..acb5a6f1f642 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the BSD license found in the diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 3efb0c3cf31d..a9e35c36db7f 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" Module ``torch.distributed.launch``. diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 20de0a032713..937647f77828 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 16e38b32712d..de8a15dd65da 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs import collections import io import sys diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index 857d090dedbe..e90a78a69324 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist from torch.autograd import Function diff --git a/torch/distributed/nn/jit/instantiator.py b/torch/distributed/nn/jit/instantiator.py index 24f53c4f1a60..d529fc740945 100644 --- a/torch/distributed/nn/jit/instantiator.py +++ b/torch/distributed/nn/jit/instantiator.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs import importlib import logging import os diff --git a/torch/distributed/nn/jit/templates/remote_module_template.py b/torch/distributed/nn/jit/templates/remote_module_template.py index ac731b434243..07b055774b36 100644 --- a/torch/distributed/nn/jit/templates/remote_module_template.py +++ b/torch/distributed/nn/jit/templates/remote_module_template.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool): diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index e3e44d4667ae..bc5f7c63dd17 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index dfd50db17591..93a1fe2b2240 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 5335df17e089..34868d23d8a5 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index f3acd4d271ef..32bce65dfe1f 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 40aabafb0ca7..43addd050822 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index fc4d7750973c..851119c8600c 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 6018ce943b40..60742bc68896 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Tuple import torch diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 4a807a605571..3a8176e87705 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional import torch diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 28edbe39d80e..9e1e5377873d 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import warnings diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index 8246c667509d..f2eca606c026 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from collections import defaultdict diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index f1717685966a..db65856e32ad 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings import torch diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index 5fb197e2d1dd..af2220ca5574 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Type from torch import optim diff --git a/torch/distributed/optim/zero_redundancy_optimizer.pyi b/torch/distributed/optim/zero_redundancy_optimizer.pyi index c341e00e3ee3..21f3cc5e3fc2 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.pyi +++ b/torch/distributed/optim/zero_redundancy_optimizer.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import enum from typing import Any, Callable, Dict, List, Optional, overload, Set, Type diff --git a/torch/distributed/pipeline/__init__.py b/torch/distributed/pipeline/__init__.py deleted file mode 100644 index eacd2bc99d04..000000000000 --- a/torch/distributed/pipeline/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import warnings - - -with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "`torch.distributed.pipeline` is deprecated. For up-to-date pipeline parallel " - "implementation, please refer to the PiPPy library under the PyTorch " - "organization (Pipeline Parallelism for PyTorch): " - "https://github.com/pytorch/PiPPy", - DeprecationWarning, - stacklevel=2, - ) diff --git a/torch/distributed/pipeline/sync/LICENSE b/torch/distributed/pipeline/sync/LICENSE deleted file mode 100644 index e52be240fdc9..000000000000 --- a/torch/distributed/pipeline/sync/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright 2019-2020 Kakao Brain - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/pipeline/sync/__init__.py b/torch/distributed/pipeline/sync/__init__.py deleted file mode 100644 index 75a80c5db0f9..000000000000 --- a/torch/distributed/pipeline/sync/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A Pipe implementation in PyTorch.""" -from .checkpoint import is_checkpointing, is_recomputing -from .pipe import Pipe, WithDevice -from .microbatch import NoChunk - -__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/pipeline/sync/_balance/__init__.py b/torch/distributed/pipeline/sync/_balance/__init__.py deleted file mode 100644 index 8ffc657896d8..000000000000 --- a/torch/distributed/pipeline/sync/_balance/__init__.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""A helper to roughly balance a sequential module. - -Usage:: - - import torch - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - - pipe = Pipe(model, balance, chunks=8) - -""" -from typing import Any, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from . import blockpartition -from .profile import profile_sizes, profile_times - -__all__ = ["balance_by_time", "balance_by_size"] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def balance_cost(cost: List[int], partitions: int) -> List[int]: - partitioned = blockpartition.solve(cost, partitions) - return [len(p) for p in partitioned] - - -def balance_by_time( - partitions: int, - module: nn.Sequential, - sample: Union[List[Any], Tensor], - *, - timeout: float = 1.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by elapsed time per layer. - :: - - sample = torch.empty(128, 3, 224, 224) - balance = balance_by_time(torch.cuda.device_count(), model, sample) - pipe = Pipe(model, balance, chunks=8) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - sample (torch.Tensor): - example input with arbitrary batch size - - Keyword Args: - timeout (float): - profiling iterates again if the timeout (in second) is not exceeded - (default: ``1.0``) - device ('cpu' or 'cuda' device): - CPU or CUDA device where each layer is profiled (default: the - current CUDA device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `sample` must be placed on the same device. - - """ - times = profile_times(module, sample, timeout, torch.device(device)) - return balance_cost(times, partitions) - - -def balance_by_size( - partitions: int, - module: nn.Sequential, - input: Union[List[Any], Tensor], - *, - chunks: int = 1, - param_scale: float = 2.0, - device: Device = torch.device("cuda"), -) -> List[int]: - """Naive automatic balancing by CUDA memory usage per layer. - - During training, required memory for parameters depends on which optimizer - is used. Optimizers may use buffers for each parameter to track - optimization statistics internally, such as momentum buffer in SGD. - - To get more reliable size based balance, you should specify `param_scale` - with regard to your optimizer. The default `param_scale` is 2 instead of 1 - due to gradient accumulation which is necessary for every optimizer. - - Follow this guide to choose correct `param_scale` for typical optimizers: - - ========= ============= ========================================= - Optimizer `param_scale` Internal State - ========= ============= ========================================= - SGD 2--3 (momentum_buffer) - Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) - Adadelta 4 square_avg, acc_delta - Adagrad 3 sum - RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) - ========= ============= ========================================= - - Here's a simple example with the Adam optimizer:: - - balance = balance_by_size( - torch.cuda.device_count(), - model, - - # Same size with mini-batch to train - torch.empty(1024, 3, 224, 224), - - # Number of micro-batches to train with Pipe - chunks=8, - - # 4 for Adam - param_scale=4.0, - ) - - pipe = Pipe(model, balance, chunks=8) - adam = Adam(pipe.parameters()) - - Args: - partitions (int): - intended number of partitions - module (torch.nn.Sequential): - sequential module to be partitioned - input (torch.Tensor): - example mini-batch with the same size to train - - Keyword Args: - chunks (int): - number of micro-batches will be used to train (default: ``1``) - param_scale (float): - how many copies of parameters would be allocated for training. It - depends on optimizer. See the above guide. (default: ``2.0``) - device ('cuda' device): - CUDA device where each layer is profiled (default: the current CUDA - device) - - Returns: - A list of number of layers in each partition. Use it for the `balance` - parameter of :class:`~torchpipe.Pipe`. - - .. note:: - `module` and `input` must be placed on the same CUDA device. - - """ - sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) - return balance_cost(sizes, partitions) diff --git a/torch/distributed/pipeline/sync/_balance/blockpartition.py b/torch/distributed/pipeline/sync/_balance/blockpartition.py deleted file mode 100644 index ccdf5fe4df99..000000000000 --- a/torch/distributed/pipeline/sync/_balance/blockpartition.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Implements "Block Partitions of Sequences" by Imre B\u00e1r\u00e1ny et al. - -Paper: https://arxiv.org/pdf/1308.2452.pdf - -""" -from typing import Iterator, List, Tuple - -__all__ = ["solve"] - - -def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: - """Splits a sequence into several partitions to minimize variance for each - partition. - - The result might not be optimal. However, it can be done only in O(kn\u00b3), - where k is the number of partitions and n is the length of the sequence. - - """ - if partitions < 1: - raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") - - n = len(sequence) - if n < partitions: - raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") - - # Normalize the sequence in [0, 1]. - minimum = min(sequence) - maximum = max(sequence) - minimum - - normal_sequence: List[float] - if maximum == 0: - normal_sequence = [0 for _ in sequence] - else: - normal_sequence = [(x - minimum) / maximum for x in sequence] - - splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] - - def block_size(i: int) -> float: - start = splits[i - 1] if i > 0 else 0 - stop = splits[i] - return sum(normal_sequence[start:stop]) - - def leaderboard() -> Iterator[Tuple[float, int]]: - return ((block_size(i), i) for i in range(partitions)) - - while True: - """ - (1) Fix p element-of [k] with M(P) = bp. So Bp is a maximal block of P. - """ - # max_size: M(P) - max_size, p = max(leaderboard()) - - while True: - """ - (2) If M(P) <= m(P) + 1, then stop. - """ - # min_size: m(P) - min_size, q = min(leaderboard()) - - if max_size <= min_size + 1: - return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] - - """ - (3) If M(P) > m(P) + 1, then let m(P) = bq for the q element-of [k] which is - closest to p (ties broken arbitrarily). Thus Bq is a minimal block - of P. Let Bh be the block next to Bq between Bp and Bq. (Note that - Bh is a non-empty block: if it were, then m(P) = 0 and we should - have chosen Bh instead of Bq.) - """ - if p < q: - """ - So either p < q and then h = q-1 and we define P * by moving - the last element from Bh = Bq-1 to Bq, - """ - h = q - 1 - splits[h] -= 1 - else: - """ - or q < p, and then h = q + 1 and P * is obtained by moving the - first element of Bh = Bq+1 to Bq. - """ - h = q + 1 - splits[q] += 1 - - """ - Set P = P * . If p = h, then go to (1), else go to (2). - """ - if p == h: - break diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py deleted file mode 100644 index fa1a0c06a8e3..000000000000 --- a/torch/distributed/pipeline/sync/_balance/profile.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Per-layer profilers.""" -import copy -import time -from typing import Any, Generator, List, Union, Sequence - -import torch -from torch import Tensor -import torch.nn as nn - -from ..microbatch import Batch - -__all__: List[str] = [] - - -Device = Union[torch.device, int, str] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - - -def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: - """Copies layers for ease to profile. It doesn't modify the given - module. - """ - for layer in module: - layer_copy = copy.deepcopy(layer) - layer_copy.to(device) - layer_copy.train() - yield layer_copy - - -def detach(batch: Batch) -> None: - """Detaches from autograd graph.""" - for i, x in enumerate(batch): - batch[i] = x.detach().requires_grad_(x.requires_grad) - - -def profile_times(module: nn.Sequential, sample: Union[List[Any], Tensor], timeout: float, device: torch.device,) -> List[int]: - """Profiles elapsed times per layer.""" - if any(p.grad is not None for p in module.parameters()): - raise ValueError("some parameter already has gradient") - - _batch = Batch(sample) - for i, x in enumerate(_batch): - _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) - - time_bufs: List[List[float]] = [[] for _ in module] - begun_at = time.time() - - while time.time() - begun_at < timeout: - batch = _batch - - for i, layer in enumerate(layerwise_sandbox(module, device)): - detach(batch) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tick = time.time() - - # Forward - batch = batch.call(layer) - - # Backward - backward_tensors = tuple(y for y in batch if y.requires_grad) - if backward_tensors: - torch.autograd.backward(backward_tensors, backward_tensors) - - if device.type == "cuda": - torch.cuda.synchronize(device) - tock = time.time() - - time_bufs[i].append(tock - tick) - - us = 1_000_000 - return [sum(int(t * us) for t in buf) for buf in time_bufs] - - -def profile_sizes( - module: nn.Sequential, input: Union[List[Any], Tensor], chunks: int, param_scale: float, device: torch.device, -) -> List[int]: - """Profiles CUDA memory usage per layer.""" - if device.type != "cuda": - raise ValueError("size profiler supports only CUDA device") - - batch = Batch(input) - sizes: List[int] = [] - - latent_scale = batch[0].size(0) / chunks - for i, x in enumerate(batch): - batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) - - for layer in layerwise_sandbox(module, device): - detach(batch) - - # Detect memory usage at forward. - torch._C._cuda_clearCublasWorkspaces() - memory_before = torch.cuda.memory_allocated(device) - batch = batch.call(layer) - torch._C._cuda_clearCublasWorkspaces() - memory_after = torch.cuda.memory_allocated(device) - latent_size = memory_after - memory_before - - # Analyze size of parameters. - param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) - - # Combine size of parameters and activations with normalize scales. - size = latent_size * latent_scale + param_size * param_scale - sizes.append(int(size)) - - return sizes diff --git a/torch/distributed/pipeline/sync/_balance/py.typed b/torch/distributed/pipeline/sync/_balance/py.typed deleted file mode 100644 index ab03724cafbf..000000000000 --- a/torch/distributed/pipeline/sync/_balance/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/batchnorm.py b/torch/distributed/pipeline/sync/batchnorm.py deleted file mode 100644 index 868ad50cf3fc..000000000000 --- a/torch/distributed/pipeline/sync/batchnorm.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks the running statistics per mini-batch instead of micro-batch.""" -from typing import TypeVar, Optional, cast - -import torch -from torch import Tensor, nn -from torch.nn.functional import batch_norm -from torch.nn.modules.batchnorm import _BatchNorm - -from .checkpoint import is_recomputing - -__all__ = ["DeferredBatchNorm"] - - -TModule = TypeVar("TModule", bound=nn.Module) - - -class DeferredBatchNorm(_BatchNorm): - """A BatchNorm layer tracks multiple micro-batches to update running statistics per mini-batch.""" - - sum: Tensor - sum_squares: Tensor - running_mean: Tensor - running_var: Tensor - num_batches_tracked: Tensor - - def __init__( - self, - num_features: int, - eps: float = 1e-5, - momentum: Optional[float] = 0.1, - affine: bool = True, - chunks: int = 1, - ) -> None: - super().__init__(num_features, eps, momentum, affine, track_running_stats=True) - - self.register_buffer("sum", torch.zeros_like(self.running_mean)) - self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) - - self.counter = 0 - self.tracked = 0 - self.chunks = chunks - - def _check_input_dim(self, input: Tensor) -> None: - # It's the typical _check_input_dim() implementation in PyTorch. - if input.dim() <= 2: - raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) - - def _track(self, input: Tensor) -> bool: - """Tracks statistics of a micro-batch.""" - # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. - dim = [0] - dim.extend(range(2, input.dim())) - - with torch.no_grad(): - self.sum += input.sum(dim) - self.sum_squares += (input ** 2).sum(dim) - - size = input.size().numel() // input.size(1) - self.counter += size - self.tracked += 1 - - return self.tracked == self.chunks - - def _commit(self) -> None: - """Update the running statistics of a mini-batch.""" - exponential_average_factor = 0.0 - self.num_batches_tracked += 1 - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum - - mean = self.sum / self.counter - var = self.sum_squares / self.counter - mean ** 2 - - # Calculate the exponential moving average here. - m = exponential_average_factor - - self.running_mean *= 1 - m - self.running_mean += mean * m - - self.running_var *= 1 - m - self.running_var += var * m - - self.sum.zero_() - self.sum_squares.zero_() - self.counter = 0 - self.tracked = 0 - - def forward(self, input: Tensor) -> Tensor: - if not self.training: - # Don't train parameters on the evaluation mode. - return batch_norm( - input, - running_mean=self.running_mean, - running_var=self.running_var, - weight=self.weight, - bias=self.bias, - training=False, - momentum=0.0, - eps=self.eps, - ) - - if not is_recomputing(): - # Track a micro-batch on the training mode - # but not under a recomputation. - tracked_enough = self._track(input) - - # Update the running statistics for a mini-batch - # if it has tracked enough micro-batches. - if tracked_enough: - self._commit() - - # Normalize a micro-batch and train the parameters. - return batch_norm( - input, - running_mean=None, - running_var=None, - weight=self.weight, - bias=self.bias, - training=True, - momentum=0.0, - eps=self.eps, - ) - - @classmethod - def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: - """Converts a :class:`nn.BatchNorm` or underlying :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: - - from torchvision.models.resnet import resnet101 - from torchpipe.batchnorm import DeferredBatchNorm - model = resnet101() - model = DeferredBatchNorm.convert_deferred_batch_norm(model) - - """ - if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: - return cast(TModule, module) - - module_output: nn.Module = module - - if isinstance(module, _BatchNorm) and module.track_running_stats: - module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) - if module.affine: - module_output.register_parameter("weight", module.weight) - module_output.register_parameter("bias", module.bias) - module_output.register_buffer("running_mean", module.running_mean) - module_output.register_buffer("running_var", module.running_var) - module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) - - for name, child in module.named_children(): - module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) - - return cast(TModule, module_output) diff --git a/torch/distributed/pipeline/sync/checkpoint.py b/torch/distributed/pipeline/sync/checkpoint.py deleted file mode 100644 index e67da2499d57..000000000000 --- a/torch/distributed/pipeline/sync/checkpoint.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Checkpointing with preceding recomputation. - -PyTorch already provides the official checkpointing utilities in -:mod:`torch.utils.checkpoint`. The official checkpointing combines -recomputation and recursive backpropagation into one autograd function named -``CheckpointFunction``. Hence, the recomputation can be started only when the -gradients arrive to the function. In Pipe, the recomputation needs to precede -the gradient arrival to minimize the GPU idle time. - -We solve this problem by introducing separate autograd functions named -:class:`Recompute` and :class:`Checkpoint`. Each function represents -recomputation and recursive backpropagation, respectively. We can manipulate -the control flow in aspect of both the autograd engine and CUDA with a pair of -the functions. - -Specifically, we place CUDA stream synchronization between :class:`Recompute` -and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is -copied entirely. - -""" -from collections import deque -from contextlib import contextmanager -import threading -from typing import ( - Any, - Deque, - Generator, - List, - Optional, - Protocol, - Union, - Sequence, - Tuple -) - -import torch -from torch import Tensor -import torch.autograd - -from .dependency import fork, join -from .microbatch import Batch -from .phony import get_phony - -__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing", - "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states", - "restore_rng_states", "Checkpoint", "Recompute"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -# Types for shared memory between Checkpoint and Recompute. -Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) -RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) - - -# Protocol with __call__ instead of Callable can be used as an attribute type. -# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 -class Function(Protocol): - def __call__(self, input: TensorOrTensors) -> TensorOrTensors: - ... - - -def checkpoint(function: Function, input): - """Make a checkpoint with a simple interface like - :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug - :class:`Checkpoint` and :class:`Recompute` without boilerplate. - """ - batch = Batch(input) - - chk = Checkpointing(function, batch) - batch = chk.checkpoint() - chk.recompute(batch) - - return batch.values - - -class Checkpointing: - """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" - - def __init__(self, function: Function, batch: Batch) -> None: - self.function = function - self.batch = batch - - # Shared memory between Checkpoint and Recompute. 1-length deque is - # used for mutability and length limitation. - self.recomputed: Deque[Recomputed] = deque(maxlen=1) - self.rng_states: Deque[RNGStates] = deque(maxlen=1) - - def checkpoint(self) -> Batch: - """Return a batch applied by :class:`Checkpoint`.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a phony which requires grad to ensure that Checkpoint can be - # tracked by the autograd engine even when none of the input tensors - # require grad. - phony = get_phony(self.batch.get_device(), requires_grad=True) - - output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - - # Gradients are only supported for float Tensors. - if isinstance(output, tuple): - output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output]) - - return Batch(output) - - def recompute(self, batch: Batch) -> None: - """Apply :class:`Recompute` to the batch in place.""" - input_atomic = self.batch.atomic - inputs = tuple(self.batch) - - # Use a tensor in the batch to tie together fork-join - tensor_idx = batch.find_tensor_idx() - # batch[tensor_idx] is always requiring grad, because it has been passed - # checkpoint with a phony requiring grad. - batch[tensor_idx], phony = fork(batch[tensor_idx]) - phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs) - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.is_checkpointing = False - self.is_recomputing = False - - -thread_local = ThreadLocal() - - -@contextmanager -def enable_checkpointing() -> Generator[None, None, None]: - """Make :func:`is_checkpointing` return :data:`True` within a context.""" - orig = thread_local.is_checkpointing - thread_local.is_checkpointing = True - try: - yield - finally: - thread_local.is_checkpointing = orig - - -@contextmanager -def enable_recomputing() -> Generator[None, None, None]: - """Makes :func:`is_recomputing` return :data:`True` within a context.""" - orig = thread_local.is_recomputing - thread_local.is_recomputing = True - try: - yield - finally: - thread_local.is_recomputing = orig - - -def is_checkpointing() -> bool: - """Whether the current forward propagation is under checkpointing. - - Returns: - bool: :data:`True` if it's under checkpointing. - - """ - return thread_local.is_checkpointing - - -def is_recomputing() -> bool: - """Whether the current forward propagation is under checkpoint recomputation. - - Use this to prevent duplicated side-effects at forward - propagation:: - - class Counter(nn.Module): - def __init__(self): - super().__init__() - self.counter = 0 - - def forward(self, input): - if not is_recomputing(): - self.counter += 1 - return input - - Returns: - bool: :data:`True` if it's under checkpoint recomputation. - - .. seealso:: :ref:`Detecting Recomputation` - - """ - return thread_local.is_recomputing - - -class Context: - """The common interface between the :class:`Checkpoint` and :class:`Recompute` context.""" - - recomputed: Deque[Recomputed] - rng_states: Deque[RNGStates] - function: Function - input_atomic: bool - inputs: Sequence[Any] - - saved_tensors: Tuple[Tensor, ...] - - def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover - pass - - -def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: - """: - Capture the current random number generator states. - - meth:`Checkpoint.forward` captures the current PyTorch's random number - generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state = torch.get_rng_state() - - gpu_rng_state: Optional[Tensor] - if device.type == "cuda": - gpu_rng_state = torch.cuda.get_rng_state(device) - else: - gpu_rng_state = None - - rng_states.append((cpu_rng_state, gpu_rng_state)) - - -@contextmanager -def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: - """: - Restore the random number generator state. - - meth:`Recompute.backward` restores the random number generator states - captured by :func:`save_rng_states` within its context. - - .. seealso:: :ref:`Referential Transparency` - - """ - cpu_rng_state, gpu_rng_state = rng_states.pop() - - gpu_devices: List[torch.device] = [] - if device.type == "cuda": - gpu_devices.append(device) - - with torch.random.fork_rng(gpu_devices): - torch.set_rng_state(cpu_rng_state) - if gpu_rng_state is not None: - torch.cuda.set_rng_state(gpu_rng_state, device) - yield - - -class Checkpoint(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ): - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - save_rng_states(phony.device, ctx.rng_states) - - ctx.function = function - ctx.input_atomic = input_atomic - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - - ctx.save_for_backward(*tensors) - - with torch.no_grad(), enable_checkpointing(): - if input_atomic: - assert len(inputs) == 1 - output = function(inputs[0]) - else: - output = function(*inputs) - return output - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover - output, input_leaf = ctx.recomputed.pop() - - if isinstance(output, tuple): - outputs = output - else: - outputs = (output,) - if any(torch.is_tensor(y) and y.requires_grad for y in outputs): - tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad]) - torch.autograd.backward(tensors, grad_output) - - grad_input: List[Optional[Tensor]] = [None, None, None, None, None] - grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf) - return tuple(grad_input) - - -class Recompute(torch.autograd.Function): - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - phony: Tensor, - recomputed: Deque[Recomputed], - rng_states: Deque[RNGStates], - function: Function, - input_atomic: bool, - *inputs, - ) -> Tensor: - ctx.recomputed = recomputed - ctx.rng_states = rng_states - - ctx.function = function - ctx.input_atomic = input_atomic - ctx.inputs = inputs - if input_atomic: - tensors = [inputs[0]] - else: - tensors = [] - for input in inputs: - if torch.is_tensor(input): - tensors.append(input) - ctx.save_for_backward(*tensors) - - return phony - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover - inputs = ctx.inputs - inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs) - - # Get the device for the inputs from a tensor - device = None - for input in inputs: - if torch.is_tensor(input): - device = input.device - break - - if device is None: - raise RuntimeError(f'No tensors found in {inputs}') - - with restore_rng_states(device, ctx.rng_states): - with torch.enable_grad(), enable_recomputing(): - if ctx.input_atomic: - assert len(inputs_leaf) == 1 - output = ctx.function(inputs_leaf[0]) - else: - output = ctx.function(*inputs_leaf) - - ctx.recomputed.append((output, inputs_leaf)) - - grad_input: List[None] = [None, None, None, None, None] - grad_input.extend(None for _ in ctx.inputs) - return tuple(grad_input) diff --git a/torch/distributed/pipeline/sync/copy.py b/torch/distributed/pipeline/sync/copy.py deleted file mode 100644 index b717f0c2932c..000000000000 --- a/torch/distributed/pipeline/sync/copy.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Autograd functions for stream-aware CUDA copy. - -It is used to overlap copy and computation on the same GPU. -""" -from collections import deque -from typing import Deque, List, Optional, Tuple, Sequence - -import torch -from torch import Tensor - -from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream - -__all__: List[str] = ["Context", "Copy", "Wait"] - - -Tensors = Sequence[Tensor] - - -# Common interface between :class:`Copy` and :class:`Wait`. -class Context: - prev_stream: AbstractStream - next_stream: AbstractStream - - -class Copy(torch.autograd.Function): - """Copies tensors on specific streams.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input,) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - output = [] - output_stream = current_stream(get_device(next_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in input: - if torch.is_tensor(x): - y = x.to(get_device(next_stream), non_blocking=True) - output.append(y) - - # 'prev_stream' is not where 'x' has been allocated. - record_stream(x, prev_stream) - # 'y' has been allocated on 'next_stream'. - # It might be used on the current stream captured as 'output_stream'. - record_stream(y, output_stream) - else: - output.append(x) - - return tuple(output) - - @staticmethod - def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) - input_stream = current_stream(get_device(prev_stream)) - - with use_stream(prev_stream), use_stream(next_stream): - for x in reversed(grad_output): - y = x.to(get_device(prev_stream), non_blocking=True) - grad_input.appendleft(y) - - # 'next_stream' is not where 'x' has been allocated. - record_stream(x, next_stream) - # 'y' has been allocated on 'prev_stream'. - # It might be used on the current stream captured as 'input_stream'. - record_stream(y, input_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + tuple(grad_input) - - -class Wait(torch.autograd.Function): - """Synchronizes a stream to another stream. - - Place it just before you want to start an operation on the next stream, - provided that all operations on the previous stream are done. - - """ - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input) -> Tensors: - ctx.prev_stream = prev_stream - ctx.next_stream = next_stream - - wait_stream(next_stream, prev_stream) - - return tuple(x.detach() if torch.is_tensor(x) else x for x in input) - - @staticmethod - def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: - prev_stream = ctx.prev_stream - next_stream = ctx.next_stream - - wait_stream(prev_stream, next_stream) - - grad_streams: Tuple[Optional[Tensor], ...] = (None, None) - return grad_streams + grad_input diff --git a/torch/distributed/pipeline/sync/dependency.py b/torch/distributed/pipeline/sync/dependency.py deleted file mode 100644 index ca5c69e388fe..000000000000 --- a/torch/distributed/pipeline/sync/dependency.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Arbitrary dependency between two autograd lanes.""" -from typing import List, Tuple - -import torch -from torch import Tensor - -from .phony import get_phony - -__all__: List[str] = ["fork", "Fork", "join", "Join"] - - -def fork(input: Tensor) -> Tuple[Tensor, Tensor]: - """Branches out from an autograd lane of the given tensor.""" - if torch.is_grad_enabled() and input.requires_grad: - input, phony = Fork.apply(input) - else: - phony = get_phony(input.device, requires_grad=False) - - return input, phony - - -class Fork(torch.autograd.Function): - @staticmethod - def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override] - phony = get_phony(input.device, requires_grad=False) - return input.detach(), phony.detach() - - @staticmethod - def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override] - return grad_input - - -def join(input: Tensor, phony: Tensor) -> Tensor: - """Merge two autograd lanes.""" - if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): - input = Join.apply(input, phony) - - return input - - -class Join(torch.autograd.Function): - @staticmethod - def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override] - return input.detach() - - @staticmethod - def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] - return grad_input, None diff --git a/torch/distributed/pipeline/sync/microbatch.py b/torch/distributed/pipeline/sync/microbatch.py deleted file mode 100644 index 5b8aca257548..000000000000 --- a/torch/distributed/pipeline/sync/microbatch.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Manipulation of micro-batches.""" -import typing -from typing import Any, Callable, List, Union, cast, Sequence - -import torch -from torch import Tensor -import torch.cuda.comm - -__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] -Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] - - -class NoChunk: - """ - Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor - should not be chunked on the batch dimension and instead be replicated - as-is across all micro-batches. This is useful for tensors which might - not have any 'batch' semantics for the model. - """ - def __init__(self, inp: Tensor): - if not torch.is_tensor(inp): - raise TypeError(f'NoChunk only supported for tensors, found: {inp}') - self._tensor = inp - - @property - def tensor(self): - return self._tensor - - -class Batch: - """ - An abstraction representing a microbatch in the pipeline. - """ - - def __init__(self, values: Union[List[Any], Tensor]) -> None: - self._values = values - self.atomic = torch.is_tensor(values) - - # Verify at least on tensor - if not self.atomic: - if not any(torch.is_tensor(value) for value in self._values): - raise TypeError(f'No tensors found in batch: {self._values}') - - @property - def tensor(self) -> Tensor: - """Retrieves the underlying tensor.""" - if not self.atomic: - raise AttributeError("not atomic batch") - return cast(Tensor, self._values) - - @property - def values(self): - """Retrieves the underlying values for the batch""" - return self._values - - def find_tensor_idx(self): - """ - Retrieves the index of first tensor found. - """ - if self.atomic: - return 0 - for i, value in enumerate(self._values): - if torch.is_tensor(value): - return i - - raise TypeError("No tensor found!") - - def get_device(self): - """ - Retrieves the device for this microbatch. - """ - if self.atomic: - return self._values.device # type: ignore[union-attr] - - for value in self._values: - if torch.is_tensor(value): - return value.device - - def call(self, function: Function) -> "Batch": - """Calls a function on the microbatch. It also wraps - the output with :class:`Batch`. - """ - if self.atomic: - return Batch(function(self._values)) - else: - return Batch(function(*self._values)) - - def __repr__(self) -> str: - return f"Batch[atomic={self.atomic!r}]({self._values!r})" - - def __iter__(self): - if self.atomic: - yield self._values - else: - yield from self._values - - def __len__(self) -> int: - return 1 if self.atomic else len(self._values) - - def __getitem__(self, index: int): - if not self.atomic: - return self._values[index] - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - return self._values - - # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". - @typing.overload - def __setitem__(self, index: int, value: Tensor) -> None: - ... - - @typing.overload - def __setitem__(self, index: slice, value: Tensors) -> None: - ... - - def __setitem__(self, index: Union[int, slice], value) -> None: - if isinstance(index, int): - self._setitem_by_index(index, value) - else: - self._setitem_by_slice(index, value) - - def _setitem_by_index(self, index: int, value) -> None: - if not self.atomic: - i = index - self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] - return - - if index != 0: - raise IndexError("atomic batch allows index 0 only") - - self._values = value - - def _setitem_by_slice(self, index: slice, value) -> None: - if not (index.start is index.stop is index.step is None): # noqa: E714 - raise NotImplementedError("only slice [:] supported") - - if not self.atomic: - self._values = value - return - - if len(value) != 1: - raise IndexError("atomic batch cannot be replaced with multiple tensors") - - self._values = value[0] - - -def check(first_device, *inputs) -> None: - """ - Checks whether the input contains at least one tensor and each tensor is - on the same device as the first partition. - - Raises: - ValueError: input does not contain at least one tensor - - """ - - if not any(torch.is_tensor(input) for input in inputs): - raise TypeError(f'inputs do not have any tensors: {inputs}') - if any(torch.is_tensor(input) and input.device != first_device for input in inputs): - raise ValueError('All inputs should be on the same device as the first partition') - - -def scatter(*inputs, chunks: int) -> List[Batch]: - """Splits an input mini-batch into multiple micro-batches.""" - if len(inputs) == 1 and isinstance(inputs[0], Tensor): - return [Batch(x) for x in inputs[0].chunk(chunks)] - - batches: List[Any] = [[] for _ in range(chunks)] - # Actual number of chunks produced - num_chunks = -1 - for input in inputs: - if torch.is_tensor(input): - # Chunk only tensors. - tensors = input.chunk(chunks) - - # Validate number of chunks equal across all inputs. - if num_chunks != -1 and num_chunks != len(tensors): - raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') - num_chunks = len(tensors) - - for i, tensor in enumerate(tensors): - batches[i].append(tensor) - else: - # Replicate non-tensors or tensors wrapped with 'NoChunk'. - for i in range(chunks): - if isinstance(input, NoChunk): - # Extract the tensor out. - batches[i].append(input.tensor) - else: - batches[i].append(input) - - # Truncate to actual number of chunks - batches = batches[:num_chunks] - - return [Batch(x) for x in batches] - - -def gather(outputs: List[Batch]): - """Concatenates output micro-batches into a mini-batch.""" - output: Any - - if outputs[0].atomic: - tensors = tuple(b.tensor for b in outputs) - output = torch.cat(tensors) - else: - output_buf: List[Any] = [] - for i in range(len(outputs[0])): - output_type = type(outputs[0][i]) - current_outputs = [] - for batch in outputs: - if output_type != type(batch[i]): - raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') - current_outputs.append(batch[i]) - - if torch.is_tensor(outputs[0][i]): - output_buf.append(torch.cat(current_outputs)) - else: - output_buf.append(current_outputs) - - output = tuple(output_buf) - - return output diff --git a/torch/distributed/pipeline/sync/phony.py b/torch/distributed/pipeline/sync/phony.py deleted file mode 100644 index 012926699cfb..000000000000 --- a/torch/distributed/pipeline/sync/phony.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides phony for arbitrary dependency in a autograd graph.""" -from typing import Dict, List, Tuple - -import torch -from torch import Tensor - -from .stream import default_stream, use_stream - -__all__: List[str] = ["get_phony"] - - -_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} - - -def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: - """Get a phony. Phony is tensor without space. - - It is useful to make arbitrary dependency in a autograd graph because it doesn't require any - gradient accumulation. - - .. note:: - - Phonies for each device are cached. If an autograd function gets a phony - internally, the phony must be detached to be returned. Otherwise, the - autograd engine will mutate the cached phony in-place:: - - class Phonify(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - phony = get_phony(input.device, requires_grad=False) - return phony.detach() # detach() is necessary. - - """ - key = (device, requires_grad) - - try: - phony = _phonies[key] - except KeyError: - with use_stream(default_stream(device)): - phony = torch.empty(0, device=device, requires_grad=requires_grad) - - _phonies[key] = phony - - return phony diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py deleted file mode 100644 index 5e61341d9ad9..000000000000 --- a/torch/distributed/pipeline/sync/pipe.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The Pipe interface.""" -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast - -import torch -from torch import Tensor, nn -from torch.distributed.rpc import RRef -import torch.autograd -import torch.cuda - -from . import microbatch -from .batchnorm import DeferredBatchNorm -from .pipeline import Pipeline -from .skip.layout import inspect_skip_layout -from .skip.skippable import verify_skippables -from .stream import AbstractStream, new_stream - -__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"] - - -Device = Union[torch.device, int, str] -Devices = Union[Iterable[Device], List[Device]] - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - Module = nn.Module[TensorOrTensors] # type: ignore[type-arg] - NamedModules = OrderedDict[str, Module] -else: - Module = nn.Module - NamedModules = OrderedDict - - -def _recommend_auto_balance(message: str) -> str: - """Expands a message with recommendation to :mod:`torchpipe.balance`.""" - return f"""{message} - -If your model is still under development, its optimal balance would change -frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for -naive automatic balancing: - - from torch.distributed.pipeline.sync import Pipe - from torch.distributed.pipeline.sync.balance import balance_by_time - - partitions = torch.cuda.device_count() - sample = torch.empty(...) - balance = balance_by_time(partitions, model, sample) - - model = Pipe(model, balance, ...) -""" - - -def _verify_module(module: nn.Sequential) -> None: - if not isinstance(module, nn.Sequential): - raise TypeError("module must be nn.Sequential to be partitioned") - - named_children = list(module.named_children()) - if len(named_children) != len(module): - raise ValueError("module with duplicate children is not supported") - - -def _verify_splitting( - module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device] -) -> None: - num_parameters = len(list(module.parameters())) - num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) - if num_parameters == num_child_parameters: - return - - for i in range(len(partitions)): - for j in range(i + 1, len(partitions)): - parti = partitions[i] - partj = partitions[j] - if devices[i] == devices[j]: - continue - for p in parti.parameters(): - for q in partj.parameters(): - if p is q: - raise ValueError("module with duplicate parameters on distinct devices is not supported") - - -class BalanceError(ValueError): - pass - - -def _retrieve_device(module: nn.Module) -> torch.device: - """Validates all parameters in the Module have the same device and returns - the appropriate device. - - Args: - An ``nn.Module`` to process. - - Returns: - ``torch.Device`` for the entire module. - - Raises: - ValueError: - If devices for ``nn.Module`` parameters are not all same. - """ - - device = None - for parameter in module.parameters(): - if device is None: - device = parameter.device - elif device != parameter.device: - raise ValueError( - f'nn.Module: {module}, should have all parameters on a single device,' - ' please use .to() to place the module on a single device') - - return device if device is not None else torch.device("cpu") - - -class PipeSequential(nn.Sequential): - """ - Pipe variant of ``nn.Sequential`` which supports multiple inputs. - """ - - def forward(self, *inputs): - for module in self: - if isinstance(inputs, Tuple): # type: ignore[arg-type] - inputs = module(*inputs) - else: - # Don't expand single variables (ex: lists/Tensor) - inputs = module(inputs) - return inputs - - -class WithDevice(nn.Module): - """ - Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe` - that overrides the device for that module. In cases where :class:`Pipe` - can't implicitly determine the device for the module and places it on CPU, - this wrapper can be used to override the implicit behavior and explicitly - specify which device a module should run on. - - The provided module is also moved to the given device via ``.to(device)`` - by :class:`Pipe` - - Args: - module(:class:`torch.nn.Module`): The module to be wrapped. - device(:class:`torch.device`): The device to run the module on. - - Example:: - >>> # xdoctest: +SKIP("distributed") - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> dropout = nn.Dropout() - >>> - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) - >>> # Dropout does not have any parameters/buffers, but we want to - >>> # run it on cuda:1 to avoid any GPU to CPU transfers. - >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1')) - >>> # xdoctest: +SKIP("Needs RPC framework init") - >>> model = Pipe(model, chunks=8) - """ - def __init__(self, module: nn.Module, device: torch.device): - super().__init__() - self._module = module - self._device = torch.device(device) - - def forward(self, *args, **kwargs): - return self._module(*args, **kwargs) - - @property - def module(self): - return self._module - - @property - def device(self): - return self._device - - -def _assemble_partition(modules: List[nn.Module]): - modules_list: List[nn.Module] = [] - for module in modules: - if isinstance(module, nn.Sequential): - modules_list.extend(module.children()) - else: - modules_list.append(module) - return PipeSequential(*modules_list) - - -def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]: - partitions = [] - devices = [] - - current_partition = [] - current_device = None - for name, module in modules.named_children(): - if isinstance(module, WithDevice): - # Process device override and move module to appropriate device. - device = module.device - module = module.module - module.to(device) - else: - device = _retrieve_device(module) - if current_device is not None and (current_device != device or device.type == 'cpu'): - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - current_partition = [] - current_device = device - current_partition.append(module) - - if current_device is not None: - partitions.append(_assemble_partition(current_partition)) - devices.append(current_device) - - partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) - - return partitions, devices - - -MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") - - -class Pipe(Module): - """Wraps an arbitrary :class:`nn.Sequential ` module - to train on using synchronous pipeline parallelism. If the module requires - lots of memory and doesn't fit on a single GPU, pipeline parallelism is a - useful technique to employ for training. - - The implementation is based on the torchgpipe_ paper. - - .. _torchgpipe: https://arxiv.org/abs/2004.09910 - - Pipe combines pipeline parallelism with checkpointing to reduce peak - memory required to train while minimizing device under-utilization. - - You should place all the modules on the appropriate devices and wrap them - into an :class:`nn.Sequential ` module defining the - desired order of execution. If a module does not contain any - parameters/buffers, it is assumed this module should be executed on CPU - and appropriate input tensors to the module are moved to CPU before - execution. This behavior can be overridden by the :class:`WithDevice` - wrapper which can be used to explicitly specify which device a module - should run on. - - Args: - module (:class:`nn.Sequential `): - sequential module to be parallelized using pipelining. Each module - in the sequence has to have all of its parameters on a single - device. Each module in the sequence has to either be an nn.Module - or :class:`nn.Sequential ` (to combine multiple - sequential modules on a single device) - chunks (int): - number of micro-batches (default: ``1``) - checkpoint (str): - when to enable checkpointing, one of ``'always'``, - ``'except_last'``, or ``'never'`` (default: ``'except_last'``). - ``'never'`` disables checkpointing completely, ``'except_last'`` - enables checkpointing for all micro-batches except the last one - and ``'always'`` enables checkpointing for all micro-batches. - deferred_batch_norm (bool): - whether to use deferred ``BatchNorm`` moving statistics (default: - :data:`False`). If set to :data:`True`, we track statistics across - multiple micro-batches to update the running statistics per - mini-batch. - - Raises: - TypeError: - the module is not a :class:`nn.Sequential `. - ValueError: - invalid arguments - - Example:: - Pipeline of two FC layers across GPUs 0 and 1. - - >>> # Need to initialize RPC framework first. - >>> # xdoctest: +SKIP - >>> os.environ['MASTER_ADDR'] = 'localhost' - >>> os.environ['MASTER_PORT'] = '29500' - >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1) - >>> - >>> # Build pipe. - >>> fc1 = nn.Linear(16, 8).cuda(0) - >>> fc2 = nn.Linear(8, 4).cuda(1) - >>> model = nn.Sequential(fc1, fc2) - >>> model = Pipe(model, chunks=8) - >>> input = torch.rand(16, 16).cuda(0) - >>> output_rref = model(input) - - .. note:: - You can wrap a :class:`Pipe` model with - :class:`torch.nn.parallel.DistributedDataParallel` only when the - checkpoint parameter of :class:`Pipe` is ``'never'``. - - .. note:: - :class:`Pipe` only supports intra-node pipelining currently, but - will be expanded to support inter-node pipelining in the future. - The forward function returns an :class:`~torch.distributed.rpc.RRef` - to allow for inter-node pipelining in the future, where the output - might be on a remote host. For intra-node pipelining you can use - :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the - output locally. - - .. warning:: - :class:`Pipe` is experimental and subject to change. - """ - - def __init__( - self, - module: nn.Sequential, - chunks: int = 1, - checkpoint: str = "except_last", - deferred_batch_norm: bool = False, - ) -> None: - super().__init__() - - # Check if RPC framework is initialized. - if not torch.distributed.rpc._is_current_rpc_agent_set(): - raise RuntimeError( - 'Please initialize RPC framework for Pipe using ' - 'torch.distributed.rpc.init_rpc') - - chunks = int(chunks) - checkpoint = str(checkpoint) - - if chunks <= 0: - raise ValueError("number of chunks must be positive integer") - if checkpoint not in ["always", "except_last", "never"]: - raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") - - _verify_module(module) - - # Verify if the underlying skippable modules satisfy integrity. The - # integrity can be verified before forward() because it is static. - verify_skippables(module) - - self.chunks = chunks - self.checkpoint = checkpoint - - if deferred_batch_norm: - module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) - - self.partitions, self.devices = _split_module(module) - _verify_splitting(module, self.partitions, self.devices) - - self._copy_streams: List[List[AbstractStream]] = [] - self._skip_layout = inspect_skip_layout(self.partitions) - - # Separate CUDA streams for copy. - copy_streams = self._ensure_copy_streams() - - # The micro-batch index where the checkpointing stops. - checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] - - self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) - - def __len__(self) -> int: - """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) - - def __getitem__(self, index: int) -> nn.Module: - """Gets a layer in the underlying sequential module.""" - partitions = self.partitions - if index < 0: - partitions = partitions[::-1] - - for partition in partitions: - try: - return partition[index] - except IndexError: - pass - - shift = len(partition) - - if index < 0: - index += shift - else: - index -= shift - - raise IndexError - - def __iter__(self) -> Iterator[nn.Module]: - """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition - - # Pipe should manage the device of each partition. - # Deny cuda(), cpu(), and to() with device, by TypeError. - def cuda(self, device: Optional[Device] = None) -> "Pipe": - raise MOVING_DENIED - - def cpu(self) -> "Pipe": - raise MOVING_DENIED - - def to(self, *args: Any, **kwargs: Any) -> "Pipe": - # Deny these usages: - # - # - to(device[, dtype, non_blocking]) - # - to(tensor[, non_blocking]) - # - # But allow this: - # - # - to(dtype[, non_blocking]) - # - if "device" in kwargs or "tensor" in kwargs: - raise MOVING_DENIED - - if args: - if isinstance(args[0], (torch.device, int, str)): - raise MOVING_DENIED - if torch.is_tensor(args[0]): - raise MOVING_DENIED - - return super().to(*args, **kwargs) - - def _ensure_copy_streams(self) -> List[List[AbstractStream]]: - """Ensures that :class:`Pipe` caches CUDA streams for copy. - - It's worth to cache CUDA streams although PyTorch already manages a - pool of pre-allocated CUDA streams, because it may reduce GPU memory - fragmentation when the number of micro-batches is small. - - """ - if not self._copy_streams: - for device in self.devices: - self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) - - return self._copy_streams - - def forward(self, *inputs) -> RRef: - """ - Processes a single input mini-batch through the pipe and returns an - :class:`~torch.distributed.rpc.RRef` pointing to the output. - :class:`Pipe` is a fairly transparent module wrapper. It doesn't - modify the input and output signature of the underlying module. But - there's type restriction. Input and output have to contain at least one - tensor. This restriction is applied at partition boundaries too. - - The sequence of inputs are fed into the first stage of the pipeline as - ``*inputs``. As a result the positional args for this function should - match the positional args for the first stage of the pipeline. The same - condition applies for output of one stage of the pipeline which is the - input for the next stage. - - The input tensor is split into multiple micro-batches based on the - ``chunks`` parameter used to initialize :class:`Pipe`. The batch size - is assumed to be the first dimension of the tensor and if the batch - size is less than ``chunks``, the number of micro-batches is equal to - the batch size. - - Only tensors are split into multiple micro-batches, non-Tensor inputs - are just replicated as-is in each micro-batch. For non-Tensor outputs - in the last stage of the pipeline, they are aggregated as a ``List`` - and returned the user. For example, if you have 2 micro-batches - returning the integer 5, the user would receive the consolidated - output of `[5, 5]` - - All the input tensors need to be on the same device as the first - partition of the pipeline. - - If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor - is not split across micro-batches and is replicated as-is similar to - non-tensors. - - Args: - inputs: input mini-batch - - Returns: - :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch - - Raises: - TypeError: input doesn't contain at least one tensor - - """ - first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu") - microbatch.check(first_partition_device, *inputs) - - if not self.devices: - # Empty sequential module is not illegal. - return RRef(*inputs) - - # Divide a mini-batch into micro-batches. - batches = microbatch.scatter(*inputs, chunks=self.chunks) - - # Run pipeline parallelism. - self.pipeline.run(batches) - - # Merge the micro-batches into one mini-batch. - output = microbatch.gather(batches) - return RRef(output) diff --git a/torch/distributed/pipeline/sync/pipeline.py b/torch/distributed/pipeline/sync/pipeline.py deleted file mode 100644 index 7cd5e5831169..000000000000 --- a/torch/distributed/pipeline/sync/pipeline.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The pipeline parallelism of Pipe.""" -from queue import Queue -from types import TracebackType -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence - -import torch -from torch import Tensor, nn -from torch.autograd.profiler import record_function - -from .checkpoint import Checkpointing -from .copy import Copy, Wait -from .dependency import fork, join -from .microbatch import Batch -from .skip.layout import SkipLayout -from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker -from .stream import AbstractStream, current_stream, use_device -from .worker import Task, create_workers - -__all__: List[str] = ["Pipeline"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -def _depend(fork_from: Batch, join_to: Batch) -> None: - fork_from_idx = fork_from.find_tensor_idx() - join_to_idx = join_to.find_tensor_idx() - - fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) - join_to[join_to_idx] = join(join_to[join_to_idx], phony) - - -def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Copy.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: - batch[:] = Wait.apply(prev_stream, next_stream, *batch) - # Gradients are only supported for float Tensors. - batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) - - -def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: - """Generate schedules for each clock cycle.""" - # m: number of micro-batches - # n: number of partitions - # i: index of micro-batch - # j: index of partition - # k: clock number - # - # k (i,j) (i,j) (i,j) - # - ----- ----- ----- - # 0 (0,0) - # 1 (1,0) (0,1) - # 2 (2,0) (1,1) (0,2) - # 3 (2,1) (1,2) - # 4 (2,2) - for k in range(m + n - 1): - yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] - - -class Pipeline: - """The pipeline parallelism for Pipe.""" - - def __init__( - self, - partitions: List[nn.Sequential], - devices: List[torch.device], - copy_streams: List[List[AbstractStream]], - skip_layout: SkipLayout, - checkpoint_stop: int, - ) -> None: - self.partitions = partitions - self.devices = devices - self.copy_streams = copy_streams - self.skip_layout = skip_layout - self.checkpoint_stop = checkpoint_stop - (self.in_queues, self.out_queues) = create_workers(devices) - - def run(self, batches: List[Batch]) -> None: - """Runs pipeline parallelism. - - It modifies the given batches in place. - - """ - partitions = self.partitions - devices = self.devices - skip_layout = self.skip_layout - - m = len(batches) - n = len(partitions) - - skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] - - for schedule in _clock_cycles(m, n): - self.fence(batches, schedule, skip_trackers) - self.compute(batches, schedule, skip_trackers) - - def fence( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Copy micro-batches after computation for the previous micro-batches.""" - copy_streams = self.copy_streams - skip_layout = self.skip_layout - - for i, j in schedule: - # Ensure that batches[i-1] is executed after batches[i] in - # backpropagation by an explicit dependency. - if i != 0 and j != 0: - _depend(batches[i - 1], batches[i]) - - next_stream = copy_streams[j][i] - - for prev_j, ns, name in skip_layout.copy_policy(j): - prev_stream = copy_streams[prev_j][i] - skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) - - if j != 0: - prev_stream = copy_streams[j - 1][i] - _copy(batches[i], prev_stream, next_stream) - - def compute( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], - ) -> None: - """Run tasks with synchronization to copy streams.""" - partitions = self.partitions - devices = self.devices - copy_streams = self.copy_streams - checkpoint_stop = self.checkpoint_stop - - # Disable checkpointing if in eval mode. - if not self.partitions[0].training: - checkpoint_stop = 0 - - n = len(partitions) - streams = [current_stream(d) for d in devices] - exc_info: Optional[ExcInfo] = None - - # With checkpointing, the autograd graph looks like this diagram: - # +-----+------+ - # | Copy | - # +-----+------+ (fence) - # - - - + - - - - - - - - - - # | (compute) - # +-----+------+ - # | Wait | [1] Synchronize the current stream with the copy stream. - # +-----+------+ - # +-----+------+ - # | Checkpoint | [2] Compute a partition within checkpointing. - # +-----+------+ - # +-----+------+ - # | Wait | [3] Synchronize the copy stream with the current stream. - # +-----+------+ - # + - - - + - # | +-----+-----+ - # | | Recompute | [4] Schedule the recomputation at backpropagation. - # | +-----+-----+ - # + - - - + - # | - # - - - + - - - - - - - - - - # +-----+------+ (fence) - # | Copy | - # +-----+------+ - for i, j in schedule: - batch = batches[i] - partition = partitions[j] - - # Synchronize with the copied input. ([1] in the diagram) - if j != 0: - _wait(batch, copy_streams[j][i], streams[j]) - - # Determine whether checkpointing or not. - checkpoint = i < checkpoint_stop - if checkpoint: - - def function( - *inputs, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> TensorOrTensors: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return partition(*inputs) - - chk = Checkpointing(function, batch) # type: ignore[arg-type] - task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) - del function, chk - - else: - - def compute( - batch: Batch = batch, - partition: nn.Module = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> Batch: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return batch.call(partition) - - task = Task(streams[j], compute=compute, finalize=None) - del compute - - # Compute tasks in parallel. ([2] in the diagram) - self.in_queues[j].put(task) - - for i, j in schedule: - ok, payload = self.out_queues[j].get() - - # Hold the first exception. - if exc_info is not None: - continue - elif not ok: - exc_info = cast(ExcInfo, payload) - continue - - task, batch = cast(Tuple[Task, Batch], payload) - - # The copy stream synchronizes to copy the output. ([3] in the - # diagram) - if j != n - 1: - _wait(batch, streams[j], copy_streams[j][i]) - - # Finalize tasks. If checkpointing is enabled, here the - # recomputation is scheduled at backpropagation. ([4] in the - # diagram) - with use_device(devices[j]): - task.finalize(batch) - - batches[i] = batch - - # Fail at the first exception. - if exc_info is not None: - raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/pipeline/sync/py.typed b/torch/distributed/pipeline/sync/py.typed deleted file mode 100644 index ab03724cafbf..000000000000 --- a/torch/distributed/pipeline/sync/py.typed +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/pipeline/sync/skip/__init__.py b/torch/distributed/pipeline/sync/skip/__init__.py deleted file mode 100644 index bdcb913867a7..000000000000 --- a/torch/distributed/pipeline/sync/skip/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Supports efficiency with skip connections.""" -from .namespace import Namespace -from .skippable import pop, skippable, stash, verify_skippables - -__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/pipeline/sync/skip/layout.py b/torch/distributed/pipeline/sync/skip/layout.py deleted file mode 100644 index 04d76d34ea16..000000000000 --- a/torch/distributed/pipeline/sync/skip/layout.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Static skip connection layout of ``@skippable`` modules.""" -from typing import Dict, Iterable, List, Tuple - -from torch import nn - -from .namespace import Namespace - -__all__: List[str] = [] - - -class SkipLayout: - """Represents a skip connection layout across partitions.""" - - # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} - by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] - - # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] - by_partition: List[List[Tuple[int, Namespace, str]]] - - def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: - # The skip routes are already indexed by 'ns, name'. - self.by_ns_name = skip_routes - - # Index skip routes by partition number 'j'. - self.by_partition = [[] for _ in range(num_partitions)] - - for (ns, name), (prev_j, next_j) in skip_routes.items(): - self.by_partition[next_j].append((prev_j, ns, name)) - - for p in self.by_partition: - p.sort() - - def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: - """Generates skip routes for the given destination partition number. - The skip routes are sorted by source partition number in ascending - order. - - Yields: - Each tuple of (source partition number, namespace, name). - - """ - for prev_j, ns, name in self.by_partition[next_j]: - if prev_j == next_j: - # This skip tensor will be popped at the same partition where - # it is stashed. In this case, copy is not required. - continue - - yield (prev_j, ns, name) - - def requires_copy(self, ns: Namespace, name: str) -> bool: - """Whether the given namespace and name requires partition-to-partition - copy or not. - """ - prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) - return prev_j != next_j - - -def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: - """Inspects the skip connection layout in the given partitions.""" - # NOTE(sublee): Hide circular import inside this subroutine. Circular - # import is not ideal but placing this logic near to SkipLayout may - # increase cohesion of code. - from .skippable import Skippable - - skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} - stashed_at: Dict[Tuple[Namespace, str], int] = {} - - for j, partition in enumerate(partitions): - def inspect_layer(layer): - if not isinstance(layer, Skippable): - return - - for ns, name in layer.stashable(): - stashed_at[(ns, name)] = j - - for ns, name in layer.poppable(): - prev_j = stashed_at.pop((ns, name)) - skip_routes[(ns, name)] = (prev_j, j) - - if isinstance(partition, nn.Sequential): - for layer in partition: - inspect_layer(layer) - else: - inspect_layer(partition) - - return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/pipeline/sync/skip/namespace.py b/torch/distributed/pipeline/sync/skip/namespace.py deleted file mode 100644 index 7d9c0d9b7d84..000000000000 --- a/torch/distributed/pipeline/sync/skip/namespace.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Provides isolated namespace of skip tensors.""" -import abc -from functools import total_ordering -from typing import Any -import uuid - -__all__ = ["Namespace"] - - -@total_ordering -class Namespace(metaclass=abc.ABCMeta): # noqa: B024 - """Namespace for isolating skip tensors used by :meth:`isolate() - `. - """ - - __slots__ = ("id",) - - def __init__(self) -> None: - self.id = uuid.uuid4() - - def __repr__(self) -> str: - return f"" - - def __hash__(self) -> int: - return hash(self.id) - - # Namespaces should support ordering, since SkipLayout will sort tuples - # including a namespace. But actual order between namespaces is not - # important. That's why they are ordered by version 4 UUID which generates - # random numbers. - def __lt__(self, other: Any) -> bool: - if isinstance(other, Namespace): - return self.id < other.id - return False - - def __eq__(self, other: object) -> bool: - if isinstance(other, Namespace): - return self.id == other.id - return False - - -# 'None' is the default namespace, -# which means that 'isinstance(None, Namespace)' is 'True'. -Namespace.register(type(None)) diff --git a/torch/distributed/pipeline/sync/skip/portal.py b/torch/distributed/pipeline/sync/skip/portal.py deleted file mode 100644 index 335793f4cc13..000000000000 --- a/torch/distributed/pipeline/sync/skip/portal.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the -autograd engine. The shared context of three functions (:class:`PortalBlue`, -:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is -one of the most important feature of :mod:`torchpipe.skip`. - -The metaphor is inspired by Portal(tm) from Valve. - -""" -from typing import List, Optional, Tuple - -import torch -from torch import Tensor - -from ..copy import Context as CopyContext -from ..copy import Copy -from ..phony import get_phony -from ..stream import AbstractStream, get_device - -__all__: List[str] = [] - - -class Portal: - """A portal for a tensor.""" - - def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: - self.put_tensor(tensor, tensor_life) - self.grad: Optional[Tensor] = None - - def blue(self) -> Tensor: - """Creates a :class:`PortalBlue` which hides the underlying tensor from - the autograd engine. - - Join the returning phony to the main lane of the autograd graph to - assure the correct backpropagation:: - - PortalBlue --+ - | - ---------- Join -- - - """ - tensor = self.use_tensor() - - if tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalBlue.apply(self, tensor) - - def orange(self, phony: Tensor) -> Optional[Tensor]: - """Creates a :class:`PortalOrange` which retrieves the hidden tensor - without losing ability of backpropagation. - - Give a phony forked from the main lane of an autograd graph:: - - +-- PortalOrange --+ - | | - -- Fork --------- f(a, b) -- - - """ - self.check_tensor_life() - - if self.tensor is None: - return self.use_tensor() - - return PortalOrange.apply(self, phony) - - def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: - """Copies the hidden tensor by a :class:`PortalCopy`. - - Give a phony and use the returning phony to keep backpropagation:: - - +-- PortalCopy --+ - | | - -- Fork ---------- Join -- - - """ - if self.tensor is None: - return get_phony(torch.device("cpu"), requires_grad=False) - - return PortalCopy.apply(self, prev_stream, next_stream, phony) - - def check_tensor_life(self) -> None: - if self.tensor_life <= 0: - raise RuntimeError("tensor in portal has been removed") - - def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: - """Stores a tensor into this portal.""" - # [Life of Tensor through Portal] - # - # The tensor can be retrieved by use_tensor() up to 'tensor_life' - # times. When the life becomes 0, the tensor will be deleted for - # deallocation in CUDA memory. - # - # The below events participate in a tensor through a portal. - # Note that [x] denotes the events which call use_tensor(): - # - # 1. [x] blue() - # 2. [ ] PortalBlue.forward - # 3. [ ] copy() - # 4. [ ] PortalCopy.forward - # 5. [ ] orange() - # 6. [x] PortalOrange.forward - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 7. [ ] orange() (recomputed) - # 8. [x] PortalOrange.forward (recomputed) - # 9. [ ] PortalOrange.backward - # 10. [ ] PortalCopy.backward - # 11. [x] blue() (recomputed) - # 12. [ ] PortalBlue.forward (recomputed) - # 13. [ ] PortalBlue.backward - # - self.tensor_life = tensor_life - - if tensor_life > 0: - self.tensor = tensor - else: - self.tensor = None - - def use_tensor(self) -> Optional[Tensor]: - """Retrieves the underlying tensor and decreases the tensor life. When - the life becomes 0, it the tensor will be removed. - """ - self.check_tensor_life() - - tensor = self.tensor - - self.tensor_life -= 1 - - if self.tensor_life <= 0: - self.tensor = None - - return tensor - - def put_grad(self, grad: Tensor) -> None: - """Stores a gradient into this portal.""" - self.grad = grad - - def use_grad(self) -> Tensor: - """Retrieves and removes the underlying gradient. The gradient is - always ephemeral. - """ - if self.grad is None: - raise RuntimeError("grad in portal has been removed or never set") - - grad = self.grad - self.grad = None - return grad - - -# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and -# :class:`PortalCopy`. -class Context(CopyContext): - portal: Portal - - -class PortalBlue(torch.autograd.Function): - """Hides a tensor from the autograd engine by a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, - portal: Portal, - # This tensor must be retrieved by portal.use_tensor(). - tensor: Tensor, - ) -> Tensor: - ctx.portal = portal - - phony = get_phony(tensor.device, requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: - # The paired PortalOrange should keep the gradient. - grad = ctx.portal.use_grad() - return None, grad - - -class PortalOrange(torch.autograd.Function): - """Retrieves the hidden tensor from a :class:`Portal`.""" - - @staticmethod - # type: ignore[override] - def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: - ctx.portal = portal - - tensor = portal.use_tensor() - assert tensor is not None - - return tensor.detach() - - @staticmethod - def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore[override] - # The paired PortalBlue will use the gradient. - ctx.portal.put_grad(grad) - return None, None - - -class PortalCopy(torch.autograd.Function): - """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden - tensor with copied one. - """ - - @staticmethod - # type: ignore[override] - def forward( - ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, - ) -> Tensor: - ctx.portal = portal - - assert portal.tensor is not None - (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) - - phony = get_phony(get_device(next_stream), requires_grad=False) - return phony.detach() - - @staticmethod - # type: ignore[override] - def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: - portal = ctx.portal - - assert portal.grad is not None - _, _, portal.grad = Copy.backward(ctx, portal.grad) - - return None, None, None, None diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py deleted file mode 100644 index 9d4db76c6b67..000000000000 --- a/torch/distributed/pipeline/sync/skip/skippable.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""The user interface to define skip connections.""" -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - FrozenSet, - Generator, - Iterable, - List, - Optional, - Set, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -from torch import Tensor, nn - -from ..microbatch import Batch -from .namespace import Namespace -from .tracker import current_skip_tracker - -__all__ = ["skippable", "stash", "pop", "verify_skippables"] - - -Tensors = Sequence[Tensor] -TensorOrTensors = Union[Tensor, Tensors] - -StashPop = Union["stash", "pop"] -StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] -if TYPE_CHECKING: - # Typechecking: nn.Module is not a Generic - SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg] -else: - SkippableModule = nn.Module - -T = TypeVar("T", bound="Skippable") - - -class Skippable(nn.Module): - """The base class for skippable modules. - - Do not use this class directly. Define a subclass by :func:`skippable` - instead. - - """ - - module_cls: ClassVar[Type[SkippableModule]] - stashable_names: ClassVar[FrozenSet[str]] - poppable_names: ClassVar[FrozenSet[str]] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg] - self.namespaces: Dict[str, Namespace] = {} - - def __repr__(self) -> str: - return f"@skippable({self.module})" - - def namespaced(self, name: str) -> Tuple[Namespace, str]: - """Prepend namespace for the given skip name.""" - ns = self.namespaces.get(name) - ns = cast(Namespace, ns) - return (ns, name) - - def stashable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be stashed.""" - for name in self.stashable_names: - yield self.namespaced(name) - - def poppable(self) -> Iterable[Tuple[Namespace, str]]: - """Iterate over namespaced skip names to be popped.""" - for name in self.poppable_names: - yield self.namespaced(name) - - def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: - r"""Isolate a specified subset or the whole set of skip tensors. - - In a single sequential module, skip tensors with the same - name are not allowed unless they are isolated by different namespaces. - - Here's an example using the same name for skip tensors twice. Each pair - of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` - and ``ns2``. There is no conflict anymore:: - - ns1 = Namespace() - ns2 = Namespace() - - model = nn.Sequential( - Layer1().isolate(ns1), - Layer1().isolate(ns2), - Layer2(), - Layer3().isolate(ns2), - Layer3().isolate(ns1), - ) - - When `only` parameter is omitted, all skip tensors are isolated. You - can isolate a subset of skip tensors by passing `only` parameter:: - - ns_alice = Namespace() - ns_bob = Namespace() - - model = nn.Sequential( - ... - StashStashPop().isolate(ns_alice, only=['alice']) \ - .isolate(ns_bob, only=['bob']), - ... - ) - - Args: - ns (Namespace): - namespace for isolation - - Keyword Args: - only (iterable of strs): - names of specific skip tensors to be isolated (omit this option - to isolate all skip tensors declared in this module) - - Returns: - this module itself - - """ - names: Iterable[str] - - if only is None: - names = self.stashable_names | self.poppable_names - else: - names = set(only) - - for name in names: - self.namespaces[name] = ns - - return self - - def dispatch( - self, - input, - handle_stash: Callable[[str, Optional[Tensor]], None], - handle_pop: Callable[[str], Optional[Tensor]], - ): - """Dispatch :class:`stash` or :class:`pop` commands. - - The commands are generated by the module's ``forward()``. - """ - generator = self.module(input) - - if not isinstance(generator, Generator): - # The underlying module returned output without any yield. - output = generator - return output - - try: - op = next(generator) - - while True: - if isinstance(op, stash): - handle_stash(op.name, op.tensor) - op = next(generator) - continue - - if isinstance(op, pop): - tensor = handle_pop(op.name) - op = generator.send(tensor) - continue - - raise TypeError(f"{op!r} is not a command from @skippable") - - except StopIteration as stop: - output = stop.args[0] - return output - - def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors: - """Perform the forward propagation. - - :class:`stash` or :class:`pop` commands will be handled by portals - silently. The portals won't be exposed to users. - - Raises: - RuntimeError: - illegal 'stash' or 'pop' is found. - - """ - skip_tracker = current_skip_tracker() - stashed_tensors: Dict[str, Optional[Tensor]] = {} - - # Load skip tensors that might be popped. - poppable_tensors = {} - batch = Batch(input) - for ns, name in self.poppable(): - try: - poppable_tensors[name] = skip_tracker.load(batch, ns, name) - except KeyError as e: - raise RuntimeError(f"'{name}' has not been stashed") from e - input = batch.values - - # Handle skip commands. - def handle_stash(name: str, tensor: Optional[Tensor]) -> None: - if name not in self.stashable_names: - raise RuntimeError(f"'{name}' has not been declared as stashable") - stashed_tensors[name] = tensor - - def handle_pop(name: str) -> Optional[Tensor]: - if name not in self.poppable_names: - raise RuntimeError(f"'{name}' has not been declared as poppable") - return poppable_tensors.pop(name) - - output = self.dispatch(input, handle_stash, handle_pop) - - # All declared skips must be stashed or popped. - not_stashed = self.stashable_names - stashed_tensors.keys() - if not_stashed: - comma_names = ", ".join(f"'{n}'" for n in not_stashed) - raise RuntimeError(f"{comma_names} must be stashed but have not") - - not_popped = poppable_tensors.keys() - if not_popped: - comma_names = ", ".join(f"'{n}'" for n in not_popped) - raise RuntimeError(f"{comma_names} must be popped but have not") - - # Save stashed skip tensors. - batch = Batch(output) - for ns, name in self.stashable(): - tensor = stashed_tensors[name] - skip_tracker.save(batch, ns, name, tensor) - output = batch.values - - return output - - -# TODO(sublee): Move to above of Skippable class for better read flow. -def skippable( - stash: Iterable[str] = (), pop: Iterable[str] = (), -) -> Callable[[Type[SkippableModule]], Type[Skippable]]: - """Define a decorator to create :class:`nn.Module ` with skip connections. - - These decorated modules are called "skippable". This functionality works perfectly - fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`. - - Each skip tensor is managed by its name. Before manipulating skip tensors, - a skippable module must statically declare the names for skip tensors by - `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be - stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield - pop(name)``. - - Here is an example with three layers. A skip tensor named "1to3" is stashed - and popped at the first and last layer, respectively:: - - @skippable(stash=['1to3']) - class Layer1(nn.Module): - def forward(self, input): - yield stash('1to3', input) - return f1(input) - - class Layer2(nn.Module): - def forward(self, input): - return f2(input) - - @skippable(pop=['1to3']) - class Layer3(nn.Module): - def forward(self, input): - skip_1to3 = yield pop('1to3') - return f3(input) + skip_1to3 - - model = nn.Sequential(Layer1(), Layer2(), Layer3()) - - One skippable module can stash or pop multiple skip tensors:: - - @skippable(stash=['alice', 'bob'], pop=['carol']) - class StashStashPop(nn.Module): - def forward(self, input): - yield stash('alice', f_alice(input)) - yield stash('bob', f_bob(input)) - carol = yield pop('carol') - return input + carol - - Every skip tensor must be associated with exactly one pair of `stash` and - `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this - restriction automatically when wrapping a module. You can also check the - restriction by :func:`verify_skippables` - without :class:`~torch.distributed.pipeline.sync.Pipe`. - - """ - stashable_names = frozenset(stash) - poppable_names = frozenset(pop) - - def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: - name = module_cls.__name__ - bases = (Skippable,) - attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} - return type(name, bases, attrs) - - return extend_skippable - - -class stash: - """The command to stash a skip tensor. - - :: - - def forward(self, input): - yield stash('name', input) - return f(input) - - Args: - name (str): name of skip tensor - input (torch.Tensor or None): tensor to pass to the skip connection - - """ - - __slots__ = ("name", "tensor") - - def __init__(self, name: str, tensor: Optional[Tensor]) -> None: - self.name = name - self.tensor = tensor - - -class pop: - """The command to pop a skip tensor. - - :: - - def forward(self, input): - skip = yield pop('name') - return f(input) + skip - - Args: - name (str): name of skip tensor - - Returns: - the skip tensor previously stashed by another layer under the same name - - """ - - __slots__ = ("name",) - - def __init__(self, name: str) -> None: - self.name = name - - -def verify_skippables(module: nn.Sequential) -> None: - """Verify if the underlying skippable modules satisfy integrity. - - Every skip tensor must have only one pair of `stash` and `pop`. If there - are one or more unmatched pairs, it will raise :exc:`TypeError` with the - detailed messages. - - Here are a few failure cases. :func:`verify_skippables` will report failure - for these cases:: - - # Layer1 stashes "1to3". - # Layer3 pops "1to3". - - nn.Sequential(Layer1(), Layer2()) - # +---- ? - - nn.Sequential(Layer2(), Layer3()) - # ? ----+ - - nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) - # +-------------------+ ^^^^^^ - - nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) - # ^^^^^^ +-------------------+ - - To use the same name for multiple skip tensors, they must be isolated by - different namespaces. See :meth:`isolate() - `. - - Raises: - TypeError: - one or more pairs of `stash` and `pop` are not matched. - - """ - stashed: Set[Tuple[Namespace, str]] = set() - popped: Set[Tuple[Namespace, str]] = set() - msgs: List[str] = [] - - for layer_name, layer in module.named_children(): - if not isinstance(layer, Skippable): - continue - - for name in layer.stashable_names & layer.poppable_names: - msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" - msgs.append(msg) - - for ns, name in layer.stashable(): - if name in layer.poppable_names: - continue - - if (ns, name) in stashed: - msg = f"'{layer_name}' redeclared '{name}' as stashable but not isolated by namespace" - msgs.append(msg) - continue - - stashed.add((ns, name)) - - for ns, name in layer.poppable(): - if name in layer.stashable_names: - continue - - if (ns, name) in popped: - msg = f"'{layer_name}' redeclared '{name}' as poppable but not isolated by namespace" - msgs.append(msg) - continue - - if (ns, name) not in stashed: - msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" - msgs.append(msg) - continue - - popped.add((ns, name)) - - for (_, name) in stashed - popped: - msg = f"no module declared '{name}' as poppable but stashed" - msgs.append(msg) - - if msgs: - raise TypeError( - "one or more pairs of stash and pop do not match:\n\n{}" "".format("\n".join(f"* {x}" for x in msgs)) - ) diff --git a/torch/distributed/pipeline/sync/skip/tracker.py b/torch/distributed/pipeline/sync/skip/tracker.py deleted file mode 100644 index 8ac82bc05dc9..000000000000 --- a/torch/distributed/pipeline/sync/skip/tracker.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Tracks skip tensors on a thread.""" -from contextlib import contextmanager -import threading -from typing import Dict, Generator, List, Optional, Tuple - -from torch import Tensor - -from ..checkpoint import is_checkpointing -from ..dependency import fork, join -from ..microbatch import Batch -from ..stream import AbstractStream -from .layout import SkipLayout -from .namespace import Namespace -from .portal import Portal - -__all__: List[str] = [] - - -class SkipTracker: - """Tracks saved skip tensors. - - It will update the given micro-batch in place. This is because when it - manipulates the underlying skip tensors, the current micro-batch also has - to be connected with the skip tensors. - - One thread has one skip tracker. Call :func:`current_skip_tracker` to get - the skip tracker on the current thread. - - """ - - def __init__(self) -> None: - self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - self.tensors[(ns, name)] = tensor - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - return self.tensors.pop((ns, name)) - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - raise TypeError("copy is not supported for non-portal skip tensors") - - -class SkipTrackerThroughPotals(SkipTracker): - """Tracks saved skip tensors through portals. The skip tensors will be - hidden in portals so that the autograd engine does not need to track them. - - This tracker is only used when the training or evaluating module is wrapped - with :class:`torchpipe.Pipe`. - - """ - - def __init__(self, skip_layout: SkipLayout) -> None: - super().__init__() - self.skip_layout = skip_layout - self.portals: Dict[Tuple[Namespace, str], Portal] = {} - - def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: - """Saves the stashed skip tensor in a portal. The portal is then - connected to the given micro-batch with :class:`Join`. - """ - if not self.skip_layout.requires_copy(ns, name): - super().save(batch, ns, name, tensor) - return - - # See [Tensor Life of Portal] at Portal.put_tensor() to understand the - # below tensor_life values. Here are the selected events which retrieve - # the tensor in portal: - # - # 1. [x] blue() - # ... - # 6. [x] PortalOrange.forward - # ... - # 8. [x] PortalOrange.forward (recomputed) - # ... - # 11. [x] blue() (recomputed) - # - if (ns, name) not in self.portals: - if is_checkpointing(): - # Under checkpointing, the tensor used by the first - # PortalOrange should be alive in the portal. This tensor will - # be used again by the second PortalOrange during the - # recomputation. - tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] - else: - tensor_life = 2 # Delete at [6. PortalOrange.forward] - - portal = Portal(tensor, tensor_life) - self.portals[(ns, name)] = portal - - else: - # Under recomputation, the portal already exists. - portal = self.portals[(ns, name)] - - # The existing tensor life already became 0. It should be reset as - # 1 to delete the tensor after the second PortalBlue immediately. - tensor_life = 1 # Delete at [11. blue() (recomputed)] - - portal.put_tensor(tensor, tensor_life) - - phony = portal.blue() - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx] = join(batch[tensor_idx], phony) - - def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: - """Loads a skip tensor from the corresponding portal to pop. The given - micro-batch is connected to the portal with :class:`Fork`. - """ - if not self.skip_layout.requires_copy(ns, name): - tensor = super().load(batch, ns, name) - return tensor - - portal = self.portals[(ns, name)] - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - tensor = portal.orange(phony) - return tensor - - def copy( - self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, - ) -> None: - """Copies the skip tensor in the corresponding portal. The given - micro-batch and the portal will be tied with :class:`Fork` and - :class:`Join`. - """ - assert self.skip_layout.requires_copy(ns, name) - - tensor_idx = batch.find_tensor_idx() - batch[tensor_idx], phony = fork(batch[tensor_idx]) - - portal = self.portals[(ns, name)] - phony = portal.copy(prev_stream, next_stream, phony) - - batch[tensor_idx] = join(batch[tensor_idx], phony) - - -class ThreadLocal(threading.local): - def __init__(self) -> None: - self.skip_tracker: Optional[SkipTracker] = None - - -thread_local = ThreadLocal() - - -@contextmanager -def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: - """Registers the given skip tracker on the current thread within a - context:: - - with use_skip_tracker(my_skip_tracker): - ... - - """ - orig = thread_local.skip_tracker - - thread_local.skip_tracker = skip_tracker - - try: - yield - finally: - thread_local.skip_tracker = orig - - -def current_skip_tracker() -> SkipTracker: - """Gets the skip tracker on the current thread.""" - skip_tracker = thread_local.skip_tracker - - if skip_tracker is None: - skip_tracker = SkipTracker() - thread_local.skip_tracker = skip_tracker - - return skip_tracker diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py deleted file mode 100644 index 59fedf865a42..000000000000 --- a/torch/distributed/pipeline/sync/stream.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Utilities for eliminating boilerplate code to handle abstract streams with -CPU device. -""" -from contextlib import contextmanager -from typing import Generator, List, Union, cast - -import torch - -__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream", - "use_device", "use_stream", "get_device", "wait_stream", "record_stream", - "is_cuda", "as_cuda"] - - -class CPUStreamType: - pass - - -# The placeholder on place of streams for the CPU device instead of CUDA. -CPUStream = CPUStreamType() - -# It represents both CUDA streams and the CPU stream. -AbstractStream = Union[torch.cuda.Stream, CPUStreamType] - - -def new_stream(device: torch.device) -> AbstractStream: - """Creates a new stream for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.Stream(device) - - -def current_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.current_stream(device) - - -def default_stream(device: torch.device) -> AbstractStream: - """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" - if device.type != "cuda": - return CPUStream - return torch.cuda.default_stream(device) - - -@contextmanager -def use_device(device: torch.device) -> Generator[None, None, None]: - """:func:`torch.cuda.device` for either CPU or CUDA device.""" - if device.type != "cuda": - yield - return - - with torch.cuda.device(device): - yield - - -@contextmanager -def use_stream(stream: AbstractStream) -> Generator[None, None, None]: - """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" - if not is_cuda(stream): - yield - return - - with torch.cuda.stream(as_cuda(stream)): - yield - - -def get_device(stream: AbstractStream) -> torch.device: - """Gets the device from CPU or CUDA stream.""" - if is_cuda(stream): - return as_cuda(stream).device - return torch.device("cpu") - - -def wait_stream(source: AbstractStream, target: AbstractStream) -> None: - """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It - makes the source stream wait until the target stream completes work queued. - """ - if is_cuda(target): - if is_cuda(source): - # A CUDA stream waits another CUDA stream. - as_cuda(source).wait_stream(as_cuda(target)) - else: - # CPU waits a CUDA stream. - as_cuda(target).synchronize() - - # If the target is CPU, synchronization is not required. - - -def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: - """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" - if is_cuda(stream): - # NOTE(sublee): record_stream() on a shifted view tensor throws - # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely - # protect the tensor against unexpected reallocation, here we use a - # temporal tensor associated with the same storage without shifting as - # a workaround. - # - # Issue: https://github.com/pytorch/pytorch/issues/27366 - # - tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) - - # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream - tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] - - -def is_cuda(stream: AbstractStream) -> bool: - """Returns ``True`` if the given stream is a valid CUDA stream.""" - return stream is not CPUStream - - -def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: - """Casts the given stream as :class:`torch.cuda.Stream`.""" - return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/pipeline/sync/utils.py b/torch/distributed/pipeline/sync/utils.py deleted file mode 100644 index 210c475317e2..000000000000 --- a/torch/distributed/pipeline/sync/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from torch import nn -from typing import List, Optional - -__all__ = ["partition_model"] - -def partition_model( - module: nn.Sequential, - balance: List[int], - devices: Optional[List[int]] = None): - """ - Partions the model accross multiple GPU devices. - - Given an :class:`nn.Sequential ` module, partitions - the model across multiple GPU devices according the provided ``balance`` - and ``devices``. - - Args: - module (:class:`nn.Sequential `): - Sequential model representing the pipe. - balance (List[int]): - List indicating the number of layers in each partition. - devices (List[int], optional): - List indicating the device to use for each partition. Defaults to - ``range(len(balance))`` - """ - device_idx = 0 - pipe_idx = 0 - balanced_pipe = [] - for num_layers in balance: - layers = [] - for i in range(num_layers): - layers.append(module[pipe_idx]) - pipe_idx += 1 - device = device_idx if devices is None else devices[device_idx] - balanced_pipe.append(nn.Sequential(*layers).to(device)) - device_idx += 1 - - return nn.Sequential(*balanced_pipe) diff --git a/torch/distributed/pipeline/sync/worker.py b/torch/distributed/pipeline/sync/worker.py deleted file mode 100644 index 87b20c4a5551..000000000000 --- a/torch/distributed/pipeline/sync/worker.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2019 Kakao Brain -# -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. -"""Multithreading in pipeline parallelism.""" -from contextlib import contextmanager -from queue import Queue -import sys -from threading import Thread -from types import TracebackType -from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast - -import torch - -from .microbatch import Batch -from .stream import AbstractStream, use_device, use_stream - -__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"] - - -ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] - -# Queue is generic only in stubs. -# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - InQueue = Queue[Optional["Task"]] - OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] -else: - InQueue = Queue - OutQueue = Queue - - -class Task: - """A task represents how to compute a micro-batch on a partition. - - It consists of two parts: :meth:`compute` and :meth:`finalize`. - :meth:`compute` should be executed in worker threads concurrently. - :meth:`finalize` should be executed after when worker threads complete to - execute :meth:`compute`. - - :meth:`compute` might be boosted by worker threads. Because it produces - several CUDA API calls by user code. In PyTorch, parallel CUDA API calls - are not serialized through GIL. So more than one CUDA API call can be - produced at the same time. - - """ - - def __init__( - self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], - ) -> None: - self.stream = stream - self._compute = compute - self._finalize = finalize - self._grad_enabled = torch.is_grad_enabled() - - def compute(self) -> Batch: - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - return self._compute() - - def finalize(self, batch: Batch) -> None: - if self._finalize is None: - return - with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): - self._finalize(batch) - - -def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: - """Main loop of a worker thread.""" - with use_device(device): - while True: - task = in_queue.get() - - if task is None: - break - - try: - batch = task.compute() - except Exception: - exc_info = cast(ExcInfo, sys.exc_info()) - out_queue.put((False, exc_info)) - continue - - out_queue.put((True, (task, batch))) - - done = (False, None) - out_queue.put(done) - - -def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: - """Spawns worker threads. A worker thread is bound to a device.""" - in_queues: List[InQueue] = [] - out_queues: List[OutQueue] = [] - - # Spawn workers. - workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} - - def normalize_device(device: torch.device) -> torch.device: - if device.type == "cuda" and device.index is None: - return torch.device("cuda", index=torch.cuda.current_device()) - - if device.type == "cpu" and device.index is not None: - return torch.device("cpu") - - return device - - for device in devices: - device = normalize_device(device) - - try: - in_queue, out_queue = workers[device] - except KeyError: - in_queue = Queue() - out_queue = Queue() - workers[device] = (in_queue, out_queue) - - t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) - t.start() - - in_queues.append(in_queue) - out_queues.append(out_queue) - - return (in_queues, out_queues) - -@contextmanager -def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: - try: - (in_queues, out_queues) = create_workers(devices) - yield (in_queues, out_queues) - finally: - pass diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py deleted file mode 100644 index 16940674b670..000000000000 --- a/torch/distributed/pipelining/PipelineSchedule.py +++ /dev/null @@ -1,854 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates - -import logging -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -from torch.profiler import record_function - -from ._IR import Pipe -from .microbatch import merge_chunks, split_args_kwargs_into_chunks -from .PipelineStage import _PipelineStageBase - - -__all__ = [ - "PipelineScheduleSingle", - "PipelineScheduleMulti", - "Schedule1F1B", - "ScheduleGPipe", - "ScheduleInterleaved1F1B", - "ScheduleLoopedBFS", -] - -logger = logging.getLogger(__name__) - - -class _PipelineSchedule(ABC): - def __init__( - self, - n_microbatches: int, - loss_fn: Optional[Callable[..., torch.Tensor]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - ): - # From arguments - self._n_microbatches = n_microbatches - self._loss_fn = loss_fn - self._output_merge_spec = output_merge_spec - # Derived - self._has_backward = self._loss_fn is not None - # To be filled by subclasses - self._pipe_info: Optional[Pipe.PipeInfo] = None - - # Holds the losses for each microbatch. - self._internal_losses: List[torch.Tensor] = [] - logger.info(f"Using {self.__class__.__name__}") # noqa: G004 - - def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): - if stage.is_last and self._has_backward: - loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] - self._internal_losses.append(loss) - - def _maybe_get_loss(self, stage, mb_index): - valid_index = 0 <= mb_index < len(self._internal_losses) - if stage.is_last and self._has_backward and valid_index: - return self._internal_losses[mb_index] - elif len(self._internal_losses) != 0 and not valid_index: - raise RuntimeError( - f"Loss for microbatch {mb_index} is not available. " - f"Available losses for microbatches: {self._internal_losses}" - ) - else: - return None - - def _update_losses(self, stages, losses): - """ - Update the losses to those in the internal state - """ - # if stages not a list turn into a list - if not isinstance(stages, list): - stages = [stages] - contains_last_stage = any(stage.is_last for stage in stages) - - # Return losses if there is a container passed in - if contains_last_stage and losses is not None: - if len(self._internal_losses) != self._n_microbatches: - raise RuntimeError( - f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" - ) - - # Clean external container first - losses.clear() - # Copy internal losses to external container - losses.extend(self._internal_losses) - - self._internal_losses.clear() - - @abstractmethod - def _step_microbatches( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, - ): - """ - Run one iteration of the pipeline schedule with list of microbatches. - Will go through all the microbatches according to the schedule - implementation. - - Args: - microbatches: list of microbatch args. - """ - raise NotImplementedError - - @abstractmethod - def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): - """ - Run one iteration of the pipeline schedule with *whole-batch* input. - Will chunk the input into microbatches automatically, and go through the - microbatches according to the schedule implementation. - - args: positional arguments to the model (as in non-pipeline case). - kwargs: keyword arguments to the model (as in non-pipeline case). - target: target for the loss function. - losses: a list to store the losses for each microbatch. - """ - raise NotImplementedError - - def _check_inputs( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, - ): - """ - Pre-process/check inputs - """ - - def check_type_and_len(mbs, name: str): - if not isinstance(mbs, list): - raise TypeError(f"{name} must be a list but got a {type(mbs)}") - if len(mbs) != self._n_microbatches: - raise ValueError( - f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" - ) - - if arg_mbs is not None: - check_type_and_len(arg_mbs, "arg_mbs") - else: - arg_mbs = [()] * self._n_microbatches - - if kwarg_mbs is not None: - check_type_and_len(kwarg_mbs, "kwarg_mbs") - else: - kwarg_mbs = [{}] * self._n_microbatches - - if target_mbs is not None: - check_type_and_len(target_mbs, "target_mbs") - - if losses is not None: - if not isinstance(losses, list): - raise TypeError(f"losses must be a list but got a {type(losses)}") - - return arg_mbs, kwarg_mbs - - def _compute_loss(self, output, target): - return self._loss_fn(output, target) # type: ignore[misc] - - def _split_inputs( - self, - args: Tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Splits a full-batch input into chunks (i.e. microbatches) and returns - the chunks - """ - if self._pipe_info is not None: - # Use spec from `pipe_info` - args_chunk_spec = self._pipe_info.args_chunk_spec - kwargs_chunk_spec = self._pipe_info.kwargs_chunk_spec - else: - # Use default spec from `microbatch.py` (i.e. chunk dim 0 for each arg/kwarg) - args_chunk_spec = None - kwargs_chunk_spec = None - - if args or kwargs: - args_split, kwargs_split = split_args_kwargs_into_chunks( - args, - kwargs, - self._n_microbatches, - args_chunk_spec, - kwargs_chunk_spec, - ) - return args_split, kwargs_split - else: - # Empty inputs (e.g. when called on middle stages) - # Return a list of empty tuples/dicts with matching length as chunks - return [()] * self._n_microbatches, [{}] * self._n_microbatches - - def _merge_outputs(self, output_chunks: List[Any]) -> Any: - """ - Merge output chunks back to a batch state. - If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). - """ - return merge_chunks( - output_chunks, - self._output_merge_spec, - ) - - -def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): - """ - Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. - """ - desc_str = f"{desc}, " if desc else "" - logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004 - return dist.batch_isend_irecv(p2p_ops).pop() - - -def _sorted_batch_p2p( - p2p_ops: List[dist.P2POp], desc: Optional[str] = None -) -> Dict[int, dist.Work]: - """ - Sorts the list of P2P ops by the peer rank, and then calls - batch_isend_irecv. Return a dictionary of works by peer rank. This function - helps us avoid hangs in case of skip connections. - """ - # Arrange p2p_ops by peer rank: - # int is the peer rank; - # List is the list of ops towards the peer - ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) - work_by_peer: Dict[int, dist.Work] = {} - if len(p2p_ops) == 0: - return work_by_peer - - # Classify the ops by peer rank - for op in p2p_ops: - ops_by_peer[op.peer].append(op) - - # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) - for peer, ops in sorted(ops_by_peer.items()): - work_by_peer[peer] = _batch_p2p(ops, desc=desc) - - return work_by_peer - - -class PipelineScheduleSingle(_PipelineSchedule): - """ - Base class for single-stage schedules. - Implements the `step` method. - Derived classes should implement `_step_microbatches`. - """ - - def __init__( - self, - stage: _PipelineStageBase, - n_microbatches: int, - loss_fn: Optional[Callable] = None, - output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - ): - # Init parent - super().__init__( - n_microbatches=n_microbatches, - loss_fn=loss_fn, - output_merge_spec=output_merge_spec, - ) - self._pipe_info = ( - stage.pipe_info if hasattr(stage, "pipe_info") else None # type: ignore[attr-defined] - ) - # Self attributes - self._stage = stage - self._num_stages = stage.num_stages - # Set the same has_backward flag for stage object - self._stage.has_backward = self._has_backward - - def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): - # Clean per iteration - self._stage.clear_runtime_states() - - # Split inputs into microbatches - args_split, kwargs_split = self._split_inputs(args, kwargs) - - # Split target into microbatches - if target is not None: - targets_split = list(torch.tensor_split(target, self._n_microbatches)) - else: - targets_split = None - - # Run microbatches - self._step_microbatches(args_split, kwargs_split, targets_split, losses) - - # Return merged results per original format - if self._stage.is_last: - return self._merge_outputs(self._stage.output_chunks) - else: - return None - - -class ScheduleGPipe(PipelineScheduleSingle): - """ - The GPipe schedule. - Will go through all the microbatches in a fill-drain manner. - """ - - def _step_microbatches( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, - ): - """ - Run one iteration of the pipeline schedule with list of microbatches. - Will go through all the microbatches according to the GPipe schedule. - - Args: - microbatches: list of microbatch args. - """ - arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - - # Delay send waits - fwd_sends_to_wait: List[dist.Work] = [] - - # Run microbatches - for i in range(self._n_microbatches): - with record_function(f"Forward {i}"): - ops = self._stage.get_fwd_recv_ops() - works = _sorted_batch_p2p(ops, desc="fwd_recv") - for work in works.values(): - work.wait() - - output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] - - ops = self._stage.get_fwd_send_ops() - works = _sorted_batch_p2p(ops, desc="fwd_send") - fwd_sends_to_wait.extend(works.values()) - - logger.debug( - f"[{self._stage.stage_index}] Forwarded microbatch {i}" # noqa: G004 - ) - - self._maybe_compute_loss(self._stage, output, target_mbs, i) - - # Wait for all forward sends to finish - # This should not have performance impact because by the time the first - # backward arrives all the forward sends should have been finished. - for work in fwd_sends_to_wait: - work.wait() - - # No loss function, no need to run backward - if not self._has_backward: - return - - # Run backward - # Delay send waits - bwd_sends_to_wait: List[dist.Work] = [] - for i in range(self._n_microbatches): - # set library-specific data-parallel config flags to ensure gradient accumulation across microbatches - self._stage._configure_data_parallel_mode(i == self._n_microbatches - 1) - - with record_function(f"Backward {i}"): - ops = self._stage.get_bwd_recv_ops() - works = _sorted_batch_p2p(ops, desc="bwd_recv") - for work in works.values(): - work.wait() - - loss = self._maybe_get_loss(self._stage, i) - self._stage.backward_one_chunk(loss=loss) - - ops = self._stage.get_bwd_send_ops() - works = _sorted_batch_p2p(ops, desc="bwd_send") - bwd_sends_to_wait.extend(works.values()) - - logger.debug( - f"[{self._stage.stage_index}] Backwarded microbatch {i}" # noqa: G004 - ) - - # Return losses if there is a container passed in - self._update_losses(self._stage, losses) - - # Wait for all backward sends to finish - for work in bwd_sends_to_wait: - work.wait() - - -class Schedule1F1B(PipelineScheduleSingle): - """ - The 1F1B schedule. - Will perform one forward and one backward on the microbatches in steady state. - """ - - def _step_microbatches( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, - ): - """ - Run one iteration of the pipeline schedule with list of microbatches. - Will go through all the microbatches according to the 1F1B schedule. - - Args: - microbatches: list of microbatch args. - """ - arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - - # Example, 4 GPUs, 8 microbatches - # Stage 0: 6 warmup, 2 1f1b, 6 cooldown - # Stage 1: 4 warmup, 4 1f1b, 4 cooldown - # Stage 2: 2 warmup, 6 1f1b, 2 cooldown - # Stage 3: 0 warmup, 8 1f1b, 0 cooldown - # fwd only - warmup_steps = min( - self._n_microbatches, - 2 * (self._num_stages - self._stage.stage_index - 1), - ) - # fwd + bwd - main_1f1b_steps = self._n_microbatches - warmup_steps - # bwd only - cooldown_steps = (2 * self._n_microbatches) - ( - warmup_steps + (2 * main_1f1b_steps) - ) - total_steps = warmup_steps + main_1f1b_steps + cooldown_steps - logger.debug( - f"Stage {self._stage.stage_index}: " # noqa: G004 - f"Warmup steps: {warmup_steps}, " - f"Main 1F1B steps: {main_1f1b_steps}, " - f"Cooldown steps: {cooldown_steps}, " - f"Total steps: {total_steps}" - ) - - # Delay send waits - fwd_sends_to_wait: List[dist.Work] = [] - bwd_sends_to_wait: List[dist.Work] = [] - - def step_has_forward(i): - assert i >= 0, i - return i < self._n_microbatches - - def step_has_backward(i): - assert i < total_steps, i - return i >= warmup_steps and self._has_backward - - def is_1f1b_step(i): - return step_has_forward(i) and step_has_backward(i) - - def is_warmup_step(i): - return step_has_forward(i) and not step_has_backward(i) - - def is_cooldown_step(i): - return not step_has_forward(i) and step_has_backward(i) - - def should_coalesce_fwd_send_bwd_recv(step): - return ( - is_1f1b_step(step) - or (is_warmup_step(step) and is_cooldown_step(step + 1)) - or (step >= 1 and is_warmup_step(step - 1) and is_cooldown_step(step)) - ) - - def should_coalesce_bwd_send_fwd_recv(bwd_send_step): - # The backward send to prev stage should be coalesced with the fwd recv from the previous stage - return bwd_send_step >= warmup_steps and is_1f1b_step(bwd_send_step + 1) - - # bwd chunk counter - bwd_mb_index = 0 - self._stage._configure_data_parallel_mode(last_backward=False) - for i in range(total_steps): - if step_has_forward(i): - with record_function(f"Forward {i}"): - ops = self._stage.get_fwd_recv_ops() - desc = "fwd_recv" - if should_coalesce_bwd_send_fwd_recv(i - 1): - desc += "_bwd_send" - ops.extend(self._stage.get_bwd_send_ops()) - - works = _sorted_batch_p2p(ops, desc=desc) - for work in works.values(): - work.wait() - - output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] - - if not should_coalesce_fwd_send_bwd_recv(i): - ops = self._stage.get_fwd_send_ops() - works = _sorted_batch_p2p(ops, desc="fwd_send") - fwd_sends_to_wait.extend(works.values()) - - self._maybe_compute_loss(self._stage, output, target_mbs, i) - - if step_has_backward(i): - self._stage._configure_data_parallel_mode( - last_backward=(i == total_steps - 1) - ) - with record_function(f"Backward {bwd_mb_index}"): - ops = self._stage.get_bwd_recv_ops() - desc = "bwd_recv" - if should_coalesce_fwd_send_bwd_recv(i): - ops.extend(self._stage.get_fwd_send_ops()) - desc += "_fwd_send" - - works = _sorted_batch_p2p(ops, desc=desc) - for work in works.values(): - work.wait() - - loss = self._maybe_get_loss(self._stage, bwd_mb_index) - self._stage.backward_one_chunk(loss=loss) - - if not should_coalesce_bwd_send_fwd_recv(i): - # see Note: coalesced bwd-send/fwd-recv - ops = self._stage.get_bwd_send_ops() - works = _sorted_batch_p2p(ops, desc="bwd_send") - bwd_sends_to_wait.extend(works.values()) - - bwd_mb_index += 1 - - # Wait for all forward sends to finish - for work in fwd_sends_to_wait: - work.wait() - - # Wait for all backward sends to finish - for work in bwd_sends_to_wait: - work.wait() - - # Return losses if there is a container passed in - self._update_losses(self._stage, losses) - - -class PipelineScheduleMulti(_PipelineSchedule): - """ - Base class for multi-stage schedules. - Implements the `step` method. - Derived classes should implement `_step_microbatches`. - """ - - def __init__( - self, - stages: List[_PipelineStageBase], - n_microbatches: int, - loss_fn: Optional[Callable] = None, - output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - ): - if len(stages) <= 1: - raise ValueError( - f"Multi-stage schedule expects at least two stages but got {len(stages)}" - ) - # Init parent - super().__init__( - n_microbatches=n_microbatches, - loss_fn=loss_fn, - output_merge_spec=output_merge_spec, - ) - self._pipe_info = ( - stages[0].pipe_info if hasattr(stages[0], "pipe_info") else None # type: ignore[attr-defined] - ) - # Self attributes - self._stages = stages - self._num_stages = stages[0].num_stages - # Set the same has_backward flag for stage object - for stage in self._stages: - stage.has_backward = self._has_backward - - self._should_compute_loss = ( - lambda stage: stage.is_last and self._loss_fn is not None - ) - - def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): - # Clean per iteration - for stage in self._stages: - stage.clear_runtime_states() - - # Split inputs into microbatches - args_split, kwargs_split = self._split_inputs(args, kwargs) - - # Split target into microbatches - if target is not None: - targets_split = list(torch.tensor_split(target, self._n_microbatches)) - else: - targets_split = None - - # Run microbatches - self._step_microbatches(args_split, kwargs_split, targets_split, losses) - - # Return merged results per original format - for stage in self._stages: - if stage.is_last: - return self._merge_outputs(stage.output_chunks) - # Does not contain the last stage - return None - - -class ScheduleLoopedBFS(PipelineScheduleMulti): - """ - Breadth-First Pipeline Parallelism. - See https://arxiv.org/abs/2211.05953 for details. - Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. - What is different is that when microbatches are ready for multiple local - stages, Loops BFS will prioritizes the earlier stage, running all available - microbatches at once. - """ - - def _step_microbatches( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, # TODO - losses: Optional[List] = None, # TODO - ): - """ - Run one iteration of the pipeline schedule with list of microbatches. - Will go through all the microbatches according to the Looped BFS schedule. - - Args: - microbatches: list of microbatch args. - """ - # Pre-process inputs - if arg_mbs is not None: - # TODO: fix this so it is preset - self._n_microbatches = len(arg_mbs) - assert len(arg_mbs) == self._n_microbatches - else: - arg_mbs = [()] * self._n_microbatches - - if kwarg_mbs is not None: - assert len(kwarg_mbs) == self._n_microbatches - else: - kwarg_mbs = [{}] * self._n_microbatches - - for stage in self._stages: - for i in range(self._n_microbatches): - with record_function(f"Stage {stage.stage_index} Forward"): - ops = stage.get_fwd_recv_ops() - if ops: - _batch_p2p(ops, desc="fwd_recv").wait() - - output = stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) - self._maybe_compute_loss(stage, output, target_mbs, i) - - ops = stage.get_fwd_send_ops() - if ops: - _batch_p2p(ops, desc="fwd_send") - - for stage in reversed(self._stages): - for i in range(self._n_microbatches): - stage._configure_data_parallel_mode(i == self._n_microbatches - 1) - with record_function(f"Stage {stage.stage_index} Backward"): - ops = stage.get_bwd_recv_ops() - if ops: - _batch_p2p(ops, desc="bwd_recv").wait() - - loss = self._maybe_get_loss(stage, i) - stage.backward_one_chunk(loss=loss) - - ops = stage.get_bwd_send_ops() - if ops: - _batch_p2p(ops, desc="bwd_send") - - self._update_losses(self._stages, losses) - - -class ScheduleInterleaved1F1B(PipelineScheduleMulti): - """ - The Interleaved 1F1B schedule. - Will perform one forward and one backward on the microbatches in steady - state and supports multiple stages per rank. When microbatches are ready for - multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch - (also called "depth first"). - """ - - def __init__( - self, - stages: List[_PipelineStageBase], - n_microbatches: int, - loss_fn: Optional[Callable] = None, - output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - ): - self.pp_group_size = stages[0].group_size - # TODO: is this limitation a must? - if n_microbatches % self.pp_group_size != 0: - raise ValueError( - "Interleaved 1F1B requires the number of microbatches to be a " - f"multiple of the number of pipeline ranks ({self.pp_group_size}), " - f"but got {n_microbatches}." - ) - - super().__init__( - stages=stages, - n_microbatches=n_microbatches, - loss_fn=loss_fn, - output_merge_spec=output_merge_spec, - ) - - self.n_local_stages = len(stages) - self.rank = stages[0].group_rank - - def _step_microbatches( - self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, - ): - """ - Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf). - - Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks - and each rank away from highest rank adds 2 warmup steps due to: - - one happened before highest rank's warmup started, - - one waiting for backward result to trickle down from highest rank - - TODO: Interleaved 1F1B does not support using _sorted_batch_p2p() - because it requires recvs and sends from different peers - to execute in the same coalesced operation. As a result, this schedule does - not support models with skip connections. - """ - arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - - # increment warmup_steps by 2 for each hop away - warmup_steps = (self.n_local_stages - 1) * self.pp_group_size - warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank) - warmup_steps = min(warmup_steps, self._n_microbatches * self.n_local_stages) - fwd_bwd_steps = (self.n_local_stages * self._n_microbatches) - warmup_steps - cooldown_steps = (self.n_local_stages * self._n_microbatches) - fwd_bwd_steps - - assert ( - warmup_steps + fwd_bwd_steps * 2 + cooldown_steps - == self.n_local_stages * self._n_microbatches * 2 - ) - total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps - - logger.debug( - f"rank {self.rank}, warmup_steps {warmup_steps}, " # noqa: G004 - f"1f1b {fwd_bwd_steps}, cooldown_steps {cooldown_steps}" - ) - - def forward_stage_local_index(step): - return (step // self.pp_group_size) % self.n_local_stages - - def backward_stage_local_index(step): - return ( - self.n_local_stages - - 1 - - ((step - warmup_steps) // self.pp_group_size) % self.n_local_stages - ) - - fwd_stage_mb_index: Dict[_PipelineStageBase, int] = defaultdict(int) - bwd_stage_mb_index: Dict[_PipelineStageBase, int] = defaultdict(int) - - # Delay send waits - sends_to_wait: List[dist.Work] = [] - - # Store ops (potentially across steps) - ops: List[dist.P2POp] = [] - - # Warmup Phase (forward only) - for step in range(warmup_steps): - fwd_stage = self._stages[forward_stage_local_index(step)] - - # This will assign the current microbatch index and update it for future steps - fwd_stage_mb_index[fwd_stage] = ( - mb_index := fwd_stage_mb_index[fwd_stage] - ) + 1 - - logger.debug( - f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" # noqa: G004 - ) - - with record_function(f"Forward {step}"): - ops.extend(fwd_stage.get_fwd_recv_ops()) - if ops: - work = _batch_p2p(ops, desc="warmup_pre_fwd") - work.wait() - ops.clear() - - output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index] - - ops.extend(fwd_stage.get_fwd_send_ops()) - # If we are right before the fwd-bwd step, then we need to delay the send to the next step, - # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang. - # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed - if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0): - work = _batch_p2p(ops, desc="warmup_post_fwd") - sends_to_wait.append(work) - ops.clear() - - self._maybe_compute_loss(fwd_stage, output, target_mbs, mb_index) - - # 1F1B Phase (forward and backward) - for step in range(warmup_steps, warmup_steps + fwd_bwd_steps): - fwd_stage = self._stages[forward_stage_local_index(step)] - bwd_stage = self._stages[backward_stage_local_index(step)] - - fwd_stage_mb_index[fwd_stage] = ( - fwd_mb_index := fwd_stage_mb_index[fwd_stage] - ) + 1 - bwd_stage_mb_index[bwd_stage] = ( - bwd_mb_index := bwd_stage_mb_index[bwd_stage] - ) + 1 - - bwd_stage._configure_data_parallel_mode( - bwd_mb_index == self._n_microbatches - 1 - ) - logger.debug( - f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, " # noqa: G004 - f"{bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}" - ) - desc = f"1F1B {step}" - with record_function(desc): - ops.extend(fwd_stage.get_fwd_recv_ops()) - ops.extend(bwd_stage.get_bwd_recv_ops()) - if ops: - work = _batch_p2p(ops, desc=desc) - work.wait() - ops.clear() - - # Forward - output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] - ops.extend(fwd_stage.get_fwd_send_ops()) - self._maybe_compute_loss(fwd_stage, output, target_mbs, fwd_mb_index) - - # Backward - loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) - bwd_stage.backward_one_chunk(loss=loss) - ops.extend(bwd_stage.get_bwd_send_ops()) - - # Cooldown Phase (backward only) - for step in range(warmup_steps + fwd_bwd_steps, total_steps): - bwd_stage = self._stages[backward_stage_local_index(step)] - bwd_stage_mb_index[bwd_stage] = ( - bwd_mb_index := bwd_stage_mb_index[bwd_stage] - ) + 1 - bwd_stage._configure_data_parallel_mode( - bwd_mb_index == self._n_microbatches - 1 - ) - - logger.debug( - f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" # noqa: G004 - ) - desc = f"Cooldown {step}" - with record_function(desc): - ops.extend(bwd_stage.get_bwd_recv_ops()) - if ops: - work = _batch_p2p(ops, desc=desc + " pre_bwd") - work.wait() - ops.clear() - - loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) - bwd_stage.backward_one_chunk(loss=loss) - - ops.extend(bwd_stage.get_bwd_send_ops()) - if ops: - work = _batch_p2p(ops, desc=desc + " post_bwd") - sends_to_wait.append(work) - ops.clear() - - # Make sure all sends are finished - for work in sends_to_wait: - work.wait() - - # Return losses if there is a container passed in - self._update_losses(self._stages, losses) diff --git a/torch/distributed/pipelining/README.md b/torch/distributed/pipelining/README.md index 46a05a22c8ce..d4c9aaafa5b3 100644 --- a/torch/distributed/pipelining/README.md +++ b/torch/distributed/pipelining/README.md @@ -1,178 +1,7 @@ # Pipeline Parallelism for PyTorch -> [!NOTE] -> `torch.distributed.pipelining` is a package migrated from the [PiPPy](https://github.com/pytorch/PiPPy) project. It is currently in alpha state and under extensive development. If you need examples that work with our APIs, please refer to PiPPy's [examples](https://github.com/pytorch/PiPPy/tree/main/examples) directory. +`torch.distributed.pipelining` is a package for implementing pipeline parallelism on your model. -[**Why Pipeline Parallel?**](#why-pipeline-parallel) -| [**What is `torch.distributed.pipelining`?**](#what-is-torchdistributedpipelining) -| [**Examples**](#examples) -| [**Techniques Explained**](#techniques-explained) - -# Why Pipeline Parallel? - -One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include _data parallelism_, _tensor/operation parallelism_, and _pipeline parallelism_. In many cases, pipeline parallelism in particular can be an effective technique for scaling, however it is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. `torch.distributed.pipelining` aims to provide a toolkit that does said things automatically to allow high-productivity scaling of models. - -# What is `torch.distributed.pipelining`? - -`torch.distributed.pipelining` consists of a compiler and runtime stack for automated pipelining of PyTorch models. Pipelining, or _pipeline parallelism_, is a technique in which the _code_ of the model is partitioned and multiple _micro-batches_ execute different parts of the model code concurrently. To learn more about pipeline parallelism, see [this article](https://www.deepspeed.ai/tutorials/pipeline/). +Our documentation is available [here](https://pytorch.org/docs/main/distributed.pipelining.html). ![pipeline_diagram_web](https://github.com/pytorch/PiPPy/assets/6676466/c93e2fe7-1cd4-49a2-9fd8-231ec9905e0c) - -Figure: Pipeline parallel. "F", "B" and "U" denote forward, backward and weight update, respectively. Different colors represent different micro-batches. - -`torch.distributed.pipelining` provides the following features that make pipeline parallelism easier: - -* Automatic splitting of model code based on your specification. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. The specification is also simple. -* Support for rich pipeline scheduling paradigms, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS. More schedules will be added and it will be easy to customize your own schedule under `torch.distributed.pipelining`'s framework. -* First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). -* Composability with other PyTorch parallel schemes such as data parallelism (DDP, FSDP) or tensor parallelism (overall, known as "3d parallelism"). - -# Examples - -In the [PiPPy](https://github.com/pytorch/PiPPy) repo where this package is migrated from, we provide rich examples based on realistic models. In particular, we show how to apply pipelining without any model code change. You can refer to the [HuggingFace examples directory](https://github.com/pytorch/PiPPy/tree/main/examples/huggingface). Popular examples include: [GPT2](https://github.com/pytorch/PiPPy/tree/main/examples/huggingface/pippy_gpt2.py), and [LLaMA](https://github.com/pytorch/PiPPy/tree/main/examples/llama). - -# Techniques Explained - -`torch.distributed.pipelining` consists of two parts: a _compiler_ and a _runtime_. The compiler takes your model code, splits it up, and transforms it into a `Pipe`, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the `PipelineStage`s in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. - -## Splitting a Model with `pipeline` - -To see how we can split a model into a pipeline, let's first take an example trivial neural network: - -```python -import torch - -class MyNetworkBlock(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.lin = torch.nn.Linear(in_dim, out_dim) - - def forward(self, x): - x = self.lin(x) - x = torch.relu(x) - return x - - -class MyNetwork(torch.nn.Module): - def __init__(self, in_dim, layer_dims): - super().__init__() - - prev_dim = in_dim - for i, dim in enumerate(layer_dims): - setattr(self, f'layer{i}', MyNetworkBlock(prev_dim, dim)) - prev_dim = dim - - self.num_layers = len(layer_dims) - # 10 output classes - self.output_proj = torch.nn.Linear(layer_dims[-1], 10) - - def forward(self, x): - for i in range(self.num_layers): - x = getattr(self, f'layer{i}')(x) - - return self.output_proj(x) - - -in_dim = 512 -layer_dims = [512, 1024, 256] -mn = MyNetwork(in_dim, layer_dims).to(device) -``` - -This network is written as free-form Python code; it has not been modified for any specific parallelism technique. - -Let us see our first usage of the `torch.distributed.pipelining` interfaces: - -```python -from torch.distributed.pipelining import annotate_split_points, pipeline, Pipe, SplitPoint - -annotate_split_points(mn, {'layer0': SplitPoint.END, - 'layer1': SplitPoint.END}) - -batch_size = 32 -example_input = torch.randn(batch_size, in_dim, device=device) -chunks = 4 - -pipe = pipeline(mn, chunks, example_args=(example_input,)) -print(pipe) - -""" -************************************* pipe ************************************* -GraphModule( - (submod_0): GraphModule( - (layer0): InterpreterModule( - (lin): InterpreterModule() - ) - ) - (submod_1): GraphModule( - (layer1): InterpreterModule( - (lin): InterpreterModule() - ) - ) - (submod_2): GraphModule( - (layer2): InterpreterModule( - (lin): InterpreterModule() - ) - (output_proj): InterpreterModule() - ) -) - -def forward(self, arg8_1): - submod_0 = self.submod_0(arg8_1); arg8_1 = None - submod_1 = self.submod_1(submod_0); submod_0 = None - submod_2 = self.submod_2(submod_1); submod_1 = None - return (submod_2,) -""" -``` - -So what's going on here? First, `pipeline` turns our model into a directed acyclic graph (DAG) by tracing the model. Then, it groups together the operations and parameters into _pipeline stages_. Stages are represented as `submod_N` submodules, where `N` is a natural number. - -We used `annotate_split_points` to specify that the code should be split and the end of `layer0` and `layer1`. Our code has thus been split into _three_ pipeline stages. Our library also provides `SplitPoint.BEGINNING` if a user wants to split before certain annotation point. - -While the `annotate_split_points` API gives users a way to specify the split points without modifying the model, our library also provides an API for in-model annotation: `pipe_split()`. For details, you can read [this example](https://github.com/pytorch/PiPPy/blob/main/test/test_pipe.py). - -This covers the basic usage of the `Pipe` API. For more information, please see the documentation. - - - -## Using PipelineStage for Pipelined Execution - -Given the above `Pipe` object, we can use one of the `PipelineStage` classes to execute our model in a pipelined fashion. First off, let us instantiate a `PipelineStage` instance: - -```python -# We are using `torchrun` to run this example with multiple processes. -# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) - -# Initialize distributed environment -import torch.distributed as dist -dist.init_process_group(rank=rank, world_size=world_size) - -# Pipeline stage is our main pipeline runtime. It takes in the pipe object, -# the rank of this process, and the device. -from torch.distributed.pipelining import PipelineStage -stage = PipelineStage(pipe, rank, device) -``` - -We can now run the pipeline by attaching the `PipelineStage` to a pipeline schedule, GPipe for example: - -```python -from torch.distributed.pipelining import ScheduleGPipe -schedule = ScheduleGPipe(stage, chunks) - -# Input data -x = torch.randn(batch_size, in_dim, device=device) - -# Run the pipeline with input `x`. Divide the batch into 4 micro-batches -# and run them in parallel on the pipeline -if rank == 0: - schedule.step(x) -else: - output = schedule.step() -``` - -Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use `torchrun` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named [example.py](https://github.com/pytorch/PiPPy/tree/main/examples/basic) and then run it with `torchrun` like so: - -``` -torchrun --nproc_per_node=3 example.py -``` diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 68465cc6cd0b..7d0aede8943e 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1,23 +1,30 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import copy import logging import operator -from dataclasses import dataclass +from collections import defaultdict from enum import Enum from inspect import Parameter, signature, Signature from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.fx as fx +from torch.distributed import ProcessGroup from torch.export import ExportedProgram -from torch.export.unflatten import _assign_attr, _AttrKind, _sink_params +from torch.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module - from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules -from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec +from ._utils import PipeInfo +from .stage import _PipelineStage logger = logging.getLogger(__name__) @@ -323,13 +330,13 @@ def pipe_split(): no-op if your annotated module is run eagerly. Example: - >>> # xdoctest: +SKIP - >>> def forward(self, x): - >>> x = torch.mm(x, self.mm_param) - >>> x = torch.relu(x) - >>> pipe_split() - >>> x = self.lin(x) - >>> return x + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = torch.mm(x, self.mm_param) + >>> x = torch.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x The above example will be split into two stages. """ @@ -479,30 +486,42 @@ def _direct_serialization_reduce(self): ) -class Pipe(torch.nn.Module): - # Class variables +def _modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): """ - args_chunk_spec: - Chunking specification for positional inputs. (default: `None`) - kwargs_chunk_spec: - Chunking specification for keyword inputs. (default: `None`) + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like torch.ones. """ - # args_chunk_spec and kwargs_chunk_spec are used to specify how to chunk - # inputs. They are used to create microbatched examples before tracing. - # See context managers `ArgsChunkSpec` and `KwargsChunkSpec`. - # TODO: Do we need to support `_Replicate`? It's unclear, dropping for now. - args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None - - @dataclass - class PipeInfo: - graph: fx.Graph - num_stages: int - num_chunks: int - has_loss_and_backward: bool - args_chunk_spec: Optional[Tuple[Any, ...]] = None - kwargs_chunk_spec: Optional[Dict[str, Any]] = None + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, torch.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + +class Pipe(torch.nn.Module): def __init__( self, split_gm: fx.GraphModule, @@ -517,7 +536,6 @@ def __init__( self.num_stages: int = num_stages self.has_loss_and_backward = has_loss_and_backward self.loss_spec = loss_spec - self.pipe_info: Optional[Pipe.PipeInfo] = None for node in split_gm.graph.nodes: assert ( @@ -616,6 +634,9 @@ def forward(self, *args, **kwargs): return res def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ if stage_idx < 0 or stage_idx >= self.num_stages: raise ValueError(f"Invalid stage index {stage_idx}!") return getattr(self.split_gm, f"submod_{stage_idx}") @@ -751,6 +772,18 @@ def delete_user_reference(node, user): # To be accumulated in `move_param_to_callee`. to_delete = list() + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + def move_param_to_callee( root, callee_name, @@ -766,12 +799,7 @@ def move_param_to_callee( # `atoms` is a list of strings representing the path to the # parameter in the original model atoms = param_fqn.split(".") - # Recursively find the parent of the parameter - mod_itr = split - for atom in atoms[:-1]: - mod_itr = getattr(mod_itr, atom) - # Get the parameter (it is still under the root module) - param_val = getattr(mod_itr, atoms[-1]) + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) # Check whether the parameter is a buffer or a parameter is_buffer = atoms[-1] in mod_itr._buffers @@ -837,6 +865,41 @@ def move_param_to_callee( node.target, ) + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: Dict[int, Set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: Dict[str, List[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: Dict[str, List[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + # Deferral deletion: Remove the original attributes (to params) from the # root GraphModule for mod_itr, last_atom in to_delete: @@ -846,12 +909,6 @@ def move_param_to_callee( # This is expected if the parameter is used in multiple stages pass - # After moving the params to their corresponding hierarchies, we also - # need to move the `get_attr` nodes from the root of the graph to those - # hierarchies. - inputs_to_state: Dict[str, List[str]] = { - attr.name: [attr.target] for attr in attr_nodes - } # This is done by (1) `_sink_params` at each submodule; for name, submod in split.named_children(): if isinstance(submod, fx.GraphModule): @@ -859,6 +916,32 @@ def move_param_to_callee( submod.graph.lint() submod.recompile() + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + if fqn in unused_attributes: # used, remove it + unused_attributes.remove(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + for node in attr_nodes: # And (2): remove `get_attr` node from submod's arg list for user in copy.copy(node.users): @@ -919,17 +1002,26 @@ def _trace_with_export( example_kwargs: Optional[Dict[str, Any]] = None, ) -> ExportedProgram: logger.info("Tracing model ...") - ep = torch.export.export( - mod, - example_args, - example_kwargs, - ) + try: + ep = torch.export.export( + mod, + example_args, + example_kwargs, + ) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pytorch.org/docs/stable/export.html" + ) from e + return ep @staticmethod def from_tracing( mod: torch.nn.Module, - num_chunks: int, example_args: Tuple[Any, ...], example_kwargs: Optional[Dict[str, Any]] = None, split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, @@ -948,19 +1040,11 @@ def from_tracing( ) """ - args_split, kwargs_split = split_args_kwargs_into_chunks( - example_args, - example_kwargs, - num_chunks, - Pipe.args_chunk_spec, - Pipe.kwargs_chunk_spec, - ) - # Trace with export exported_program = Pipe._trace_with_export( mod, - example_args=args_split[0], - example_kwargs=kwargs_split[0], + example_args, + example_kwargs, ) pipe = Pipe._from_traced( @@ -1000,15 +1084,6 @@ def from_tracing( ) submod0.recompile() - # Create pipe info - pipe.pipe_info = Pipe.PipeInfo( - graph=pipe.split_gm.graph, - num_stages=pipe.num_stages, - num_chunks=num_chunks, - has_loss_and_backward=pipe.has_loss_and_backward, - args_chunk_spec=Pipe.args_chunk_spec, - kwargs_chunk_spec=Pipe.kwargs_chunk_spec, - ) return pipe def __str__(self): @@ -1018,11 +1093,53 @@ def __repr__(self): return self.split_gm.__repr__() def info(self) -> PipeInfo: - if self.pipe_info is None: - raise RuntimeError( - "Pipe info is not available. Please use the `pipeline` method to create the `Pipe` object." + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: torch.device, + group: Optional[ProcessGroup] = None, + ) -> _PipelineStage: + """ + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + if isinstance(stage_module, torch.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 ) - return self.pipe_info + + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self.info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) class SplitPoint(Enum): @@ -1074,29 +1191,26 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): def pipeline( module: torch.nn.Module, - num_chunks: int, - example_args: Tuple[Any, ...], - example_kwargs: Optional[Dict[str, Any]] = None, + mb_args: Tuple[Any, ...], + mb_kwargs: Optional[Dict[str, Any]] = None, split_spec: Optional[Dict[str, SplitPoint]] = None, split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, ) -> Pipe: """ - Creates a pipeline representation for the provided module. + Split a module based on a specification. See `Pipe` for more details. Arguments --------- module: - The module to be transformed into a `Pipe`. - num_chunks: - The number of microbatches to be run with this pipeline. - example_args: - Example positional inputs to be used with this pipeline. - example_kwargs: - Example keyword inputs to be used with this pipeline. (default: `None`) + The module to be splitted. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) split_spec: - A dictionary mapping module names to `SplitPoint`s. (default: `None`) + A dictionary using submodule names as split marker. (default: `None`) split_policy: The policy to use for splitting the module. (default: `None`) @@ -1114,75 +1228,14 @@ def pipeline( annotate_split_points(module, split_spec) return Pipe.from_tracing( mod=module, - num_chunks=num_chunks, - example_args=example_args, - example_kwargs=example_kwargs, + example_args=mb_args, + example_kwargs=mb_kwargs, ) else: # Use split policy return Pipe.from_tracing( mod=module, - num_chunks=num_chunks, - example_args=example_args, - example_kwargs=example_kwargs, + example_args=mb_args, + example_kwargs=mb_kwargs, split_policy=split_policy, ) - - -# Context manager for setting `args_chunk_spec` during creation of Pipe -class ArgsChunkSpec: - """ - Example: - >>> # xdoctest: +SKIP - >>> # There are three positional arguments to the model, and - >>> # we are chunking them along dimension 0, 0 and 1, respectively - >>> with ArgsChunkSpec((0, 0, 1)): - >>> pipe = pipeline(model, num_chunks, example_args) - """ - - def __init__( - self, - chunk_dims: Tuple[int, ...], - ): - self.args_chunk_spec = map_aggregate( - chunk_dims, - lambda dim: TensorChunkSpec(dim), - ) - - def __enter__(self): - # Inject into the Pipe class - Pipe.args_chunk_spec = self.args_chunk_spec - return self.args_chunk_spec - - def __exit__(self, exc_type, exc_val, traceback): - # Remove from the Pipe class - Pipe.args_chunk_spec = None - - -# Context manager for setting `kwargs_chunk_spec` during creation of Pipe -class KwargsChunkSpec: - """ - Example: - >>> # xdoctest: +SKIP - >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument - >>> with KwargsChunkSpec({"id": 0, "mask": 1}): - >>> pipe = pipeline(model, num_chunks, (), example_kwargs) - """ - - def __init__( - self, - chunk_dims: Dict[str, int], - ): - self.kwargs_chunk_spec = map_aggregate( - chunk_dims, - lambda dim: TensorChunkSpec(dim), - ) - - def __enter__(self): - # Inject into the Pipe class - Pipe.kwargs_chunk_spec = self.kwargs_chunk_spec - return self.kwargs_chunk_spec - - def __exit__(self, exc_type, exc_val, traceback): - # Remove from the Pipe class - Pipe.kwargs_chunk_spec = None diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index c192c314e802..18b3191add5b 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -1,31 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from ._IR import ( - annotate_split_points, - ArgsChunkSpec, - KwargsChunkSpec, - Pipe, - pipe_split, - pipeline, - SplitPoint, -) -from .PipelineSchedule import ( +from ._IR import Pipe, pipe_split, pipeline, SplitPoint +from .schedules import ( Schedule1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleLoopedBFS, ) -from .PipelineStage import ManualPipelineStage, PipelineStage +from .stage import build_stage, PipelineStage __all__ = [ "Pipe", "pipe_split", "SplitPoint", - "annotate_split_points", "pipeline", - "ArgsChunkSpec", - "KwargsChunkSpec", - "ManualPipelineStage", "PipelineStage", + "build_stage", "Schedule1F1B", "ScheduleGPipe", "ScheduleInterleaved1F1B", diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index c3aa9060502b..6ba12899e838 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Optional diff --git a/torch/distributed/pipelining/_debug.py b/torch/distributed/pipelining/_debug.py index 7067a39b39d1..6b153ec78d89 100644 --- a/torch/distributed/pipelining/_debug.py +++ b/torch/distributed/pipelining/_debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import torch diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index 27241d17874c..659c9804a966 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index f4680530d29f..cf7097795868 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -1,10 +1,11 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging +from dataclasses import dataclass from typing import List, Tuple, Union import torch from torch import fx -from torch.export.unflatten import InterpreterModule logger = logging.getLogger(__name__) @@ -53,41 +54,6 @@ def extract_tensor_args(a): return flat_args -def modify_graph_op_device( - gm: torch.fx.GraphModule, - new_device: torch.device, -): - """ - Modify the device argument of all "call_function" nodes in the graph. This - is useful for moving the graph to a different device. In particular for - generator ops, like torch.ones. - """ - modified = False - for node in gm.graph.nodes: - if node.op == "call_function": - if "device" in node.kwargs and node.kwargs["device"] != new_device: - logger.debug( - f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 - ) - node.update_kwarg("device", new_device) - modified = True - elif node.op == "call_module": - # Recursively modify "device" in submodules - submod = gm.get_submodule(node.target) - if isinstance(submod, torch.fx.GraphModule): - modify_graph_op_device(submod, new_device) - elif isinstance(submod, InterpreterModule): - # If unflattening has been performed, we need to access its graph module by `.graph_module` - modify_graph_op_device(submod.graph_module, new_device) - else: - logger.warning( - f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 - ) - - if modified: - gm.recompile() - - class PipeliningShapeError(RuntimeError): """Shape mismatch between configured and runtime values.""" @@ -120,3 +86,14 @@ def validate_tensors_metadata( validate_tensor_metadata( f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] ) + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index 1201e235d036..8360951b43eb 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -1,11 +1,19 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging from typing import Any, Dict, List, Optional, Tuple import torch +from torch.fx.node import map_aggregate from torch.utils._pytree import tree_flatten, tree_unflatten +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + logger = logging.getLogger(__name__) """ @@ -45,8 +53,11 @@ class _LossReducer(_CustomReducer): DEFAULT_CHUNK_DIM = 0 -# Class used to specify chunking of inputs class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + def __init__(self, split_dim): self.split_dim = split_dim @@ -60,6 +71,43 @@ def __repr__(self): def __str__(self): return f"TensorChunkSpec({self.split_dim})" + @staticmethod + def from_tuple( + chunk_dims: Tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: Dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), + ) + return kwargs_chunk_spec + # Class used to specify replication of inputs class _Replicate: diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py new file mode 100644 index 000000000000..6990ea983edb --- /dev/null +++ b/torch/distributed/pipelining/schedules.py @@ -0,0 +1,957 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.profiler import record_function + +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec +from .stage import _PipelineStageBase + + +__all__ = [ + "PipelineScheduleSingle", + "PipelineScheduleMulti", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", +] + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + FORWARD = 1 + BACKWARD = 2 + + def __str__(self): + if self == _ComputationType.FORWARD: + return "F" + else: + return "B" + + +class _Action(NamedTuple): + computation_type: _ComputationType + microbatch_index: int + stage_index: int + + def __repr__(self): + return f"{self.computation_type}{self.microbatch_index}_s{self.stage_index}" + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Optional[Callable[..., torch.Tensor]] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: List[torch.Tensor] = [] + logger.info(f"Using {self.__class__.__name__}") # noqa: G004 + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._has_backward: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._has_backward and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + """ + raise NotImplementedError + + @abstractmethod + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + raise NotImplementedError + + def _check_inputs( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError(f"losses must be a list but got a {type(losses)}") + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: List[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): + """ + Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return None + desc_str = f"{desc}, " if desc else "" + logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004 + return dist.batch_isend_irecv(p2p_ops).pop() + + +def _sorted_batch_p2p( + p2p_ops: List[dist.P2POp], desc: Optional[str] = None +) -> Dict[int, dist.Work]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) + work_by_peer: Dict[int, dist.Work] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + + # TODO: later replace this with lazy shape inference during forward + # Prepare forward send/recv infrastructure for stage + stage._prepare_forward_infra(n_microbatches) + if self._has_backward: + stage._prepare_backward_infra(n_microbatches) + + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + if self._stage.is_last: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Delay send waits + fwd_sends_to_wait: List[dist.Work] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + work.wait() + + output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug( + f"[{self._stage.stage_index}] Forwarded microbatch {i}" # noqa: G004 + ) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + work.wait() + + # No loss function, no need to run backward + if not self._has_backward: + return + + # Run backward + # Delay send waits + bwd_sends_to_wait: List[dist.Work] = [] + for i in range(self._n_microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + work.wait() + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk(i, loss=loss) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug( + f"[{self._stage.stage_index}] Backwarded microbatch {i}" # noqa: G004 + ) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + work.wait() + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + + # Warmup phase + send_work = None + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"): + recv_work.wait() + + # Compute + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + if send_work: + send_work.wait() + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last foward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"): + fuse_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(bwd_mb_index, loss=loss) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"): + fuse_work.wait() + + # Now do the fwd + output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"): + recv_work.wait() + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(bwd_mb_index, loss=loss) + + # Clear previous chunk's backward sends (hopefully they have well finished) + if send_work: + send_work.wait() + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + # Wait for the last backward send to finish + if send_work: + send_work.wait() + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + if len(stages) <= 1: + raise ValueError( + f"Multi-stage schedule expects at least two stages but got {len(stages)}" + ) + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + + self._should_compute_loss = ( + lambda stage: stage.is_last and self._loss_fn is not None + ) + + # This will be set during init of derived schedules + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + + # TODO: later replace this with lazy shape inference during forward + # Prepare forward send/recv infrastructure for stage + for stage in self._stages: + stage._prepare_forward_infra(n_microbatches) + if self._has_backward: + stage._prepare_backward_infra(n_microbatches) + + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage + return None + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + prev_rank: int = (self.rank - 1) % self.pp_group_size + next_rank: int = (self.rank + 1) % self.pp_group_size + + for time_step, action in enumerate(self.pipeline_order[self.rank]): + prev_rank_ops = self.pipeline_order[prev_rank] + next_rank_ops = self.pipeline_order[next_rank] + ops: List[dist.P2POp] = [] + if action is not None: + computation_type, mb_index, stage_index = action + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk(mb_index, loss=loss) + ops.extend(stage.get_bwd_send_ops(mb_index)) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type, mb_index, stage_index = prev_rank_action + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index != self._num_stages - 1: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD: + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError(f"Unknown computation type {computation_type}") + + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type, mb_index, stage_index = next_rank_action + # Only handle receives for the backwards from a next rank + if computation_type == _ComputationType.FORWARD: + # Next rank doing forward has no influence for the current rank backward recv + pass + elif computation_type == _ComputationType.BACKWARD: + # If not the first stage, then receive bwd gradients + if stage_index != 0: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # do the communication + if ops: + _batch_p2p(ops).wait() + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class ScheduleLoopedBFS(PipelineScheduleMulti): + """ + Breadth-First Pipeline Parallelism. + See https://arxiv.org/abs/2211.05953 for details. + Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + What is different is that when microbatches are ready for multiple local + stages, Loops BFS will prioritizes the earlier stage, running all available + microbatches at once. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + ) + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) + + # Store the list of operations used for that rank + rank_ops: List[Optional[_Action]] = [] + # Pre-padding, rank starts with no-ops based on the warmup. + for _ in range(rank): + rank_ops.append(None) + + for stage_index in stage_indices: + for mb_index in range(self._n_microbatches): + rank_ops.append( + _Action(_ComputationType.FORWARD, mb_index, stage_index) + ) + + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) + + for stage_index in reversed(stage_indices): + for mb_index in reversed(range(self._n_microbatches)): + rank_ops.append( + _Action(_ComputationType.BACKWARD, mb_index, stage_index) + ) + return rank_ops + + +class ScheduleInterleaved1F1B(PipelineScheduleMulti): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + self.pp_group_size = stages[0].group_size + # TODO: is this limitation a must? + if n_microbatches % self.pp_group_size != 0: + raise ValueError( + f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ + to be a multiple of the number of pipeline ranks ({self.pp_group_size})." + ) + + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + ) + + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.group = stages[0].group + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size + # Increment warmup operations by 2 for each hop away from the last stage + warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank) + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.pp_group_size) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + # Dictionary for tracking {stage index : current microbatch index} + # All stages start with handling microbatch 0 + fwd_stage_mb_index: Dict[int, int] = defaultdict(int) + bwd_stage_mb_index: Dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + rank_ops: List[Optional[_Action]] = [] + # Pre-padding, rank starts with no-ops based on the warmup. + for _ in range(rank): + rank_ops.append(None) + + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + self.n_local_stages * self.pp_group_size + + 2 * (self.pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.FORWARD, mb_index, fwd_stage_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index) + ) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) + ) + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + rank_ops.append(None) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) + ) + + # Post padding + for _ in range(self.pp_group_size - rank - 1): + rank_ops.append(None) + return rank_ops diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/stage.py similarity index 88% rename from torch/distributed/pipelining/PipelineStage.py rename to torch/distributed/pipelining/stage.py index 93b67696bc79..c2c5582d6854 100644 --- a/torch/distributed/pipelining/PipelineStage.py +++ b/torch/distributed/pipelining/stage.py @@ -1,8 +1,9 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging import operator from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -15,13 +16,12 @@ from ._backward import stage_backward from ._debug import map_debug_info -from ._IR import Pipe -from ._utils import flatten_args, modify_graph_op_device, validate_tensors_metadata +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata __all__ = [ "PipelineStage", - "ManualPipelineStage", + "build_stage", ] logger = logging.getLogger(__name__) @@ -80,7 +80,8 @@ def _make_tensor_from_meta( class _PipelineStageBase(ABC): """ Base class for pipeline stages. - Implements common methods used by both the `PipelineStage` used by the tracing frontend and `ManualPipelineStage`. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. """ def __init__( @@ -89,7 +90,6 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - num_microbatches: int, group: Optional[dist.ProcessGroup] = None, ): """ @@ -98,7 +98,6 @@ def __init__( stage_index (int): The index of this stage. num_stages (int): The total number of stages in this pipeline. device (torch.device): The device to run this stage on. - num_microbatches (int): The number of microbatches to be run with this stage. group (Optional[dist.ProcessGroup]): The process group to use for communication. If `None`, the default process group will be used. Default: `None`. @@ -113,7 +112,6 @@ def __init__( self.stage_index = stage_index self.num_stages = num_stages self.device = device - self.chunks = num_microbatches self.group = group # `group_rank` is rank in process group `group`. @@ -128,10 +126,6 @@ def __init__( self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None # map microbatch ID to list of forward tensor args self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} - # Current forward chunk id - self.fwd_chunk_id: int = 0 - # Current backward chunk id - self.bwd_chunk_id: int = 0 # Caching chunk outputs for final output merge or reduction self.output_chunks: List[Any] = [] @@ -159,6 +153,13 @@ def __init__( self.grad_recv_info: Dict = {} self.grad_send_info: Optional[List] = None + # Number of backward chunks seen. This is used to determine when to do + # grad reduction in DDP or FSDP. + self._seen_bwd_chunks = 0 + + # To be populated later + self.chunks: Optional[int] = None + @property def has_backward(self) -> bool: """ @@ -184,6 +185,16 @@ def is_last(self): """ return self.stage_index == self.num_stages - 1 + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]): """ Track the output shapes/dtype of this stage since they determine the send operation(s) which must match @@ -230,6 +241,20 @@ def map_recv_to_send(a): ) return grad_send_info + @abstractmethod + def _prepare_forward_infra(self, num_microbatches: int): + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + @abstractmethod def _create_grad_recv_info( self, @@ -262,23 +287,23 @@ def _get_recv_ops( return ops - def get_fwd_recv_ops(self) -> List[dist.P2POp]: + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: """ Returns a list of ops that are needed to receive the input arguments for this stage. """ - recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[self.fwd_chunk_id] + recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] # In case there is backward pass, set requires_grad for receive buffers # before first forward - if self.has_backward and not self.set_requires_grad[self.fwd_chunk_id]: + if self.has_backward and not self.set_requires_grad[fwd_chunk_id]: for a in recv_infos: if isinstance(a, _RecvInfo): a.buffer.requires_grad_(True) return self._get_recv_ops(recv_infos) - def get_bwd_recv_ops(self) -> List[dist.P2POp]: + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: """ Returns a list of ops that are needed to receive the gradients for this stage. @@ -286,21 +311,14 @@ def get_bwd_recv_ops(self) -> List[dist.P2POp]: if not self.has_backward or self.is_last: return [] - # Create bwd recv infra lazily - recv_infos = self.grad_recv_info.setdefault( - self.bwd_chunk_id, - # `grad_recv_info` is a mirror of `act_send_info` - self._create_grad_recv_info(self.act_send_info), - ) - + recv_infos = self.grad_recv_info[bwd_chunk_id] return self._get_recv_ops(recv_infos) - def get_fwd_send_ops(self) -> List[dist.P2POp]: + def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: """ Get the activation send ops for current stage's forward. """ - # Use "-1" to get the outputs created by the last chunk - output = self.output_chunks[-1] + output = self.output_chunks[fwd_chunk_id] # Unify output form to tuple for easy correspondance with # `act_send_info` output_tuple = output if type(output) is tuple else (output,) @@ -326,10 +344,12 @@ def get_fwd_send_ops(self) -> List[dist.P2POp]: return ops - def get_bwd_send_ops(self) -> List[dist.P2POp]: + def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: """ Get the gradient send ops for current stage's backward. """ + self._check_chunk_id(bwd_chunk_id) + if not self.has_backward or self.is_first: return [] @@ -358,7 +378,7 @@ def get_bwd_send_ops(self) -> List[dist.P2POp]: else: if not (grad is None and grad_recv_stage is None): raise RuntimeError( - f"[{self.stage_index}] for chunk {self.bwd_chunk_id - 1} has gradients {grad} " + f"[{self.stage_index}] for chunk {bwd_chunk_id - 1} has gradients {grad} " f"and is expecting to send gradients to stage {grad_recv_stage}" ) return ops @@ -367,13 +387,12 @@ def clear_runtime_states(self) -> None: """ Clear runtime states of the stage. """ - # Reset pointers - self.fwd_chunk_id = 0 - self.bwd_chunk_id = 0 # map microbatch ID to list of forward tensor args self.fwd_cache.clear() # Caching chunk outputs for final output merge or reduction self.output_chunks.clear() + # Reset bwd chunk counter + self._seen_bwd_chunks = 0 # Clear grad of input buffers in between schedule steps. This is because # `torch.autograd.backward()` will accumulate gradients into leaf @@ -407,37 +426,25 @@ def get_recv_tensor(info): return tensors - def _retrieve_recv_activations( - self, - ): + def _retrieve_recv_activations(self, fwd_chunk_id: int): """ Retrieve the activations received for the current stage during forward. """ - recv_infos = self.args_recv_info[self.fwd_chunk_id] + recv_infos = self.args_recv_info[fwd_chunk_id] activations = self._map_tensor_from_recv_info(recv_infos) return activations def _retrieve_recv_grads( self, + bwd_chunk_id: int, ): """ Retrieve the gradients received for the current stage during backward. """ - recv_infos = self.grad_recv_info[self.bwd_chunk_id] + recv_infos = self.grad_recv_info[bwd_chunk_id] grads = self._map_tensor_from_recv_info(recv_infos) return grads - def _configure_data_parallel_mode(self, last_backward: bool): - """ - Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the - other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but - there are additional state-variables and performance considerations depending on the data parallelism used. - This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. - """ - if isinstance(self.submod, FSDPModule): - self.submod.set_is_last_backward(last_backward) - self.submod.set_requires_gradient_sync(last_backward) - def forward_maybe_with_nosync(self, *args, **kwargs): # If submod is wrapped with DDP, we use the `no_sync` context manager to # avoid gradient all-reduce per microbatch @@ -448,9 +455,18 @@ def forward_maybe_with_nosync(self, *args, **kwargs): out_val = self.submod(*args, **kwargs) return out_val - def backward_maybe_with_nosync(self, bwd_kwargs: Dict, bwd_chunk_id: int): + def backward_maybe_with_nosync(self, bwd_kwargs: Dict): + """ + Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] + + # If submod is wrapped by DDP if isinstance(self.submod, DistributedDataParallel): - if bwd_chunk_id == self.chunks - 1: + if last_backward: # Last chunk, prepare for gradient reduction # HACK: reaching into DDP implementation details here. Is there a better way? self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] @@ -464,14 +480,21 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict, bwd_chunk_id: int): else: with self.submod.no_sync(): # type: ignore[operator] grads_input = stage_backward(**bwd_kwargs) + # If submod is a FSDP module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(last_backward) + self.submod.set_requires_gradient_sync(last_backward) + grads_input = stage_backward(**bwd_kwargs) else: - # Non-DDP submodule, regular backward + # Non-DP submodule, regular backward grads_input = stage_backward(**bwd_kwargs) + self._seen_bwd_chunks += 1 return grads_input def forward_one_chunk( self, + fwd_chunk_id: int, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, ): @@ -488,7 +511,7 @@ def forward_one_chunk( else: # Receive activations for this chunk # Activations only come in args form - composite_args = self._retrieve_recv_activations() + composite_args = self._retrieve_recv_activations(fwd_chunk_id) composite_kwargs = {} self._validate_fwd_input(args, kwargs) @@ -520,30 +543,32 @@ def forward_one_chunk( flat_args = flatten_args(composite_args) flat_kwargs = flatten_args(composite_kwargs) flatten_input_tensors = flat_args + flat_kwargs - self.fwd_cache[self.fwd_chunk_id] = ( + self.fwd_cache[fwd_chunk_id] = ( output_tuple, # stage_output flatten_input_tensors, # input_values ) logger.debug( - f"{self.log_prefix} Forwarded chunk {self.fwd_chunk_id}, outputs: {map_debug_info(output)}" # noqa: G004 + f"{self.log_prefix} Forwarded chunk {fwd_chunk_id}, outputs: {map_debug_info(output)}" # noqa: G004 ) self._validate_fwd_outputs(output_tuple) - self.fwd_chunk_id += 1 return output def backward_one_chunk( self, + bwd_chunk_id: int, loss=None, ): """ Perform backward pass on the module. This should only be called once per microbatch. """ + self._check_chunk_id(bwd_chunk_id) + ( stage_output, input_values, - ) = self.fwd_cache.pop(self.bwd_chunk_id) + ) = self.fwd_cache.pop(bwd_chunk_id) # Compute backward if self.is_last: @@ -556,7 +581,7 @@ def backward_one_chunk( } else: # Otherwise, receive gradients from next stage - grads_output = self._retrieve_recv_grads() + grads_output = self._retrieve_recv_grads(bwd_chunk_id) # If an input to the pipeline requires gradient, # `torch.autograd.backward` will accumulate the gradient into the # `.grad` field of such input @@ -566,20 +591,17 @@ def backward_one_chunk( "input_values": input_values, } - self.grads_input = self.backward_maybe_with_nosync( - bwd_kwargs, self.bwd_chunk_id - ) - logger.debug( - f"{self.log_prefix} Backwarded chunk {self.bwd_chunk_id}" # noqa: G004 - ) - self.bwd_chunk_id += 1 + self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs) + logger.debug(f"{self.log_prefix} Backwarded chunk {bwd_chunk_id}") # noqa: G004 def _validate_fwd_input(self, args, kwargs): """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" if self.is_first: # TODO why is there a separate recv_info for each pipeline chunk? - expected_args = self.args_recv_info[self.fwd_chunk_id] + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] else: # We don't check inputs for non-0 stages assuming they don't accept # user inputs in canonical pipeline scenarios @@ -619,13 +641,20 @@ def __init__( self, stage_module: torch.nn.Module, stage_index: int, - pipe_info: Pipe.PipeInfo, + pipe_info: PipeInfo, device: torch.device, group: Optional[dist.ProcessGroup] = None, ): """ Create a pipeline stage given a stage_module to be wrapped by this stage and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage """ _PipelineStageBase.__init__( self, @@ -633,7 +662,6 @@ def __init__( stage_index, pipe_info.num_stages, device, - pipe_info.num_chunks, group, ) self.pipe_info = pipe_info @@ -660,13 +688,8 @@ def __init__( for i, node in enumerate(submod_nodes): self.submod_to_stage_index.setdefault(node.name, i) - # Prepare forward send/recv infrastructure - self._prepare_forward_infra() - # Cast submodule to device self._move_submod_to_device() - # Move ops argument to device - self._move_ops_to_device() def _move_submod_to_device(self): # Move submodule to indicated device if possible @@ -681,22 +704,13 @@ def _move_submod_to_device(self): else: self.submod.to(self.device) - def _move_ops_to_device(self): - # Today PT2 tracer does not treat `x.device` as a symbolic device; - # instead, the device of tracing time got burned into the generated - # code. Here we provide a workaround for users to manually modify the - # "device" kwarg of operations. Such operation may include: - # `torch.ones`, `torch.zeros`, `torch.rand`, etc. - if isinstance(self.submod, torch.fx.GraphModule): - modify_graph_op_device(self.submod, self.device) - - def _prepare_forward_infra(self): + def _prepare_forward_infra(self, num_microbatches: int): """ Create send/recv infrastructures for activations (during forward) """ # Flag per chunk to keep track of whether we have set `requires_grad` # for receive buffers. Format: {chunk : Boolean} - for chunk in range(self.chunks): + for chunk in range(num_microbatches): self.args_recv_info[chunk] = self._create_act_recv_info() self.set_requires_grad[chunk] = False @@ -885,22 +899,35 @@ def _create_grad_recv_info( return grad_recv_info_tuple -class PipelineStage(_PipelineStage): - def __init__( - self, - pipe: Pipe, - stage_index: int, - device: torch.device, - group: Optional[dist.ProcessGroup] = None, - ): - """ - Create a pipeline stage given a `Pipe` (representing the whole pipeline) and a stage index. - """ - # Find my stage module - stage_module = pipe.get_stage_module(stage_index) - # Get my pipe info - pipe_info = pipe.info() - super().__init__(stage_module, stage_index, pipe_info, device, group) +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) # Manual PipelineStage functions and definition @@ -910,7 +937,7 @@ def __init__( def _create_empty_tensors( - tensor: Union[torch.Tensor, List[torch.Tensor]], device: torch.device + tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device ) -> List[torch.Tensor]: """ Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s), @@ -1060,21 +1087,21 @@ def _get_stage_shapes( return stage_id_to_shapes -class ManualPipelineStage(_PipelineStageBase): +class PipelineStage(_PipelineStageBase): """ A class representing a pipeline stage in a pipeline parallelism setup. This class is created manually by providing a example input (and optionally output) as opposed to the PipelineStage class that is outputed from pipeline(). This class extends the `_PipelineStageBase` class and can similarly be used in `PipelineScheule`. + Args: submodule (nn.Module): The PyTorch module wrapped by this stage. stage_index (int): The ID of this stage. num_stages (int): The total number of stages. device (torch.device): The device where this stage is located. - num_microbatches (int): The number of microbatches to use. - input_args (Union[torch.Tensor, List[torch.tensor]], optional): The input arguments for the submodule. - output_args (Union[torch.Tensor, List[torch.tensor]], optional): The output arguments for the submodule. + input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. """ @@ -1084,17 +1111,13 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - num_microbatches: int, - input_args: Union[torch.Tensor, List[torch.Tensor]], - output_args: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, group: Optional[dist.ProcessGroup] = None, ): - super().__init__( - submodule, stage_index, num_stages, device, num_microbatches, group - ) + super().__init__(submodule, stage_index, num_stages, device, group) self.submod.to(self.device) # When we materialize the model partition on cuda, we call reset_parameters() if it is available - # logger.info(f"input args {input_args=}") self.inputs: List[torch.Tensor] = [] self.outputs: List[torch.Tensor] = [] @@ -1124,9 +1147,17 @@ def stage_global_rank(peer_rank): self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size) self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size) + logger.debug( + f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + f"inputs: {[inp.shape for inp in self.inputs]}, " + f"output: {[output.shape for output in self.outputs]}" + ) + + def _prepare_forward_infra(self, num_microbatches: int) -> None: # Receive info during forward # TODO: create args_recv_info lazily? (same needed for PipelineStage) - for chunk_id in range(self.chunks): + for chunk_id in range(num_microbatches): self.set_requires_grad[chunk_id] = False if not self.is_first: # We assume that we always receive from stage - 1 @@ -1157,13 +1188,6 @@ def stage_global_rank(peer_rank): else: self.act_send_info[idx] = [] - logger.debug( - f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 - f"{self.is_last=}, {self.num_stages=}, " - f"inputs: {[inp.shape for inp in self.inputs]}, " - f"output: {[output.shape for output in self.outputs]}" - ) - def _create_grad_recv_info( self, act_send_info: Dict, @@ -1209,7 +1233,7 @@ def _init_p2p_neighbors(self): return True -def _validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]): +def _validate_stage_shapes(pipeline_stages: List[PipelineStage]): """ Check that the buffer shapes match between stages was expected by performing an all_gather between all stages. diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index e26d398bf786..da664f7408bb 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Union import torch diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 19936f910b8a..e3266cb238ac 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs try: from urllib.parse import urlparse, urlunparse except ImportError as e: @@ -58,6 +59,12 @@ def _query_to_dict(query: str) -> Dict[str, str]: return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))} +def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: + # libuv is the default backend for TCPStore. To enable the non-libuv backend, + # user can explicitly specify ``use_libuv=0`` in the URL parameter. + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" + + def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): result = urlparse(url) if world_size_opt is None: @@ -145,13 +152,16 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=False) -> Store: +def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. The TCPStore server is assumed to be hosted on ``hostname:port``. + By default, the TCPStore server uses the asynchronous implementation + ``LibUVStoreDaemon`` which utilizes libuv. + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that the agent leader (node rank 0) hosts the TCPStore server (for which the endpoint is specified by the given ``hostname:port``). Hence @@ -194,7 +204,8 @@ def _error(msg): rank = int(query_dict["rank"]) world_size = int(query_dict["world_size"]) - use_libuv = query_dict.get("use_libuv", "0") == "1" + use_libuv = _get_use_libuv_from_query_dict(query_dict) + assert result.hostname is not None store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv) @@ -242,7 +253,7 @@ def _get_env_or_raise(env_var: str) -> str: master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) - use_libuv = query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "0")) == "1" + use_libuv = _get_use_libuv_from_query_dict(query_dict) store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv) diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index de8153e19c01..581433d220c6 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from datetime import timedelta import logging import os diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 5755b99c7571..640c4d09f062 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py index b02a6a2ff8ac..9e8660989e5a 100644 --- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py +++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs import torch.distributed as dist import torch.distributed.rpc as rpc diff --git a/torch/distributed/rpc/_utils.py b/torch/distributed/rpc/_utils.py index a532897969d4..6499a80e0e17 100644 --- a/torch/distributed/rpc/_utils.py +++ b/torch/distributed/rpc/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from typing import cast import logging diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 0f317829b207..a33358eb0dc6 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index d09ec399e390..6290f9e8e205 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] import collections diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index b1c85c47853d..c9e92980cf56 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 6e00a4d18521..2fc647c414d9 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import copyreg import io diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 67892d14e075..70328f345969 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List, Optional, Union import torch diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 89986be8b928..cdb0a5d22b74 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import partial from . import functions diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index dc3f4c19ef1e..0543ab56a877 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +# mypy: allow-untyped-defs import itertools diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 399c9c39ec61..9e418c708f03 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 651a4cc9a847..394fde457bb2 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Tuple, Union @@ -25,6 +26,7 @@ def _deprecate_warnings(func_name: str, extra_msg: str) -> None: warnings.warn( f"{func_name} is deprecated and will be removed soon. {extra_msg}", FutureWarning, + stacklevel=3, ) @@ -45,7 +47,7 @@ def _validate_tp_mesh_dim( """ if device_mesh.ndim > 1: raise ValueError(f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - "If you have a 2-D or N-D device_mesh, consider passing in device_mesh[\"tp\"]") + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]') parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) if parent_mesh: diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index 474e542551ae..baa9d638037d 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, List, Tuple import torch.nn as nn diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index d7eae93a7258..c38771ae86e2 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import Any, cast, List, Optional, Tuple @@ -112,9 +113,7 @@ def _create_sharded_tensor_md_from_dt( def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup: mesh = dt.device_mesh assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" - dim_groups = mesh.get_group() - assert isinstance(dim_groups, list) - return dim_groups[0] + return mesh.get_group() def _rewrite_spec_if_needed( diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index f7144a38e923..f2776c5123b4 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib from typing import cast, Dict, Optional, Tuple @@ -14,7 +15,7 @@ Reduction, replicate_reduction_dims, ) -from torch.distributed._tensor.placement_types import Placement, TensorMeta +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten @@ -164,14 +165,16 @@ def _log_softmax_handler( res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) - return DTensor( - res, + res_spec = DTensorSpec( spec.mesh, spec.placements, - shape=output_tensor_meta.shape, - dtype=output_tensor_meta.dtype, + tensor_meta=output_tensor_meta, + ) + + return DTensor( + res, + res_spec, requires_grad=res.requires_grad, - stride=output_tensor_meta.stride, ) @@ -317,16 +320,13 @@ def _nll_loss_forward_handler( spec.mesh, mesh_dim, ) + out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) return ( DTensor( result, - spec.mesh, - output_placements, - shape=output_tensor_meta.shape, - dtype=output_tensor_meta.dtype, + out_spec, requires_grad=result.requires_grad, - stride=output_tensor_meta.stride, ), total_weight, ) @@ -452,16 +452,17 @@ def _nll_loss_backward_handler( spec.mesh, mesh_dim, ) + # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim + out_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) return DTensor( result, - spec.mesh, - # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim - spec.placements, - shape=output_tensor_meta.shape, - dtype=output_tensor_meta.dtype, + out_spec, requires_grad=result.requires_grad, - stride=output_tensor_meta.stride, ) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 2720f9dca7d0..f532b97e97d0 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from typing import Optional, Union, Tuple, Dict diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index f47908d96c74..7c135cbbacf8 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -1,6 +1,18 @@ +# mypy: allow-untyped-defs import dataclasses import traceback -from typing import Any, Callable, Container, Dict, List, Optional, OrderedDict, Tuple, TypeVar, overload +from typing import ( + Any, + Callable, + Container, + Dict, + List, + Optional, + OrderedDict, + overload, + Tuple, + TypeVar, +) import torch import torch.distributed as dist @@ -40,6 +52,7 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, return tuple(flat_args), tuple(kwarg_keys) + def _cast_forward_inputs( dtype: Optional[torch.dtype], *args: Any, @@ -60,7 +73,10 @@ def cast_fn(x: torch.Tensor) -> torch.Tensor: return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) -def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + +def _unpack_kwargs( + flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...] +) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """See _pack_kwargs.""" assert len(kwarg_keys) <= len( flat_args @@ -77,12 +93,16 @@ def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> T @overload -def _recursive_to(inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> List[S]: +def _recursive_to( + inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> List[S]: ... @overload -def _recursive_to(inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> Tuple[T]: +def _recursive_to( + inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> Tuple[T]: ... @@ -155,9 +175,7 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: storage was already allocated. """ with torch.no_grad(): - if ( - not torch.distributed._functional_collectives.is_torchdynamo_compiling() - ): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): already_allocated = tensor._typed_storage()._size() == size.numel() if not already_allocated: tensor_storage_size = tensor._typed_storage()._size() @@ -177,9 +195,7 @@ def _free_storage(tensor: torch.Tensor): storage was already freed. """ with torch.no_grad(): - if ( - not torch.distributed._functional_collectives.is_torchdynamo_compiling() - ): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): already_freed = tensor._typed_storage()._size() == 0 if not already_freed: _p_assert( @@ -192,7 +208,6 @@ def _free_storage(tensor: torch.Tensor): tensor._typed_storage()._resize_(0) - Q = TypeVar("Q") R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) @@ -264,7 +279,9 @@ def _to_kwargs( def _verify_param_shape_across_processes( - process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None + process_group: dist.ProcessGroup, + tensors: List[torch.Tensor], + logger: Optional[dist.Logger] = None, ): return dist._verify_params_across_processes(process_group, tensors, logger) diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 75c2882dbc15..701d24ecd68c 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index a802301a47ed..79b2f5e79ae0 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number, Real import torch diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 9243da7b6bf4..95e7baeb906e 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 08d2fb3ac8e8..cc35689bee99 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import nan from torch.distributions import constraints diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 1a95dfe0d762..ed42d183a7fd 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index 16d0d6d60fbe..11f8127169a3 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.gamma import Gamma diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 83192f69547f..ae30348dd2d7 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" PyTorch provides two global :class:`ConstraintRegistry` objects that link :class:`~torch.distributions.constraints.Constraint` objects to diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index df94bbd7b14f..5dc9b46519a3 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" The following constraints are implemented: diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index 3e7f1a53a47f..34eb75b9b6f8 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index b7175aa61628..c8a5ec485b1a 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.autograd import Function from torch.autograd.function import once_differentiable diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 2fb05828a8b3..b329a277174d 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Any, Dict, Optional, Tuple from typing_extensions import deprecated diff --git a/torch/distributions/exp_family.py b/torch/distributions/exp_family.py index e60f6489d5bf..6d422aeacf08 100644 --- a/torch/distributions/exp_family.py +++ b/torch/distributions/exp_family.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions.distribution import Distribution diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index 020b5215bbdb..e557f6a6bccc 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 788f74b58556..3e70aa7f5c70 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index c189fb24e070..c115a8d71bf9 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index 0bf2f3dbacc6..918d97885738 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index e0ed5d8f8690..af886f65e833 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index ef0edc6f0fe8..0afedbc9d5d7 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 6526170b24ee..4cf977376ea3 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 35b705fd0f29..36946e798f6b 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index 5a66138b6f04..cff64d0a9e49 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.gamma import Gamma diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 923f1edcdf41..20adf1cdad2a 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import warnings from functools import total_ordering diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index 9de3c422dc4c..25393f7177c5 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import nan from torch.distributions import constraints diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 7b830cc76f9b..8069a41ab6fb 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index c1cb46f02fc2..38f5235ed278 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index f6694cf9507f..bde09b88ecb4 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index a9ef4dd26564..6cdd4f8db515 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index a3acaa990966..6f09de1f5177 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 8db242e33253..ab507f9f60a2 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 3f316e823a79..50699a592a31 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import inf from torch.distributions import Categorical, constraints diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 2784eeb214d5..4edff9c69b57 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 59edee589f9a..230b404c3fb0 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F from torch.distributions import constraints diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 3364474ba68f..0f73c8facf29 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from numbers import Number, Real diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 37e62e874f5e..957a7d6bdf7f 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.categorical import Categorical @@ -119,7 +120,7 @@ class OneHotCategoricalStraightThrough(OneHotCategorical): through gradient estimator from [1]. [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation - (Bengio et al, 2013) + (Bengio et al., 2013) """ has_rsample = True diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 07cfb417a814..76dbe29b67b6 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.distributions import constraints from torch.distributions.exponential import Exponential from torch.distributions.transformed_distribution import TransformedDistribution diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index 81c0898a577b..4ecf85dc825b 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 05e0995e4a33..ca5b6fd46b5b 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch @@ -30,10 +31,10 @@ class LogitRelaxedBernoulli(Distribution): logits (Number, Tensor): the log-odds of sampling `1` [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random - Variables (Maddison et al, 2017) + Variables (Maddison et al., 2017) [2] Categorical Reparametrization with Gumbel-Softmax - (Jang et al, 2017) + (Jang et al., 2017) """ arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 245ab87aa2a7..719c0c15d38e 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.categorical import Categorical @@ -26,10 +27,10 @@ class ExpRelaxedCategorical(Distribution): logits (Tensor): unnormalized log probability for each event [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables - (Maddison et al, 2017) + (Maddison et al., 2017) [2] Categorical Reparametrization with Gumbel-Softmax - (Jang et al, 2017) + (Jang et al., 2017) """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = ( diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index 553144e2643b..b49e56c2e313 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index b2201278ea8d..8c7cba61fb14 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict import torch diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index f2907caa6018..b81b19441335 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import numbers diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index e939bb4aae39..8b3497b4e313 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from numbers import Number import torch diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 7a6d31a05722..c6a10088fdd8 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import update_wrapper from numbers import Number from typing import Any, Dict @@ -90,6 +91,27 @@ def logits_to_probs(logits, is_binary=False): def clamp_probs(probs): + """Clamps the probabilities to be in the open interval `(0, 1)`. + + The probabilities would be clamped between `eps` and `1 - eps`, + and `eps` would be the smallest representable positive number for the input data type. + + Args: + probs (Tensor): A tensor of probabilities. + + Returns: + Tensor: The clamped probabilities. + + Examples: + >>> probs = torch.tensor([0.0, 0.5, 1.0]) + >>> clamp_probs(probs) + tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) + + >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64) + >>> clamp_probs(probs) + tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64) + + """ eps = torch.finfo(probs.dtype).eps return probs.clamp(min=eps, max=1 - eps) diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 17f52fad25b3..8be9ffb7778c 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index 39e07d580bc5..607190df1e1e 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.distributions import constraints from torch.distributions.exponential import Exponential diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 733efbbeb95f..3ec13c25017f 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import warnings from numbers import Number diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index c1cea8ec005f..930915f96f9b 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 235b43b969aa..20411dc87cce 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from typing import List diff --git a/torch/export/_safeguard.py b/torch/export/_safeguard.py index 92fb9b434041..76f22f369c56 100644 --- a/torch/export/_safeguard.py +++ b/torch/export/_safeguard.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode from torch.overrides import TorchFunctionMode diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 31c933b4518a..ee25dbc2e1ea 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import functools import inspect @@ -481,6 +482,7 @@ def _export_to_torch_ir( *, preserve_module_call_signature: Tuple[str, ...] = (), disable_constraint_solver: bool = False, + _allow_complex_guards_as_runtime_asserts: bool = False, restore_fqn: bool = True, _log_export_usage: bool = True, same_signature: bool = True, @@ -513,6 +515,10 @@ def _export_to_torch_ir( assume_static_by_default=True, tracing_mode="symbolic", disable_constraint_solver=disable_constraint_solver, + # currently the following 2 flags are tied together for export purposes, + # but untangle for sake of dynamo export api + prefer_deferred_runtime_asserts_over_guards=_allow_complex_guards_as_runtime_asserts, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, _log_export_usage=_log_export_usage, same_signature=same_signature, )( @@ -666,19 +672,22 @@ def make_argument_spec(i, node) -> ArgumentSpec: fake_mode = detect_fake_mode(flat_args) - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - insert_deferred_runtime_asserts( - gm, - fake_mode.shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, + from torch._dynamo import config as _dynamo_config + + if not _dynamo_config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" ) + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + fake_mode.shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) if pre_dispatch: from torch._export.passes.replace_set_grad_with_hop_pass import ( @@ -989,17 +998,17 @@ def _temp_disable_texpr_fuser(): torch._C._jit_set_texpr_fuser_enabled(original_state) -def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): - with _temp_disable_texpr_fuser(): +class _WrapperModule(torch.nn.Module): + def __init__(self, f): + super().__init__() + self.f = f - class _WrapperModule(torch.nn.Module): - def __init__(self, f): - super().__init__() - self.f = f + def forward(self, *args, **kwargs): + return self.f(*args, **kwargs) - def forward(self, *args, **kwargs): - return self.f(*args, **kwargs) +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + with _temp_disable_texpr_fuser(): from torch.jit._trace import TopLevelTracedModule export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) @@ -1024,6 +1033,7 @@ def forward(self, *args, **kwargs): strict=False, _is_torch_jit_trace=True, ).module() + else: return _export( _WrapperModule(traced_callable), @@ -1043,6 +1053,7 @@ def _strict_export( pre_dispatch: bool, original_state_dict: Dict[str, Any], orig_in_spec: TreeSpec, + _allow_complex_guards_as_runtime_asserts: bool, _disable_forced_specializations: Optional[bool], _is_torch_jit_trace: bool, ): @@ -1053,6 +1064,7 @@ def _strict_export( dynamic_shapes, preserve_module_call_signature=preserve_module_call_signature, restore_fqn=False, # don't need to restore because we will do it later + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, _log_export_usage=False, ) @@ -1215,6 +1227,7 @@ def _non_strict_export( pre_dispatch: bool, original_state_dict: Dict[str, Any], orig_in_spec: TreeSpec, + _allow_complex_guards_as_runtime_asserts: bool, _disable_forced_specializations: Optional[bool], _is_torch_jit_trace: bool, ): @@ -1283,7 +1296,12 @@ def forward(self, *args, **kwargs): equalities_inputs, original_signature, ) = make_fake_inputs( - mod, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace + mod, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=_is_torch_jit_trace, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, # for shape env initialization ) fake_params_buffers = make_fake_params_buffers(fake_mode, _get_params_buffers(mod)) @@ -1346,6 +1364,7 @@ def _export( strict: bool = True, preserve_module_call_signature: Tuple[str, ...] = (), pre_dispatch: bool = False, + _allow_complex_guards_as_runtime_asserts: bool = False, _disable_forced_specializations: Optional[bool] = False, _is_torch_jit_trace: bool = False, ) -> ExportedProgram: @@ -1378,13 +1397,23 @@ def _export( preserve_module_call_signature: A list of submodule paths for which the original calling conventions are preserved as metadata. + _allow_complex_guards_as_runtime_asserts: + With the current dynamic shapes language for dims and derived dims, we can run into constraints + that are not expressible with the language. For example, flattening a matrix and adding to a vector, + both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. + By default, we either raise a constraint violation error or specialize to static values. + If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime + assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops + required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar). + Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints + while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. + _disable_forced_specializations: - By default, some inferred dynamic shapes guards/constraints that are not expressible with the current - dynamic shapes language will lead to specialization to the concrete input values provided. - If _disable_forced_specializations is set to True, we will not specialize, and will not perform runtime - checks on such produced guards. Instead, we allow the user to specify arbitrary shapes, - and fail during runtime if the inputs are invalid. Constraints expressible with the language - (e.g. ranges, linear derived dims) will still be enforced. + Similar to _allow_complex_guards_as_runtime_asserts, but only avoids specializing to static values if set to True. + For complex guards that don't specialize, this flag doesn't have any effect. Ideally this would be subsumed by + _allow_complex_guards_as_runtime_asserts, but this handles one additional case: single-variable equalities where + the symbol is solvable for a concrete value (e.g. Eq(s0 // 4, 400) -> s0 = 1600). If set to True, this flag will + avoid specializations. Direct equalities (e.g. s0 = 4), will still specialize. Returns: An ExportedProgram containing the traced method. @@ -1432,6 +1461,7 @@ def _export( pre_dispatch, original_state_dict, orig_in_spec, + _allow_complex_guards_as_runtime_asserts, _disable_forced_specializations, _is_torch_jit_trace, ) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 2fdb7916eeeb..97df0562caa7 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from itertools import chain from typing import Any, Dict, List, Optional, Tuple diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index a4ed16e975b8..a5ce066faa47 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,14 +1,20 @@ +# mypy: allow-untyped-defs import builtins import dataclasses import inspect -import math import sys import weakref from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import torch -from torch.utils._pytree import _get_node_type, BUILTIN_TYPES, SUPPORTED_NODES, tree_map +from torch.utils._pytree import ( + _get_node_type, + BUILTIN_TYPES, + SUPPORTED_NODES, + tree_flatten, + tree_map, +) from .exported_program import ExportedProgram @@ -19,7 +25,13 @@ from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint -__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"] +__all__ = [ + "Constraint", + "Dim", + "dims", + "dynamic_dim", + "refine_dynamic_shapes_from_suggested_fixes", +] class _Dim(type): @@ -254,11 +266,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=math.inf): + def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges + if upper is None: + upper = sys.maxsize - 1 + constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -486,7 +501,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally - import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -496,7 +510,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False + vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False ), debug_name=debug_name, ) @@ -896,3 +910,156 @@ def assoc_shape(t, dynamic_shape): constraints.append(primary) return constraints # type: ignore[return-value] + + +def _get_dim_name_mapping( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] +): + name_to_dim = {} + for dim in tree_flatten( + dynamic_shapes, + is_leaf=lambda x: isinstance(x, _Dim), + )[0]: + if dim is None or isinstance(dim, int): + continue + name_to_dim[dim.__name__] = dim + if isinstance(dim, _DerivedDim): + name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] + return name_to_dim + + +def refine_dynamic_shapes_from_suggested_fixes( + msg: str, + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], +) -> Union[Dict[str, Any], Tuple[Any], List[Any]]: + """ + For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes. + Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes. + + For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range, + or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such. + + e.g. + Suggested fixes: + + dim = Dim('dim', min=3, max=6) -> this just refines the dim's range + dim = 4 -> this specializes to a constant + dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation + + However, suggested fixes associated with derived dims can be more complicated. + For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root. + + e.g. + dx = Dim('dx') + dy = dx + 2 + dynamic_shapes = {"x": (dx,), "y": (dy,)} + + Suggested fixes: + + dx = 4 # specialization will lead to dy also specializing = 6 + dx = Dim('dx', max=6) # dy now has max = 8 + + Derived dims suggested fixes can also be used to express divisibility constraints. + This involves creating new root dims that aren't tied to a particular input shape. + In this case the root dims won't appear directly in the new spec, but as a root of + one of the dims. + + e.g. + Suggested fixes: + + _dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will + dx = 4*_dx # dx is now divisible by 4, with a max value of 4096 + """ + + import re + + import sympy + + from torch._dynamo.exc import UserError, UserErrorType + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + try: + shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() + except Exception as exc: + raise UserError( + UserErrorType.INVALID_INPUT, + "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", + ) from exc + + # build shape_fixes dictionary + shape_fixes = {} + for fix in shape_fixes_msg.split("\n"): + fix = fix.strip() + if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): + name = match.group(1) + _min, _max = None, None + if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): + _min = int(match_min.group(1)) + if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): + _max = int(match_max.group(1)) + shape_fixes[name] = Dim(name, min=_min, max=_max) + else: + name, expr = fix.split(" = ") + expr = sympy.sympify(expr) + if isinstance(expr, sympy.Number): + shape_fixes[name] = int(expr) # static, integer + else: + shape_fixes[name] = expr # relation or derived dim + + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + # track derived dim roots + roots: Set[str] = set() + for k, c in shape_fixes.items(): + assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr)) + if isinstance(c, sympy.Expr): # check dim/derived dim expression + assert _is_supported_equivalence(c) + shape_fixes[k] = c + roots.add(str(next(iter(c.free_symbols)))) + if isinstance(c, _DerivedDim): + roots.add(c.root.__name__) # type: ignore[attr-defined] + + # check keys are existing dims or new roots + for k, c in shape_fixes.items(): + assert k in name_to_dim or k in roots + + # cache so we don't produce multiple derived dim objects + derived_dim_cache: Dict[str, _DerivedDim] = {} + + def apply_fixes(dim, dummy): + if dim is None or isinstance(dim, int): # not dynamic + return dim + elif dim.__name__ in shape_fixes: # directly fix + fix = shape_fixes[dim.__name__] + if isinstance(fix, sympy.Expr): # now derived or related + if str(fix) in derived_dim_cache: + return derived_dim_cache[str(fix)] + else: + symbol = next(iter(fix.free_symbols)) + # try to locate symbol + if symbol.name in shape_fixes: # type: ignore[attr-defined] + root = shape_fixes[symbol.name] # type: ignore[attr-defined] + else: + assert symbol.name in name_to_dim # type: ignore[attr-defined] + root = name_to_dim[symbol.name] # type: ignore[attr-defined] + # figure out value of fix + modulus, remainder = sympy.polys.polytools.div(fix, symbol) + dim = root + if modulus != 1: + dim = int(modulus) * dim + if remainder != 0: + dim = dim + int(remainder) + derived_dim_cache[str(fix)] = dim + return dim + else: + return fix + elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] + if dim.__name__ in derived_dim_cache: + return derived_dim_cache[dim.__name__] + else: # evaluate new derived value based on root + _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] + derived_dim_cache[dim.__name__] = _dim + return _dim + return dim # unchanged dim + + return _tree_map(apply_fixes, dynamic_shapes, dynamic_shapes) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index cc6a9e65dd34..7b29251ca4ae 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import dataclasses import functools @@ -33,10 +34,13 @@ import torch.utils._pytree as pytree from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.fx._compatibility import compatibility + +from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from .graph_signature import ( # noqa: F401 _sig_to_specs, @@ -660,6 +664,29 @@ def update_arg(old_arg, new_ph): _replace_sym_size_ops_pass(gm) + from torch._dynamo import config as _dynamo_config + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + if not _dynamo_config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + exported_program = ExportedProgram( root=gm, graph=gm.graph, @@ -799,30 +826,31 @@ def _update( ) +def _get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + def _get_updated_range_constraints( gm: torch.fx.GraphModule, old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None, _is_executorch: bool = True, ) -> "Dict[sympy.Symbol, Any]": - def get_shape_env(gm): - vals = [ - node.meta["val"] - for node in gm.graph.nodes - if node.meta.get("val", None) is not None - ] - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode(vals) - if fake_mode is not None: - return fake_mode.shape_env - for v in vals: - if isinstance(v, torch.SymInt): - return v.node.shape_env - # FIXME(tmanlaibaatar) Remove this whole branch once https://github.com/pytorch/pytorch/pull/123764 if _is_executorch: assert old_range_constraints is None - shape_env = get_shape_env(gm) + shape_env = _get_shape_env(gm) if shape_env is None: return {} range_constraints = { @@ -840,7 +868,7 @@ def get_shape_env(gm): assert old_range_constraints is not None - shape_env = get_shape_env(gm) + shape_env = _get_shape_env(gm) if shape_env is None: return {} diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index ecfd7853400d..ce62e8793941 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses from enum import auto, Enum from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 2bb38fccc378..11075058a0e9 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import copy import operator @@ -337,16 +338,35 @@ def add_to_consts_map(obj_id, node_name, target_name): inputs_to_state[n] = targets _sink_params(self, inputs_to_state, []) - # Check all input nodes has been processed. - for name, module in self.named_modules(): - if not hasattr(module, "graph"): - continue - for node in module.graph.nodes: - if node.op != "placeholder": - continue - assert ( - node.name not in inputs_to_state - ), f"{node.name} was not sunk into the module {name} which has the graph: {module.graph}" + + # Helper function to check input nodes of `module` has been processed. + def check_module_inputs(module, scope): + if hasattr(module, "graph"): + for node in module.graph.nodes: + # sink_params() should turn placeholders into get_attr nodes + # for attributes that are within scope of the current + # module. We allow attributes to remain as placeholders if + # they are inputs in the original module signature, meaning + # they are a parent module's attribute, and therefore out of + # scope of the current module. + if ( + node.op == "placeholder" + and node.name in inputs_to_state + and any( + fqn.split(".")[: len(scope)] == scope + for fqn in inputs_to_state[node.name] + ) # matching scope to avoid wrong assert + ): + raise AssertionError( + f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}" + ) + # Recursively check the submodules. + for name, submod in module.named_children(): + scope.append(name) + check_module_inputs(submod, scope) + + # Recurively check all input nodes have been processed. + check_module_inputs(self, []) # Cache so we don't have to compute this every time. # NOTE: this needs to be kept in sync with the placeholders in @@ -711,14 +731,20 @@ def __init__( ) if isinstance(arg, ConstantArgument): continue - flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node + + if arg.name in self.seen_nodes: + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[ + self.seen_nodes[arg.name] + ] = flat_arg_node with self.parent.graph.inserting_before(self.parent_call_module): input_nodes: List[Optional[torch.fx.Node]] = [] for input in signature.inputs: if isinstance(input, ConstantArgument) and input.value is None: input_nodes.append(None) + elif input.name not in self.seen_nodes: + input_nodes.append(None) else: assert isinstance(input, (TensorArgument, SymIntArgument)) input_nodes.append( @@ -781,18 +807,32 @@ def finalize_outputs(self): if signature is not None and self.parent is not None: for output in signature.outputs: if isinstance(output, (TensorArgument, SymIntArgument)): - orig_outputs.append(self.seen_nodes[output.name]) + if output.name in self.seen_nodes: + orig_outputs.append(self.seen_nodes[output.name]) + else: + orig_outputs.append(None) else: raise RuntimeError( f"Unsupported data type for output node: {output}" ) + def get_actual_output_node(output): + if output is None: + return None + + seen_node = self.seen_nodes[output.name] + if seen_node in self.node_map: + return self.node_map[seen_node] + elif seen_node in self.node_to_placeholder: + return self.node_to_placeholder[seen_node] + else: + raise RuntimeError( + f"Could not find output node {output}. Graph: {self.graph}" + ) + tree_out_node = _generate_unflatten( self.module, - tuple( - self.node_map[self.seen_nodes[output.name]] - for output in orig_outputs - ), + tuple(get_actual_output_node(output) for output in orig_outputs), signature.out_spec, ) parent_out: Optional[torch.fx.Node] = _generate_flatten( @@ -832,6 +872,8 @@ def finalize_outputs(self): self.parent.node_map[orig_outputs[0]] = parent_out else: for i, orig_output in enumerate(orig_outputs): + if orig_output is None: + continue # Use Proxy to record getitem access. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] proxy_out.meta["val"] = orig_output.meta.get("val") @@ -1010,14 +1052,23 @@ def _sink_params( scope: tracks where we are in the module hierarchy, so that we can emit the right `getattr(self, "foo.bar")` calls, etc. """ + # This dict records inputs removed by child modules. + # Maps the module object id to the list of placeholder node names + # in the child module that were removed. + module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list) + # We need to use _modules here instead of named_children(), because we # explicitly want duplicate modules to show up in the traversal. for name, submodule in module._modules.items(): - _sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]) + submod_id_to_inputs_removed = _sink_params( + cast(torch.nn.Module, submodule), inputs_to_state, scope + [name] + ) + for k, v in submod_id_to_inputs_removed.items(): + module_id_to_inputs_removed[k].extend(v) if not hasattr(module, "graph"): # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) - return + return module_id_to_inputs_removed graph = module.graph inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) @@ -1026,32 +1077,49 @@ def _sink_params( # Also remove from call_module nodes call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) for node in call_module_nodes: - node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) + submodule = _recursive_getattr(module, node.target.split(".")) + # remove placeholder from call_module node arguments, only if we've + # erased the placeholder node in the corresponding _sink_params() call + if submodule is not None and id(submodule) in module_id_to_inputs_removed: + node.args = tuple( + filter( + lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], + node.args, + ) + ) + # Filter out inputs_to_state corresponding to current scope. + inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {} for node in inputs: if node.name not in inputs_to_state: continue - if len(node.users) > 0: - state_name = None - for sn in inputs_to_state[node.name]: - sn_split = sn.split(".") - if sn_split[: len(scope)] == scope: - state_name = sn_split - break - - # If there's a mismatch beteewn scope name and state name, then - # there must be multuple scopes pointing to the same state name, - # meaning some modules are shared. In such case, we can simply skip - # updating the current node because another later iteration will - # take care of this input node when the unique match between scope - # and state name occurs. To make sure this always happen, we should - # enforce the invariant that no placeholder node in the unflattened - # graph appears in inputs_to_state dict, which means all the extra - # input nodes have been handled. - if state_name is None: - continue + state_name = None + for sn in inputs_to_state[node.name]: + sn_split = sn.split(".") + if sn_split[: len(scope)] == scope: + state_name = sn_split + break + + # If there's a mismatch beteewn scope name and state name, then + # there must be multuple scopes pointing to the same state name, + # meaning some modules are shared. In such case, we can simply skip + # updating the current node because another later iteration will + # take care of this input node when the unique match between scope + # and state name occurs. To make sure this always happen, we should + # enforce the invariant that no placeholder node in the unflattened + # graph appears in inputs_to_state dict, which means all the extra + # input nodes have been handled. + if state_name is None: + continue + inputs_to_state_of_scope[node] = state_name + + # Record name of remove inputs for return purpose. + inputs_removed: List[str] = [] + + for node, state_name in inputs_to_state_of_scope.items(): + if len(node.users) > 0: attr_path = state_name[len(scope) :] state_attr = _recursive_getattr(module, attr_path) assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) @@ -1061,13 +1129,20 @@ def _sink_params( new_node = graph.create_node("get_attr", ".".join(attr_path)) node.replace_all_uses_with(new_node, propagate_meta=True) + graph.erase_node(node) + inputs_removed.append(node.name) + if isinstance(module, InterpreterModule): module.finalize() + return {id(module): inputs_removed} + def _recursive_getattr(obj, attr_path): for attr in attr_path: + if not hasattr(obj, attr): + return None obj = getattr(obj, attr) return obj diff --git a/torch/functional.py b/torch/functional.py index 7c07ae348631..a836c06f028d 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING ) diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 6a398bebb599..e1623c44f193 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py index 14588fad9a09..4258979eb3e7 100644 --- a/torch/fx/_compatibility.py +++ b/torch/fx/_compatibility.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict import textwrap diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index a4b4bc0d69d7..79a18de12f31 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from torch.fx import GraphModule diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 29ab0c867911..da02e21528de 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 5725c4c6a05c..25a342f064c8 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import copy import functools diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index 5f99d698586c..36c831dfdee0 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Optional import torch @@ -5,7 +6,7 @@ from torch._logging import LazyString -def lazy_format_graph_code(name, gm, maybe_id=None): +def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): """ Returns a LazyString that formats the graph code. """ @@ -16,11 +17,14 @@ def format_name(): else: return name + if "print_output" not in kwargs: + kwargs["print_output"] = False + return LazyString( lambda: _format_graph_code( f"===== {format_name()} =====\n", gm.forward.__code__.co_filename, - gm.print_readable(print_output=False), + gm.print_readable(**kwargs), ) ) diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index 032ce14b6ec7..d1b5b5f2d376 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -1,10 +1,21 @@ +# mypy: allow-untyped-defs from torch.fx.proxy import Proxy from ._compatibility import compatibility @compatibility(is_backward_compatible=False) def annotate(val, type): - # val could be either a regular value (not tracing) - # or fx.Proxy (tracing) + """ + Annotates a Proxy object with a given type. + + This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object + Args: + val (object): An object to be annotated if its type is torch.fx.Proxy. + type (object): A type to be assigned to a given proxy object as val. + Returns: + The given val. + Raises: + RuntimeError: If a val already has a type in its node. + """ if isinstance(val, Proxy): if val.node.type: raise RuntimeError(f"Tried to annotate a value that already had a type on it!" diff --git a/torch/fx/experimental/_sym_dispatch_mode.py b/torch/fx/experimental/_sym_dispatch_mode.py index c3385de61683..6e48a8ca18f4 100644 --- a/torch/fx/experimental/_sym_dispatch_mode.py +++ b/torch/fx/experimental/_sym_dispatch_mode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Type __all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"] diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index fc28f112323f..9b347762dedb 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from collections import deque from typing import Dict, List, Set, NamedTuple, Tuple, Deque diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 8176ccb562fa..dca495b7f691 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re from typing import Callable, Dict, Optional, Set, Union @@ -258,7 +259,7 @@ def mod_partition(node: torch.fx.Node): # worry about whether this is one or more tensors because the original graph # correctly uses getitem to extract individual tensors if there are multiple folded. fx_const_folded_attrs_name = get_unique_attr_name_in_module( - split, "_FX_CONST_FOLDED_ATTRS" + mod_traced, "_FX_CONST_FOLDED_ATTRS" ) setattr( split, diff --git a/torch/fx/experimental/debug.py b/torch/fx/experimental/debug.py index bd6fed690914..d3c482319f2e 100644 --- a/torch/fx/experimental/debug.py +++ b/torch/fx/experimental/debug.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.fx as fx def set_trace(gm: fx.GraphModule) -> fx.GraphModule: diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index e44a75ddad08..a6ac80fd72fb 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import reduce import torch import operator diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index bd56694773e9..c1a634b2602a 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.node import Node diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index be19e7b93ac8..b09e221f6b36 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.fx import warnings diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 3c1f724d26a5..45038837cae6 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ op_mod, op_gt, op_lt, op_neq, op_eq from torch.fx.tensor_type import TensorType, Dyn diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 031562393edc..e04fc26b408e 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import operator import warnings diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py index 15af0241ec5b..c8cf70006cd8 100644 --- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py index a43d8f3ebbe0..99f94609f265 100644 --- a/torch/fx/experimental/migrate_gradual_types/util.py +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ BVar from torch.fx.experimental.migrate_gradual_types.operation import op_leq diff --git a/torch/fx/experimental/normalize.py b/torch/fx/experimental/normalize.py index 06bc2309975c..30b076a72bee 100644 --- a/torch/fx/experimental/normalize.py +++ b/torch/fx/experimental/normalize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator from typing import Any, Callable, Dict, Tuple, Optional diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index be411d9b6eff..8362c0cb88ac 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch.fx as fx from torch.fx.node import Argument, Target from torch.nn.utils.fusion import fuse_conv_bn_eval diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index d96c6b40667f..796c65a43022 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum from typing import NamedTuple, Dict, List, Set diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 4bf9ebab17b3..1c384d9dfbeb 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import inspect import itertools @@ -277,7 +278,13 @@ def wrapper(*args, **kwargs): raise except Exception: - log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) raise return wrapper diff --git a/torch/fx/experimental/refinement_types.py b/torch/fx/experimental/refinement_types.py index 762e4340f12b..a33ddf3710a4 100644 --- a/torch/fx/experimental/refinement_types.py +++ b/torch/fx/experimental/refinement_types.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs class Equality: def __init__(self, lhs, rhs): self.lhs = lhs diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 85a95895f7c9..8cfb030b9f77 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import inspect import textwrap diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index a2a840408618..5c7ab78706cb 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.fx import inspect diff --git a/torch/fx/experimental/shape_inference/infer_shape.py b/torch/fx/experimental/shape_inference/infer_shape.py index 3c2e0c22bd89..10f5d53712ae 100644 --- a/torch/fx/experimental/shape_inference/infer_shape.py +++ b/torch/fx/experimental/shape_inference/infer_shape.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from collections import defaultdict diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 98cba67a73a1..8f270c56e6c1 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file does three things: - Contains the definition of SymNode @@ -267,8 +268,11 @@ def mul(self, other) -> "SymNode": def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def pow(self, other) -> "SymNode": - return self._pow(other) # type: ignore[attr-defined] + def float_pow(self, other) -> "SymNode": + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> "SymNode": + return self._pow_by_natural(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -276,11 +280,14 @@ def and_(self, other) -> "SymNode": def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def truediv(self, other) -> "SymNode": - return self._truediv(other) # type: ignore[attr-defined] + def float_truediv(self, other) -> "SymNode": + return self._float_truediv(other) # type: ignore[attr-defined] - def floordiv(self, other) -> "SymNode": - return self._floordiv(other) # type: ignore[attr-defined] + def int_truediv(self, other) -> "SymNode": + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> "SymNode": + return self._int_floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -361,6 +368,17 @@ def sym_or(self, other): def sym_and(self, other): return self.and_(other) + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> "SymNode": + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -477,7 +495,7 @@ def is_constant(self): "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "floordiv": operator.floordiv, + "int_floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -489,7 +507,8 @@ def is_constant(self): "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "pow": operator.pow, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -498,12 +517,14 @@ def is_constant(self): "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "truediv": operator.truediv, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", + "sym_int", "ceil", "floor", "neg", @@ -559,20 +580,20 @@ def fn(self): bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer"} +only_float_magic_methods = {"is_integer", "round", "sym_int"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"truediv", "sym_float", "pow"} +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc"} +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} always_bool_magic_methods = { "eq", "ne", @@ -590,10 +611,16 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_truediv(a, b): - from torch.utils._sympy.functions import TrueDiv +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv - return TrueDiv(a, b) + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) def _sympy_floordiv(a, b): @@ -603,15 +630,24 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + - return Mod(a, b) +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + return PowByNatural(a, b) -def _sympy_pow(a, b): - from torch.utils._sympy.functions import Pow - return Pow(a, b) +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) def _sympy_and(a, b): @@ -643,11 +679,13 @@ def _sympy_rshift(a, b): "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow": _sympy_pow, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, "and": _sympy_and, "or": _sympy_or, - "truediv": _sympy_truediv, - "floordiv": _sympy_floordiv, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -672,21 +710,23 @@ def _floor_ceil_helper(a, fn): def _sympy_floor(a): - import sympy + from torch.utils._sympy.functions import FloorToInt - return _floor_ceil_helper(a, sympy.floor) + return FloorToInt(a) +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import Trunc + from torch.utils._sympy.functions import TruncToInt - return Trunc(a) + return TruncToInt(a) def _sympy_ceil(a): - import sympy + from torch.utils._sympy.functions import CeilToInt - return _floor_ceil_helper(a, sympy.ceiling) + return CeilToInt(a) def _sympy_eq(a, b): @@ -771,26 +811,28 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import Round, RoundDecimal + from torch.utils._sympy.functions import RoundDecimal, RoundToInt if ndigits is None: - return Round(number) + return RoundToInt(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - # Cannot use sympy.Float(a) here, coz it expects python literals - # Multiply by 1.0 to cast to float. This is needed when the input - # is a SymInt which has the assumption that it is integer and - # SymPy will otherwise assume that return value cannot be a float. - return a * 1.0 + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) def _sympy_is_integer(a): import sympy - return sympy.Eq(sympy.floor(a), a) + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) magic_methods = { @@ -989,9 +1031,26 @@ def binary_magic_impl(self, other): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) - # TODO: consider constant prop here try: - out = func(self.expr, other.expr) + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1122,9 +1181,13 @@ def round_impl(self, ndigits=None): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise + out = safe_expand(out) - pytype = int if ndigits is None else self.pytype + if ndigits is None: + pytype = int + else: + pytype = self.pytype out_hint = None if self.hint is not None: @@ -1136,6 +1199,7 @@ def round_impl(self, ndigits=None): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1259,6 +1323,32 @@ def is_constant(x): return x.node.is_constant() return False + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + if method in bool_becomes_int_magic_methods: def promote(x): @@ -1272,6 +1362,41 @@ def promote(x): def promote(x): return x + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1286,9 +1411,12 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1300,8 +1428,11 @@ def binary_magic_impl(self, other): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 9a9d7baa21ef..e1170fd49f8c 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,9 @@ from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator +from torch.utils._sympy.functions import ( + FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt +) from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -450,20 +452,21 @@ def free_unbacked_symbols(x): # setup! def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]: if ( - node.op == "placeholder" and "val" in node.meta and isinstance(node.meta["val"], torch.SymInt) and - isinstance(node.meta["val"].node.expr, sympy.Symbol) + isinstance(node.meta["val"].node.expr, sympy.Symbol) and + (node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr)) ): return node.meta["val"].node.expr return None def find_symbol_binding_fx_nodes(graph): - return { - node.meta["val"].node.expr: node - for node in graph.nodes - if is_symbol_binding_fx_node(node) - } + r = {} + # NB: Prefer first occurrence of symbol + for node in graph.nodes: + if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r: + r[node.meta["val"].node.expr] = node + return r # Analogous to ConvertIntSource @@ -869,9 +872,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sympy.oo + min = -sys.maxsize - 1 if max is None: - max = sympy.oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -979,16 +982,6 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) -def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): - """ - We assert that the bounds are either Boolean, or not finite, or can be computed - in exact prevision via rational arithmetic. - The only exception to this is the rare case when the user calls `sqrt(s0)` - sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) - """ - assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) - assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) - class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1191,6 +1184,19 @@ def _assert_symbol_context(symbolic_context): assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" +def _is_supported_equivalence(expr): + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + if len(expr.args) > 2: + return False + lhs, rhs = expr.args + return ( + (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or + (isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs)) + ) + return isinstance(expr, sympy.Symbol) @dataclass(frozen=True) class SymbolicContext: @@ -1374,14 +1380,21 @@ def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: 'Min': min, 'Max': max, 'Mod': operator.mod, + 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, + 'FloorToInt': math.floor, + 'CeilToInt': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'Round': builtins.round, + 'RoundToInt': builtins.round, 'RoundDecimal': builtins.round, + 'TruncToInt': math.trunc, + 'IntTrueDiv': operator.truediv, + 'FloatTrueDiv': operator.truediv, + 'ToFloat': builtins.float, } @@ -1526,7 +1539,14 @@ class DimConstraints: Solutions are "static" values or simplified "dynamic" constraints. """ - def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_debug_name): + def __init__( + self, + symbol_to_source, + var_to_val, + marked_dynamic, + source_name_to_debug_name, + _allow_complex_guards_as_runtime_asserts=False, + ): # We try to solve systems of inequalities with 1 free variable. self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) # Among them, we prioritize solving for a free variable that has equalities. @@ -1568,6 +1588,9 @@ def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_ # symbols that are marked dynamic self._marked_dynamic = marked_dynamic + # for constraints we can't express with the dynamic shapes language, defer as runtime asserts in export + self._allow_complex_guards_as_runtime_asserts = _allow_complex_guards_as_runtime_asserts + def rewrite_with_congruences(self, s, expr): """ Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. @@ -1619,10 +1642,17 @@ def floor_div_handler(*args): congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -1831,7 +1861,7 @@ def solve( symbolic_equivalences = self._symbolic_equivalences self._symbolic_equivalences = [] for source, expr in symbolic_equivalences: - if not _disable_forced_specializations and not self._is_supported_equivalence(expr): + if not _disable_forced_specializations and not _is_supported_equivalence(expr): for s in expr.free_symbols: self._force_specialization(s) sexpr = self._dcp._print_Symbol(s) @@ -1842,19 +1872,6 @@ def solve( for source, expr in self._symbolic_equivalences: self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}") - @classmethod - def _is_supported_equivalence(cls, expr): - # Currently supported Dim ops are linear expressions with integer coefficients. - # So check that expr only contains +, *, ints, and a single occurrence of a symbol. - # (See also documentation of dynamic_shapes._DerivedDim.) - if isinstance(expr, (sympy.Add, sympy.Mul)): - lhs, rhs = expr.args - return ( - (cls._is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or - (isinstance(lhs, sympy.Integer) and cls._is_supported_equivalence(rhs)) - ) - return isinstance(expr, sympy.Symbol) - @classmethod def _is_supported_congruence(cls, congruence): base, divisor = congruence.args @@ -1976,7 +1993,10 @@ def _check_same_range(c, dim): return ( self._is_dim(dim) and ("min" in c or "max" in c) - and (dim.min < 2 or dim.min == c.get("min", 2)) # let pass if min < 2 + and ( + (dim.min < 2 and c.get("min", 2) == 2) + or dim.min == c.get("min", 2) + ) # let pass if analysis min = 2 and specified min = 0/1 and dim.max == c.get("max", sys.maxsize - 1) ) @@ -2104,6 +2124,7 @@ def prettify_results( forced_specializations=None, ): """Format a message for constraint violation erros""" + from torch.export.dynamic_shapes import _get_dim_name_mapping if self._dcp.source_name_to_debug_name: def transform(s, inverse=False): @@ -2141,16 +2162,7 @@ def relation_with_digit(expr, op, digit): results[expr]["eq"] = digit # retrieve dynamic shapes - name_to_dim = {} - for dim in pytree.tree_flatten( - dynamic_shapes, - is_leaf=lambda x: self._is_derived_dim(x) or self._is_dim(x), - )[0]: - if dim is None or isinstance(dim, int): - continue - name_to_dim[dim.__name__] = dim - if self._is_derived_dim(dim): - name_to_dim[dim.root.__name__] = dim.root + name_to_dim = _get_dim_name_mapping(dynamic_shapes) for s in self._static_results.union(self._dynamic_results): t = transform(s) @@ -2189,7 +2201,7 @@ def relation_with_digit(expr, op, digit): buf += ( f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " - "For more information, run with TORCH_LOGS=\"+dynamic\".\n" + 'For more information, run with TORCH_LOGS="+dynamic".\n' ) for s, val in forced_specializations.items(): buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n" @@ -2211,7 +2223,7 @@ def relation_with_digit(expr, op, digit): other = c["eq"] if isinstance(other, int): others.append(f"{k} = {other}") - elif self._is_supported_equivalence(other): + elif _is_supported_equivalence(other): others.append(f"{k} = {other}") else: min_ = c.get("min", None) @@ -2339,6 +2351,7 @@ class ShapeEnvSettings: specialize_zero_one: bool duck_shape: bool prefer_deferred_runtime_asserts_over_guards: bool + _allow_complex_guards_as_runtime_asserts: bool class ShapeEnv: @@ -2432,6 +2445,10 @@ def _init( # in guards is helpful, since these guards in some sense are overly # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 prefer_deferred_runtime_asserts_over_guards=False, + # When True, does not emit or raise constraint violation errors on + # implicit guards generated by ops, and defers to runtime assertions + # in the graph instead. For export. + _allow_complex_guards_as_runtime_asserts=False, # XXX Add any new settings that could affect FakeTensor evaluation # to: torch._subclasses.fake_tensor._ShapeEnvSettings ): @@ -2444,6 +2461,7 @@ def _init( specialize_zero_one=specialize_zero_one, duck_shape=duck_shape, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, ) self.guards: List[ShapeGuard] = [] @@ -2629,6 +2647,10 @@ def duck_shape(self): def prefer_deferred_runtime_asserts_over_guards(self): return self.settings.prefer_deferred_runtime_asserts_over_guards + @property + def _allow_complex_guards_as_runtime_asserts(self): + return self.settings._allow_complex_guards_as_runtime_asserts + def check_equal(self, other: "ShapeEnv") -> None: """Compare another ShapeEnv for equivalence """ @@ -3310,6 +3332,7 @@ def create_unbacked_symfloat(self): self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3328,6 +3351,7 @@ def create_unbacked_symint(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3351,6 +3375,7 @@ def create_unbacked_symbool(self): self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3496,6 +3521,7 @@ def create_symbol( self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] + assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3504,6 +3530,7 @@ def create_symbol( elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3519,7 +3546,7 @@ def create_symbol( if not is_debug: maybe_more_info = ( ", for more info run with " - f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{sympy_expr}\"" + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}"' ) fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) self.log.info( @@ -3553,6 +3580,7 @@ def create_symbol( def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -3584,6 +3612,7 @@ def produce_guards( sources, source_ref=lambda n: n.name(), *, + guards: List[ShapeGuard] = None, input_contexts: Optional[DimList[SymbolicContext]] = None, # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). # (See docs on EqualityConstraint for details of the encoding.) @@ -3932,6 +3961,7 @@ def track_symfloat(source, val): self.var_to_val, set(symbol_to_constraints.keys()), self.source_name_to_debug_name, + self._allow_complex_guards_as_runtime_asserts, ) if not _simplified: @@ -4051,7 +4081,7 @@ def issue_guard(guard: ShapeGuard) -> None: # First, issue all guards. # This removes all the checks that follow from bounds # We could simply emit those and also the bounds 2 <= size when necessary - for guard in self.guards: + for guard in (guards if guards is not None else self.guards): if self._maybe_evaluate_static(guard.expr, axioms=()) is not None: continue issue_guard(guard) @@ -4135,7 +4165,7 @@ def issue_guard(guard: ShapeGuard) -> None: err = '\n'.join(error_msgs) raise ConstraintViolationError( f"Constraints violated ({debug_names})! " - "For more information, run with TORCH_LOGS=\"+dynamic\".\n" + 'For more information, run with TORCH_LOGS="+dynamic".\n' f"{err}" ) elif len(warn_msgs) > 0: @@ -4179,10 +4209,18 @@ def issue_guard(guard: ShapeGuard) -> None: with fx_traceback.preserve_node_meta(): PopulateValidator(self.graph, self.validator).run() - self._check_translation_validate() + # Only run translation validation when we are not passing custom guards + if guards is None: + self._check_translation_validate() return exprs - def produce_guards_expression(self, placeholders, ignore_static=True): + def produce_guards_expression( + self, + placeholders, + *, + guards: Optional[List[ShapeGuard]] = None, + ignore_static=True + ): """ Expected to be used with evaluate_guards_expression(). Produces the guards for the given placeholders and returns a string expression to be evaluated @@ -4190,9 +4228,14 @@ def produce_guards_expression(self, placeholders, ignore_static=True): """ from torch._dynamo.source import LocalSource arg_names = [f"t{i}" for i in range(len(placeholders))] - guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static) - if guards: - return " and ".join(guards) + produced_guards = self.produce_guards( + placeholders, + [LocalSource(a) for a in arg_names], + guards=guards, + ignore_static=ignore_static, + ) + if produced_guards: + return " and ".join(produced_guards) return None def evaluate_guards_expression(self, code, args): @@ -4211,6 +4254,18 @@ def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True): return self.evaluate_guards_expression(code, args) return True + def get_pruned_guards(self, symints): + """ + Get a list of guards, but pruned so it only provides guards that + reference symints from the passed in input + """ + symints = {s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)} + guards = [] + for g in self.guards: + if all(s in symints for s in g.expr.free_symbols): + guards.append(g) + return guards + def bind_symbols(self, placeholders, args): """ Given a paired list of placeholders (fake tensors with @@ -4280,7 +4335,8 @@ def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRa # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sympy.oo) + var_to_range[x] = ValueRanges(2, sys.maxsize - 1) + assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @_lru_cache @@ -4390,9 +4446,18 @@ def _maybe_evaluate_static( # Skip var_ranges logic for SingletonInt which is only used # for jagged layout NestedTensors today continue - vr = var_ranges[k] + try: + vr = var_ranges[k] + except KeyError: + log.warning("%s is not in var_ranges, defaulting to unknown range.", k) + vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= vr.upper: + vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4400,10 +4465,17 @@ def _maybe_evaluate_static( # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) + (unbacked_only and k in self.var_to_val) or + not vr.is_int ): new_range_env[k] = vr continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4435,6 +4507,14 @@ def replace(expr, repl): self.counter["sympy_recursion_error"] += 1 return None + new_expr = safe_expand(new_expr) + if new_expr.is_number: + return new_expr + + # This is bad to do, the replacement with division leaves us with + # rationals when atom.args[0] is addition, e.g., sympy will happily + # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! + """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4443,13 +4523,12 @@ def replace(expr, repl): # are still free symbols if new_expr.is_number: return new_expr + """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if expect_rational: - _assert_bound_is_rational(new_expr, out) - if out.is_singleton(): - return out.lower + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4501,7 +4580,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = base / divisor + div_replacements[fd] = CleanDiv(base, divisor) new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4584,7 +4663,7 @@ def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_resu f"{size_oblivious_result_msg}" "Potential framework code culprit (scroll up for full backtrace):\n" f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n" - "For more information, run with TORCH_LOGS=\"dynamic\"\n" + 'For more information, run with TORCH_LOGS="dynamic"\n' "For extended logs when we create symbols, also add " f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" @@ -4610,9 +4689,15 @@ def _update_var_to_range(self, symbol, vr): # Updates the range and the guards corresponding to each bound of the symbol. if symbol not in self.var_to_range: - self.var_to_range[symbol] = ValueRanges(lower, upper) + r = ValueRanges(lower, upper) + self.log.debug("_update_var_to_range %s = %s (new)", symbol, r) + self.var_to_range[symbol] = r else: - self.var_to_range[symbol] &= ValueRanges(lower, upper) + old = self.var_to_range[symbol] + new = old & ValueRanges(lower, upper) + if new != old: + self.var_to_range[symbol] = new + self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None: """ @@ -4626,6 +4711,9 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No # Precondition: a == tgt assert isinstance(a, sympy.Symbol) + if self._allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt): + return # continuing leads to placeholder shapes having complex expressions that we can't resolve + # Handles nested tensor symbolic variables which don't have # var_to_range bounds tgt_bound = None @@ -4642,7 +4730,10 @@ def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> No int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - return (x & int_range).issubset(y & int_range) + if x.is_int and y.is_int: + return (x & int_range).issubset(y & int_range) + else: + return x.issubset(y) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4661,8 +4752,15 @@ def issubset(x, y): # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) if r is not None: - b_bound = self.bound_sympy(r[1]) - self.var_to_range[b] = b_bound & self.var_to_range[b] + self.log.debug("set_replacement: solve for %s in %s == %s gives %s", b, a, tgt, r) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) + self._update_var_to_range(b, b_bound) tgt_bound = self.bound_sympy(tgt) assert issubset(tgt_bound, src_bound) @@ -4871,12 +4969,12 @@ def trivial_solve(lhs, rhs): ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) + d = q / sympy.gcd(q, c) # TODO: CleanDiv? i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -4982,7 +5080,7 @@ def _log_guard(self, prefix: str, g, forcing_spec: bool): if not is_debug: maybe_more_info = ( ", for more info run with " - f"TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"{str_g}\"" + f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' ) self.log.info( "%s %s [guard added]%s (%s)%s%s", @@ -5153,34 +5251,43 @@ def compute_concrete_val(): # is no longer necessary) self._maybe_guard_rel(g) - stack = CapturedTraceback.extract(skip=1) - guard = ShapeGuard(g, stack) - self.guards.append(guard) + if not self._allow_complex_guards_as_runtime_asserts: + # at this point, we've evaluated the concrete expr value, and have + # flipped/negated the guard if necessary. Now we know what to guard + # or defer to runtime assert on. + stack = CapturedTraceback.extract(skip=1) + guard = ShapeGuard(g, stack) + self.guards.append(guard) + else: + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + except Exception: if fresh: self._remove_fx_node(node) raise else: if not self._suppress_guards_tls(): - assert guard is not None - - self._log_guard("eval", g, forcing_spec=forcing_spec) + if guard is not None: # we might have deferred this to runtime assert + self._log_guard("eval", g, forcing_spec=forcing_spec) - for s in g.free_symbols: - self.symbol_guard_counter[s] += 1 - # Forcing_spec to avoid infinite recursion - if ( - not forcing_spec and - config.symbol_guard_limit_before_specialize is not None and - self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize - ): - # Force specialization - self.log.info( - "symbol_guard_limit_before_specialize=%s exceeded on %s", - config.symbol_guard_limit_before_specialize, - s - ) - self.evaluate_expr(s, forcing_spec=True) + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec and + config.symbol_guard_limit_before_specialize is not None and + self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s + ) + self.evaluate_expr(s, forcing_spec=True) else: self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) @@ -5304,7 +5411,6 @@ def _refine_ranges(self, expr: sympy.Expr) -> None: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) - _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/unification/core.py b/torch/fx/experimental/unification/core.py index 560ceb588924..0893c385bbc9 100644 --- a/torch/fx/experimental/unification/core.py +++ b/torch/fx/experimental/unification/core.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections.abc import Iterator # type: ignore[import] from functools import partial diff --git a/torch/fx/experimental/unification/match.py b/torch/fx/experimental/unification/match.py index dd459726917f..96583ef324de 100644 --- a/torch/fx/experimental/unification/match.py +++ b/torch/fx/experimental/unification/match.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .core import unify, reify # type: ignore[attr-defined] from .variable import isvar from .utils import _toposort, freeze diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index 2b074235f14a..2228448a71a1 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .core import unify, reify # type: ignore[attr-defined] from .dispatch import dispatch diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 6c247bd98111..7187330ead25 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .utils import _toposort, groupby from .variadic import isvariadic import operator diff --git a/torch/fx/experimental/unification/multipledispatch/core.py b/torch/fx/experimental/unification/multipledispatch/core.py index 2a8ed78e52e3..5b5bdbc96301 100644 --- a/torch/fx/experimental/unification/multipledispatch/core.py +++ b/torch/fx/experimental/unification/multipledispatch/core.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import sys diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index c46e47e5d35b..a1d28201d041 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from warnings import warn import inspect from typing_extensions import deprecated diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index 4b5ec2ed6315..0e90241cf69c 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] diff --git a/torch/fx/experimental/unification/multipledispatch/variadic.py b/torch/fx/experimental/unification/multipledispatch/variadic.py index 0f046ba55bd3..49e546e1ea26 100644 --- a/torch/fx/experimental/unification/multipledispatch/variadic.py +++ b/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .utils import typename __all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index ae159b937ec0..472cd487f62f 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import operator from functools import reduce diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 56cde39319e3..2147d6175136 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] def hashable(x): try: diff --git a/torch/fx/experimental/unification/variable.py b/torch/fx/experimental/unification/variable.py index 8f7efda3328b..66e97a3a7663 100644 --- a/torch/fx/experimental/unification/variable.py +++ b/torch/fx/experimental/unification/variable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from .utils import hashable from .dispatch import dispatch diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index 532d2784fb49..cad0a33425bf 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.graph_gradual_typechecker import Refine from torch.fx.tensor_type import TensorType from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 6dcb59db7979..871b8dd4709b 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import logging import math @@ -216,10 +217,7 @@ def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - if ndigits is not None: - raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") - + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -284,7 +282,7 @@ def wrapper(*args): operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round), + builtins.round: lift(ops.round_to_int), # Math module. math.ceil: lift(ops.ceil), @@ -350,6 +348,7 @@ def __init__( self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -358,6 +357,20 @@ def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -370,11 +383,17 @@ def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - return self._ops.round(number, ndigits) + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 96b186cc6c48..9e034278ccb1 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,8 +1,10 @@ +# mypy: allow-untyped-defs from collections import defaultdict from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name import torch.utils._pytree as pytree from . import _pytree as fx_pytree from ._compatibility import compatibility +from torch._C import _NodeIter import os import contextlib @@ -270,20 +272,8 @@ def __len__(self): return self.graph._len def __iter__(self): - root = self.graph._root - if self.direction == "_next": - cur = root._next - while cur is not root: - if not cur._erased: - yield cur - cur = cur._next - else: - assert self.direction == "_prev" - cur = root._prev - while cur is not root: - if not cur._erased: - yield cur - cur = cur._prev + assert self.direction == "_prev" or self.direction == "_next" + yield from _NodeIter(self.graph._root, self.direction == "_prev") def __reversed__(self): return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') @@ -1504,7 +1494,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: if node.graph is not self: raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') if node not in self._find_nodes_lookup_table: - raise RuntimeError(f"Node \'{node}\' is not added to the side table") + raise RuntimeError(f"Node '{node}' is not added to the side table") map_arg(node.args, lambda arg: check_arg(arg, node)) map_arg(node.kwargs, lambda arg: check_arg(arg, node)) seen_values.add(node) diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index c5d0df29b903..5fb6691dda7c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import copy import itertools diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 7ad3807f23bb..2ff29cba474d 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Iterable, List, Tuple from torch.utils._pytree import ( diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 23c006fbbd5f..61f3a6919015 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .graph_module import GraphModule from ._lazy_graph_module import _make_graph_module from .graph import Graph diff --git a/torch/fx/node.py b/torch/fx/node.py index d9af26c9207f..2e400158b551 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -11,6 +11,7 @@ import warnings from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair from .._ops import ops as _ops +from torch._C import _NodeBase if TYPE_CHECKING: from .graph import Graph @@ -60,7 +61,7 @@ @compatibility(is_backward_compatible=False) -def has_side_effect(fn: Callable) -> None: +def has_side_effect(fn: Callable) -> Callable: _side_effectful_functions.add(fn) return fn @@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str: return str(arg) @compatibility(is_backward_compatible=True) -class Node: +class Node(_NodeBase): """ ``Node`` is the data structure that represents individual operations within a ``Graph``. For the most part, Nodes represent callsites to various entities, @@ -197,6 +198,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', annotation of values in the generated code or for other types of analyses. """ + super().__init__() self.graph = graph self.name = name # unique name of value being created assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] @@ -235,10 +237,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. self.type : Optional[Any] = return_type - self._prev = self - self._next = self - self._erased = False - self._sort_key = () + self._sort_key: Any = () # If set, use this fn to print this node self._repr_fn : Optional[Callable[[Node], str]] = None @@ -247,6 +246,22 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # transformations. This metadata is preserved across node copies self.meta : Dict[str, Any] = {} + def __getstate__(self): + state = self.__dict__.copy() + state["_erased"] = self._erased + state["_prev"] = self._prev + state["_next"] = self._next + return state + + def __setstate__(self, state): + _erased = state.pop("_erased") + _prev = state.pop("_prev") + _next = state.pop("_next") + self.__dict__.update(state) + self._erased = _erased + self._prev = _prev + self._next = _next + @property def next(self) -> 'Node': """ @@ -295,6 +310,7 @@ def prepend(self, x: 'Node') -> None: psk = x._prev._sort_key nsk = x._next._sort_key if len(psk) > len(nsk): + idx: int *prefix, idx = psk[:len(nsk) + 1] x._sort_key = (*prefix, idx + 1) elif len(psk) < len(nsk): @@ -421,7 +437,7 @@ def insert_arg(self, idx : int, arg : Argument) -> None: self._args = args_left + (arg,) + args_right - _new_input_nodes = {} + _new_input_nodes: Dict[Node, None] = {} map_arg(arg, _new_input_nodes.setdefault) for new_use in _new_input_nodes.keys(): diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 142740a322bc..04be7d139da4 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import inspect import numbers @@ -183,6 +184,17 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): @compatibility(is_backward_compatible=False) def create_type_hint(x): + """ + Produces a type hint for the given argument. + + The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. + + If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass + of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. + If no such object is found, it defaults to `List[Any]`. + + If `x` is neither a `list` nor a `tuple`, it returns `x`. + """ try: if isinstance(x, (list, tuple)): # todo(chilli): Figure out the right way for mypy to handle this diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index d423de930dc7..0f48165b7dab 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index dc95a70a22a7..577f445e7b31 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Tuple, Any import torch diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index 58ee61f10089..04aadbbdc9b9 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch.fx diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 7256c41dcdec..ec2336dbdeab 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import hashlib import torch diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index f6e53f0e969a..36c59cb31af0 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, NamedTuple, Optional import torch diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py new file mode 100644 index 000000000000..503844a97aa9 --- /dev/null +++ b/torch/fx/passes/graph_transform_observer.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +import os +from typing import Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + +from .graph_drawer import FxGraphDrawer + +__all__ = ["GraphTransformObserver"] + + +@compatibility(is_backward_compatible=False) +class GraphTransformObserver: + __pass_count = 0 + + def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None): + # If log_url is None, we don't log anything + self.log_url = log_url + if self.log_url is None: + return + GraphTransformObserver.__pass_count += 1 + self.gm = gm + self.passname = passname + + self.input_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + + @classmethod + def get_current_pass_count(cls): + return cls.__pass_count + + def __enter__(self): + if self.log_url is None or self.gm is None: + return self + + self.erased_nodes = set() + self.created_nodes = set() + self.gm._register_create_node_hook(self.on_node_creation) + self.gm._register_erase_node_hook(self.on_node_erase) + + return self + + def __exit__(self, type, value, tb): + if self.log_url is None or self.gm is None: + return + + self.gm._unregister_create_node_hook(self.on_node_creation) + self.gm._unregister_erase_node_hook(self.on_node_erase) + + if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: + for e in self.input_dot_graph.get_node_list(): + if e.get_name() in self.erased_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + self.input_dot_graph.write_svg( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.svg", + ) + ) + + output_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + for e in output_dot_graph.get_node_list(): + if e.get_name() in self.created_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + output_dot_graph.write_svg( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.svg", + ) + ) + + def on_node_creation(self, node): + self.created_nodes.add(node.name) + + def on_node_erase(self, node): + self.erased_nodes.add(node.name) diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 3952bb652517..58e4e9dd09e8 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.passes.utils.fuser_utils import fuse_by_partitions import collections import itertools @@ -17,16 +18,16 @@ class Partition: def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): self.id = id - self.nodes: Set[Node] = set(nodes) if nodes is not None else set() + self.nodes = {node: None for node in nodes} if nodes is not None else dict() def __repr__(self) -> str: return str(self.nodes) def add_node(self, node: Node): - self.nodes.add(node) + self.nodes.update({node: None}) def remove_node(self, node: Node): - self.nodes.remove(node) + del self.nodes[node] def size(self): return len(self.nodes) @@ -320,12 +321,13 @@ def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: remove_node: Set[Node] = set() for node in partition.nodes: if is_non_compute_node(node) and \ - (is_transparent_input_node(node, partition.nodes, remove_node) or - is_transparent_output_node(node, partition.nodes, remove_node)): + (is_transparent_input_node(node, set(partition.nodes), remove_node) or + is_transparent_output_node(node, set(partition.nodes), remove_node)): remove_node.add(node) if len(remove_node) != 0: - partition.nodes = partition.nodes - remove_node + for node in remove_node: + partition.nodes.pop(node, None) def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: partitions = self.propose_partitions() diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index dd699ea86cde..488450ab24ec 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc from collections import namedtuple from typing import Optional diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 44de7fcc0b1b..fcf0499b9dd1 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import logging from queue import Queue diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 6d050c78f754..e250dd09a121 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index ce050f046eea..8edd3c746dbb 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import typing as t diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 55d5ea0af54d..b90f338f303d 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import wraps from inspect import unwrap from typing import Callable, List, Optional diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 6f6014b1c2af..535c63aa1bad 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.fx import Node from torch.fx._compatibility import compatibility diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index e32b5a13fb78..66b8fbe29d9f 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import operator from typing import Any, Dict, Optional, Set, TYPE_CHECKING @@ -51,6 +52,9 @@ def insert_deferred_runtime_asserts( # We hash (node_name, min_val, max_val) nodes_that_already_have_sym_constraint_range = set() + + # We hash only node name here because size don't take min/max + nodes_that_already_have_sym_constraint_size = set() # TODO this only works for top-level nodes today, also # we should potentially use it not create duplicate # assert_async nodes @@ -63,6 +67,12 @@ def insert_deferred_runtime_asserts( nodes_that_already_have_sym_constraint_range.add( (node.args[0], node.kwargs["min"], node.kwargs["max"]) ) + if ( + node.op == "call_function" + and node.target == torch.ops.aten.sym_constrain_range_for_size.default + ): + assert len(node.args) == 1 + nodes_that_already_have_sym_constraint_size.add(node.args[0]) # Import sympy locally import sympy @@ -90,16 +100,30 @@ def insert_deferred_runtime_asserts( lazy_format_graph_code(f"pre insert_deferred_runtime_asserts {name}", gm), ) + # deduplicate unassociated runtime assertions + # we could do better, some guards might be redundant, + # e.g. Eq(s0, 4) & Eq(2*s0, 8) + # but unclear how to handle all of that right now. + # TODO(pianpwk): better way of doing this + new_ras = [] + ras_exprs: Set[sympy.Expr] = set() + for ras in ras_by_symbol.pop(None, []): # type: ignore[call-overload] + if ras.expr not in ras_exprs: + new_ras.append(ras) + ras_exprs.add(ras.expr) + ras_by_symbol[None] = new_ras # type: ignore[index] + # We are going to mutate the dict symbol_to_proxy: Dict[sympy.Symbol, fx.Proxy] = {} placeholders = set() last_placeholder = None for node in graph.nodes: if node.op != "placeholder": - last_placeholder = node break + last_placeholder = node placeholders.add(node) - assert last_placeholder is not None + if last_placeholder is None: # no placeholders, just insert before first node + last_placeholder = next(iter(graph.nodes)) # Identify what symbols we need to reify. This isn't strictly needed # but helps reduce churn on the graph @@ -137,6 +161,7 @@ def add_runtime_asserts(ras): ), ) + inserted_sym_nodes = 0 # for inserting unassociated runtime asserts nodes = list(graph.nodes) for i, node in enumerate(nodes[:-1]): # Placeholders can match symbols, but when we destructure them @@ -164,6 +189,8 @@ def match_symbol(symint, cb): ): symbol_to_proxy[s] = fx.Proxy(cb()) log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s]) + nonlocal inserted_sym_nodes + inserted_sym_nodes += 1 match_symbol(example_value, lambda: node) if isinstance(t := example_value, torch.Tensor): @@ -191,8 +218,13 @@ def match_symbol(symint, cb): # Handle asserts that aren't associated with any symbol. This # doesn't really have to be in the loop as it will only run once, # it just needs to happen right after the placeholders. + # insert this after placeholders & added sym nodes, and before non-placeholders. if node not in placeholders: - add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] + last_sym_node = last_placeholder + for _ in range(inserted_sym_nodes): + last_sym_node = last_sym_node.next + with graph.inserting_before(last_sym_node.next): + add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] defs = [] @@ -315,10 +347,14 @@ def go(node, keypath): if i0 in shape_env.size_like: if export: - graph.call_function( - torch.ops.aten.sym_constrain_range_for_size.default, - (symbol_to_proxy[i0].node,), - ) + if ( + symbol_to_proxy[i0].node + not in nodes_that_already_have_sym_constraint_size + ): + graph.call_function( + torch.ops.aten.sym_constrain_range_for_size.default, + (symbol_to_proxy[i0].node,), + ) else: graph.call_function( torch._check_is_size, (symbol_to_proxy[i0].node,) diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 977741cfe62d..093d7e4071d0 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Any, Callable, Dict, List, Optional, Set from collections import OrderedDict diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 1282081af67b..38aa56064db6 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Type, Union diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index b37f8ecf1d0c..f4aa439b409d 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import copy from collections import defaultdict diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index 7dc757a9c0e5..aac071ace8c2 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional import collections from dataclasses import dataclass diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index 3bd030337df4..ba2ae45aabf5 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, Tuple from torch.fx._compatibility import compatibility diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 3423ea3dad5a..cc26dea3cc44 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from queue import SimpleQueue from typing import List, Dict, Tuple diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index 00415d10fee7..a69806829875 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass, field from collections import defaultdict import copy diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 2830f60d5eab..0f2650ea8d49 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from dataclasses import dataclass, field from torch.fx.graph import Graph from torch.fx.node import Node diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index d0bb4b55a403..3106daca0b18 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -70,8 +70,8 @@ def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: # CASE 3: The target doesn't exist as an attribute in `gm` # or `replacement` else: - raise RuntimeError("Attempted to create a \"", node.op, - "\" node during subgraph rewriting " + raise RuntimeError('Attempted to create a "', node.op, + '" node during subgraph rewriting ' f"with target {node.target}, but " "the referenced attribute does not " "exist in the replacement GraphModule") diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index c822a38ec78e..f59ed2d45baa 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.fx.experimental.unification import Var # type: ignore[attr-defined] from ._compatibility import compatibility diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a582e03979c4..4e72a8011f63 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import traceback from contextlib import contextmanager from typing import List, Any, Dict diff --git a/torch/hub.py b/torch/hub.py index 4ea92ed6be82..213a1290bebd 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import errno import hashlib @@ -234,7 +235,7 @@ def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=T try: url = _git_archive_link(repo_owner, repo_name, ref) - sys.stderr.write(f'Downloading: \"{url}\" to {cached_file}\n') + sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, progress=False) except HTTPError as err: if err.code == 300: diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index a5b9f5627ea7..6d1760fb9f4f 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from contextlib import contextmanager diff --git a/torch/jit/_async.py b/torch/jit/_async.py index 2134975bb953..bdde55adf14f 100644 --- a/torch/jit/_async.py +++ b/torch/jit/_async.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Async API. This module contains the API for parallelism in TorchScript, notably: diff --git a/torch/jit/_await.py b/torch/jit/_await.py index a79952bf3e2d..e86493512e59 100644 --- a/torch/jit/_await.py +++ b/torch/jit/_await.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._jit_internal import _Await from torch.jit._builtins import _register_builtin diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index f50e1bbfedb5..ecf0223cebe6 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import cmath import math import warnings diff --git a/torch/jit/_check.py b/torch/jit/_check.py index 0dc2cb6d37ba..8db5bb82ce3d 100644 --- a/torch/jit/_check.py +++ b/torch/jit/_check.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import inspect import textwrap diff --git a/torch/jit/_dataclass_impls.py b/torch/jit/_dataclass_impls.py index 52056ce46bea..2dc1dfba076f 100644 --- a/torch/jit/_dataclass_impls.py +++ b/torch/jit/_dataclass_impls.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Functions for synthesizing magic methods for JIT-compiled dataclasses import ast import dataclasses diff --git a/torch/jit/_decomposition_utils.py b/torch/jit/_decomposition_utils.py index fb4448e2b900..795f9da8e073 100644 --- a/torch/jit/_decomposition_utils.py +++ b/torch/jit/_decomposition_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch._ops import OpOverload, OpOverloadPacket diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index babb70eaf7cb..8ac456be482b 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch import Tensor diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index 731f28305628..8f35fc471e68 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Freezing. This is not intended to be imported directly; please use the exposed diff --git a/torch/jit/_fuser.py b/torch/jit/_fuser.py index 253682736034..7466800402d2 100644 --- a/torch/jit/_fuser.py +++ b/torch/jit/_fuser.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import List, Tuple diff --git a/torch/jit/_ir_utils.py b/torch/jit/_ir_utils.py index 028247f54011..52b953624a3a 100644 --- a/torch/jit/_ir_utils.py +++ b/torch/jit/_ir_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Union import torch diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 3b19e8438d4e..4662869e3683 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import pathlib import sys diff --git a/torch/jit/_passes/_property_propagation.py b/torch/jit/_passes/_property_propagation.py index 8ebd21e4bc10..1537f7bc4147 100644 --- a/torch/jit/_passes/_property_propagation.py +++ b/torch/jit/_passes/_property_propagation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Tools to help with tensor property propagation. diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py index 1cb4a0a93efd..5517499e9260 100644 --- a/torch/jit/_pickle.py +++ b/torch/jit/_pickle.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # These functions are referenced from the pickle archives produced by # ScriptModule.save() diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index a76a0c4a2cb0..fc37237edd30 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import functools import inspect diff --git a/torch/jit/_script.py b/torch/jit/_script.py index b77b0d2ea45f..7327a204fccc 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1097,6 +1097,7 @@ def _script_impl( "`optimize` is deprecated and has no effect. " "Use `with torch.jit.optimized_execution()` instead", FutureWarning, + stacklevel=3, ) # No-op for modules, functions, class instances that are already scripted diff --git a/torch/jit/_script.pyi b/torch/jit/_script.pyi index b43a8bc7089e..b1f39b2bc706 100644 --- a/torch/jit/_script.pyi +++ b/torch/jit/_script.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" from typing import ( Any, diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 514f23cb76d3..b9b9691401d3 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Serialization. This module contains functionality for serializing TorchScript modules, notably: diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index bef34e28239b..18b69acddc09 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/jit/_state.py b/torch/jit/_state.py index 1d75415ef80e..63df2acfdf09 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """JIT-related state. This module stores various pieces of Python-global state relating to the JIT. diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 2713a66a4499..7db856024287 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Tracing. This module contains functionality to support the JIT's tracing frontend, notably: @@ -646,10 +647,17 @@ def analyze_ts_result_with_export_result(export, trace): flat_trace = pytree.tree_leaves(trace) for orig, loaded in zip(flat_export, flat_trace): + if orig.layout != loaded.layout: + return False + # mkldnn is not supported for torch.allclose + if orig.layout == torch._mkldnn: # type: ignore[attr-defined] + return True if type(orig) != type(loaded): return False if isinstance(orig, torch.Tensor): + if orig.dtype != loaded.dtype: + return False if not torch.allclose(orig, loaded): return False else: @@ -981,6 +989,7 @@ def forward(self, x): "`optimize` is deprecated and has no effect. " "Use `with torch.jit.optimized_execution()` instead", FutureWarning, + stacklevel=2, ) from torch._utils_internal import ( @@ -1012,6 +1021,21 @@ def forward(self, x): _process_jit_trace_inputs_for_export, ) + traced_func_for_export = _trace_impl( + func, + example_inputs=example_inputs, + optimize=optimize, + check_trace=False, + check_inputs=check_inputs, + check_tolerance=check_tolerance, + strict=strict, + _force_outplace=_force_outplace, + _module_class=_module_class, + _compilation_unit=_compilation_unit, + example_kwarg_inputs=example_kwarg_inputs, + _store_inputs=_store_inputs, + ) + export_args, _ = _process_jit_trace_inputs_for_export( example_inputs, example_kwarg_inputs ) @@ -1037,7 +1061,7 @@ def _log_exportability(func_to_export, export_func, export_args, export_type): return try: - traced_result = traced_func(*export_args) + traced_result = func_to_export(*export_args) except Exception as e: _ = e log_torch_jit_trace_exportability( @@ -1065,22 +1089,22 @@ def _convert_ts_to_export_source_to_source(func, export_args): return TS2EPConverter(func, export_args).convert().module() # torch.jit.trace is noop when the original module is torch.jit.ScriptModule - if not isinstance(traced_func, torch.jit.ScriptModule): + if not isinstance(traced_func_for_export, torch.jit.ScriptModule): _log_exportability( - traced_func, + traced_func_for_export, _direct_export_and_lower, export_args, _ExportType.DIRECT_EXPORT, ) _log_exportability( - traced_func, + traced_func_for_export, _convert_ts_to_export_experimental, export_args, _ExportType.TRACE_AND_EXPORT, ) _log_exportability( - traced_func, + traced_func_for_export, _convert_ts_to_export_source_to_source, export_args, _ExportType.SOURCE_TO_SOURCE, @@ -1190,6 +1214,7 @@ def weighted_kernel_sum(self, weight): "`optimize` is deprecated and has no effect. " "Use `with torch.jit.optimized_execution()` instead", FutureWarning, + stacklevel=2, ) var_lookup_fn = _create_interpreter_name_lookup_fn(0) diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index a24fad838353..76d5ce5805b6 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import builtins import dis diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index ea834f664f4f..775120a67ccb 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast import dataclasses import inspect diff --git a/torch/jit/generate_bytecode.py b/torch/jit/generate_bytecode.py index 8e56c7665d1c..f66bf7bfc4c1 100644 --- a/torch/jit/generate_bytecode.py +++ b/torch/jit/generate_bytecode.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index 63632de23d3f..ba29b31bccc5 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import torch diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index c7c679c79456..a2500c1f1b9f 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index c06664a6cff2..3bfec99feb17 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import textwrap diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index 4e553757eab4..f8c9be4f5b06 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from textwrap import dedent from typing import Any, Dict diff --git a/torch/library.py b/torch/library.py index f771141ec436..d0a4cf24f088 100644 --- a/torch/library.py +++ b/torch/library.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from ._ops import OpOverload from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence from typing_extensions import deprecated @@ -364,8 +365,8 @@ def define(qualname, schema, *, lib=None, tags=()): if not NAMELESS_SCHEMA.fullmatch(schema): raise ValueError( f"define(qualname, schema, ...): expected schema " - f"to look like e.g. \"(Tensor x) -> Tensor\" but " - f"got \"{schema}\"") + f'to look like e.g. "(Tensor x) -> Tensor" but ' + f'got "{schema}"') lib.define(name + schema, alias_analysis="", tags=tags) @@ -570,7 +571,7 @@ def register_fake( This API may be used as a decorator (see examples). For a detailed guide on custom ops, please see - https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit + https://pytorch.org/docs/main/notes/custom_operators.html Examples: >>> import torch @@ -796,7 +797,7 @@ def inner(*args, **kwargs): raise RuntimeError( f"Operator '{qualname}' was defined in C++ and has a Python " f"fake impl. In this situation, we require there to also be a " - f"companion C++ `m.set_python_module(\"{actual_module_name}\")` " + f'companion C++ `m.set_python_module("{actual_module_name}")` ' f"call, but we could not find one. Please add that to " f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " f"operator was registered in ({cpp_filename})") diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 29df838bab54..0637f3f7b83c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -457,6 +457,8 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. +The returned eigenvalues are not guaranteed to be in any specific order. + .. note:: The eigenvalues and eigenvectors of a real matrix may be complex. """ + fr""" @@ -559,6 +561,8 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. +The returned eigenvalues are not guaranteed to be in any specific order. + .. note:: The eigenvalues of a real matrix may be complex, as the roots of a real polynomial may be complex. The eigenvalues of a matrix are always well-defined, even when the matrix is not diagonalizable. diff --git a/torch/masked/__init__.py b/torch/masked/__init__.py index e0193416ed2f..18d1b9f9e283 100644 --- a/torch/masked/__init__.py +++ b/torch/masked/__init__.py @@ -1,33 +1,34 @@ -from .maskedtensor.core import is_masked_tensor, MaskedTensor -from .maskedtensor.creation import as_masked_tensor, masked_tensor -from ._ops import ( +from torch.masked._ops import ( _canonical_dim, + _combine_input_and_mask, _generate_docstring, - _reduction_identity, - _where, _input_mask, _output_mask, - _combine_input_and_mask, - sum, - prod, - cumsum, - cumprod, + _reduction_identity, + _where, amax, amin, argmax, argmin, + cumprod, + cumsum, + log_softmax, + logaddexp, + logsumexp, mean, median, - logsumexp, - logaddexp, norm, - var, - std, + normalize, + prod, softmax, - log_softmax, softmin, - normalize, + std, + sum, + var, ) +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor + __all__ = [ "as_masked_tensor", diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index b7872a6d4cf4..26094459c171 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,15 +1,13 @@ - +# mypy: allow-untyped-defs import warnings - -# A workaround to support both TorchScript and MyPy: from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch import Tensor -from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor -from . import _docs +from torch import sym_float, Tensor from torch._prims_common import corresponding_real_dtype -from torch import sym_float +from torch.masked import _docs +from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor +from torch.masked.maskedtensor.creation import as_masked_tensor if TYPE_CHECKING: from torch.types import _dtype as DType @@ -469,7 +467,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: raise RuntimeError(f"dim={d} appears multiple times in the list of dims") if d >= ndim or d < -ndim: raise IndexError( - f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})" + f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" ) dims.append(d % ndim) return tuple(sorted(dims)) @@ -1420,7 +1418,6 @@ def median( dtype: Optional[DType] = None, mask: Optional[Tensor] = None, ) -> Tensor: - """\ {reduction_signature} {reduction_descr} @@ -1487,46 +1484,45 @@ def logaddexp( ) -> Tensor: """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor -Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` -tensor. The :attr:`input` elements are masked out according to the boolean tensor -:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor -:attr:`other_mask`. - -The shapes of a mask tensor and the tensor to be masked -don't need to match, but they must be :ref:`broadcastable -` and the dimensionality of the mask -tensor must not be greater than of the tensor to be masked. - -Args: - input (Tensor): the input tensor - other (Tensor): the second input tensor - -Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type - of returned tensor. If specified, the output tensor is - casted to :attr:`dtype` after the operation is - performed. Default: None. - input_mask (:class:`torch.Tensor`, optional): the boolean tensor - containing the binary mask of validity of :attr:`input` tensor elements. - Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. - other_mask (:class:`torch.Tensor`, optional): the boolean tensor - containing the binary mask of validity of :attr:`other` tensor elements. - Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. - -Example:: - - >>> input = torch.tensor([-100.0, -200, -300]) - >>> input - tensor([-100., -200., -300.]) - >>> other = torch.tensor([-1.0, -2, -3]) - >>> other - tensor([-1., -2., -3.]) - >>> mask = torch.tensor([True, False, True]) - >>> mask - tensor([ True, False, True]) - >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) - tensor([-1., -inf, -3.]) -""" + Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` + tensor. The :attr:`input` elements are masked out according to the boolean tensor + :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor + :attr:`other_mask`. + + The shapes of a mask tensor and the tensor to be masked + don't need to match, but they must be :ref:`broadcastable + ` and the dimensionality of the mask + tensor must not be greater than of the tensor to be masked. + + Args: + input (Tensor): the input tensor + other (Tensor): the second input tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the output tensor is + casted to :attr:`dtype` after the operation is + performed. Default: None. + input_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`input` tensor elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. + other_mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of :attr:`other` tensor elements. + Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. + + Example:: + + >>> input = torch.tensor([-100.0, -200, -300]) + >>> input + tensor([-100., -200., -300.]) + >>> other = torch.tensor([-1.0, -2, -3]) + >>> other + tensor([-1., -2., -3.]) + >>> mask = torch.tensor([True, False, True]) + >>> mask + tensor([ True, False, True]) + >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) + tensor([-1., -inf, -3.])""" if dtype is None: dtype = input.dtype if input.layout == torch.strided and other.layout == torch.strided: @@ -1586,7 +1582,9 @@ def _std_var( mask: Optional[Tensor], take_sqrt: Optional[bool], ) -> Tensor: - assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given" + assert ( + unbiased is None or correction_opt is None + ), "Only one of unbiased and correction may be given" correction = 1.0 if unbiased is not None: correction = 1.0 if unbiased else 0.0 @@ -1632,8 +1630,11 @@ def _std_var( if not keepdim: count = count.reshape(total.shape) if correction != 0: - real_dtype = (corresponding_real_dtype(compute_dtype) - if compute_dtype.is_complex else compute_dtype) + real_dtype = ( + corresponding_real_dtype(compute_dtype) + if compute_dtype.is_complex + else compute_dtype + ) count = count.to(real_dtype) count = torch.subtract(count, correction) count = torch.maximum(count, count.new_zeros([])) diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 81a890af5d65..802c52aecafd 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -1,43 +1,46 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from functools import partial -from typing import Callable, Any, Dict, TYPE_CHECKING -import torch - -if TYPE_CHECKING: - import torch._ops +from typing import Any, Callable, Dict, TYPE_CHECKING -from .binary import ( - _apply_native_binary, - NATIVE_BINARY_FNS, - NATIVE_INPLACE_BINARY_FNS, -) -from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask -from .passthrough import ( - _apply_pass_through_fn, - PASSTHROUGH_FNS +import torch +from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS +from .core import ( + _get_data, + _masks_match, + _maybe_get_mask, + is_masked_tensor, + MaskedTensor, ) +from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS from .reductions import ( _apply_reduction, NATIVE_REDUCE_FNS, - TORCH_REDUCE_FNS, TENSOR_REDUCE_FNS, + TORCH_REDUCE_FNS, ) -from .unary import ( - _apply_native_unary, - NATIVE_UNARY_FNS, - NATIVE_INPLACE_UNARY_FNS, -) +from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS + + +if TYPE_CHECKING: + from torch._ops import OpOverload __all__ = [] # type: ignore[var-annotated] -def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None): +def _check_args_kwargs_length( + args, kwargs, error_prefix, len_args=None, len_kwargs=None +): if len_args is not None and len_args != len(args): - raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}") + raise ValueError( + f"{error_prefix}: len(args) must be {len_args} but got {len(args)}" + ) if len_kwargs is not None and len_kwargs != len(kwargs): - raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}") + raise ValueError( + f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}" + ) class _MaskedContiguous(torch.autograd.Function): @@ -116,7 +119,9 @@ def forward(ctx, input): raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") if input._masked_data.ndim != 2: - raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}") + raise ValueError( + f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}" + ) if input.layout == torch.sparse_csr: return input @@ -157,7 +162,11 @@ def masked_out_like(mt): _MASKEDTENSOR_FUNCTION_TABLE = {} _function_fn_apply_map = { - (tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction, + ( + tuple(NATIVE_REDUCE_FNS), + tuple(TORCH_REDUCE_FNS), + tuple(TENSOR_REDUCE_FNS), + ): _apply_reduction, } for fn_map_list, apply_fn in _function_fn_apply_map.items(): @@ -177,9 +186,11 @@ def register_function_func(ops): def foo(func, *args, **kwargs): """ + def wrapper(func): for op in ops: _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) + return wrapper @@ -190,7 +201,9 @@ def _general_function_reductions(func, *args, **kwargs): @register_function_func([torch.Tensor.where, torch.where]) def _function_where(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0 + ) return _MaskedWhere.apply(*args) @@ -214,7 +227,8 @@ def _function_to_sparse_csr(func, *args, **kwargs): return _MaskedToSparseCsr.apply(args[0]) -_MASKEDTENSOR_DISPATCH_TABLE: Dict["torch._ops.OpOverload", Callable[..., Any]] = {} +_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {} + def register_dispatch_func(aten_ops): """ @@ -227,9 +241,11 @@ def register_dispatch_func(aten_ops): def foo(func, *args, **kwargs): """ + def wrapper(func): for aten_op in aten_ops: _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) + return wrapper @@ -272,9 +288,7 @@ def layout(func, *args, **kwargs): def is_contiguous(func, *args, **kwargs): data = _get_data(args[0]) if data.is_sparse: - raise ValueError( - "MaskedTensors with sparse data do not have is_contiguous" - ) + raise ValueError("MaskedTensors with sparse data do not have is_contiguous") return func(data, *args[1:], **kwargs) @@ -301,9 +315,7 @@ def is_non_overlapping_and_dense(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten.contiguous]) def contiguous(func, *args, **kwargs): if _get_data(args[0]).is_sparse: - raise ValueError( - "MaskedTensors with sparse data do not have contiguous" - ) + raise ValueError("MaskedTensors with sparse data do not have contiguous") return _MaskedContiguous.apply(args[0]) @@ -313,9 +325,13 @@ def new_empty_strided(func, *args, **kwargs): data = _get_data(args[0]) mask = _maybe_get_mask(args[0]) if tuple(args[1]) != tuple(data.size()): - raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()") + raise ValueError( + f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()" + ) if tuple(args[2]) != tuple(data.stride()): - raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()") + raise ValueError( + f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()" + ) return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) @@ -339,7 +355,9 @@ def _to_copy(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._softmax]) def _softmax(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 + ) data = _get_data(args[0]) mask = _maybe_get_mask(args[0]) result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) @@ -359,7 +377,9 @@ def _softmax_backward_data(func, *args, **kwargs): grad, output, dim, input_dtype = args if is_masked_tensor(grad) and is_masked_tensor(output): if not _masks_match(grad, output): - raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match") + raise ValueError( + "__torch_dispatch__, {func}: expected the masks of grad and output to match" + ) grad_data = _get_data(grad) new_grad_data = torch.ops.aten._masked_softmax_backward( grad_data, @@ -370,7 +390,9 @@ def _softmax_backward_data(func, *args, **kwargs): res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) return res else: - raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors") + raise ValueError( + f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors" + ) @register_dispatch_func([torch.ops.aten.copy_]) @@ -384,7 +406,9 @@ def copy_(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten.where]) def where(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mx = args[1] @@ -400,7 +424,9 @@ def where(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_sparse]) def _to_sparse(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] @@ -415,7 +441,9 @@ def _to_sparse(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_sparse_csr]) def _to_sparse_csr(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] @@ -430,7 +458,9 @@ def _to_sparse_csr(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_dense]) def _to_dense(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] @@ -444,14 +474,18 @@ def _to_dense(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._indices]) def _indices(func, *args, **kwargs): # Assumes data is sparse - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) data = _get_data(args[0]).indices() return MaskedTensor(data, torch.ones_like(data).bool()) @register_dispatch_func([torch.ops.aten._values]) def _values(func, *args, **kwargs): - _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) + _check_args_kwargs_length( + args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0 + ) data = _get_data(args[0]).values() return MaskedTensor(data, torch.ones_like(data).bool()) diff --git a/torch/masked/maskedtensor/binary.py b/torch/masked/maskedtensor/binary.py index 087ea95916e5..7b64cfa0fbd9 100644 --- a/torch/masked/maskedtensor/binary.py +++ b/torch/masked/maskedtensor/binary.py @@ -1,8 +1,16 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import torch -from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor +from .core import ( + _map_mt_args_kwargs, + _masks_match, + _tensors_match, + _wrap_result, + is_masked_tensor, +) + __all__ = [] # type: ignore[var-annotated] @@ -79,25 +87,22 @@ def _binary_helper(fn, args, kwargs, inplace): raise ValueError("len(kwargs) must equal 0") for a in args[2:]: if torch.is_tensor(a): - raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs") + raise TypeError( + "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs" + ) if not _masks_match(*args[:2]): raise ValueError( "Input masks must match. If you need support for this, please open an issue on Github." ) - data_args, data_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x.get_data() - ) - mask_args, mask_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x.get_mask() - ) + data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) + mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) args0_layout = data_args[0].layout same_layout = ( - (torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and - (args0_layout == data_args[1].layout) - ) + torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1]) + ) and (args0_layout == data_args[1].layout) if args0_layout == torch.sparse_coo: if same_layout: @@ -106,7 +111,9 @@ def _binary_helper(fn, args, kwargs, inplace): "sparse_coo indices must match. If you need support for this, please open an issue on Github." ) if data_args[0].size() != data_args[1].size(): - raise ValueError("input1 and input2 must have the same size for binary functions.") + raise ValueError( + "input1 and input2 must have the same size for binary functions." + ) data_args[1] = data_args[1].values() diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index d2002048edd9..0933a804fcc7 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import warnings @@ -13,7 +14,7 @@ def is_masked_tensor(a): - r""" Returns True if the input is a MaskedTensor, else False + r"""Returns True if the input is a MaskedTensor, else False Args: a: any input @@ -35,7 +36,9 @@ def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08): if is_masked_tensor(a) or is_masked_tensor(b): raise ValueError("Neither `a` nor `b` can be a MaskedTensor.") if a.layout != b.layout: - raise ValueError(f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}") + raise ValueError( + f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}" + ) if a.dtype != b.dtype: b = b.type(a.dtype) @@ -108,9 +111,7 @@ def _masked_tensor_str(data, mask, formatter): formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item()) for d in data ] - max_len = max( - 8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask) - ) + max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask)) return ( "[" + ", ".join( @@ -153,13 +154,21 @@ def __new__(cls, data, mask, requires_grad=False): kwargs["requires_grad"] = requires_grad kwargs["dispatch_sizes_strides_policy"] = "strides" kwargs["dispatch_layout"] = True - warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage " - "and will change in the near future. Please open a Github issue " - "for features requests and see our documentation on the torch.masked " - "module for further information about the project."), UserWarning) + warnings.warn( + ( + "The PyTorch API of MaskedTensors is in prototype stage " + "and will change in the near future. Please open a Github issue " + "for features requests and see our documentation on the torch.masked " + "module for further information about the project." + ), + UserWarning, + ) if data.requires_grad: - warnings.warn("It is not recommended to create a MaskedTensor with a tensor that requires_grad. " - "To avoid this, you can use data.clone().detach()", UserWarning) + warnings.warn( + "It is not recommended to create a MaskedTensor with a tensor that requires_grad. " + "To avoid this, you can use data.clone().detach()", + UserWarning, + ) return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined] def _preprocess_data(self, data, mask): @@ -184,17 +193,23 @@ def _validate_members(self): data = self._masked_data mask = self.get_mask() if type(data) != type(mask): - raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}") + raise TypeError( + f"data and mask must have the same type. Got {type(data)} and {type(mask)}" + ) if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: raise TypeError(f"data layout of {data.layout} is not supported.") if data.layout == torch.sparse_coo: if not _tensors_match(data.indices(), mask.indices(), exact=True): - raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.") + raise ValueError( + "data and mask are both sparse COO tensors but do not have the same indices." + ) elif data.layout == torch.sparse_csr: if not _tensors_match( data.crow_indices(), mask.crow_indices(), exact=True ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True): - raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.") + raise ValueError( + "data and mask are both sparse CSR tensors but do not share either crow or col indices." + ) if mask.dtype != torch.bool: raise TypeError("mask must have dtype bool.") if not ( @@ -219,7 +234,8 @@ def __init__(self, data, mask, requires_grad=False): @staticmethod def _from_values(data, mask): - """ Differentiable constructor for MaskedTensor """ + """Differentiable constructor for MaskedTensor""" + class Constructor(torch.autograd.Function): @staticmethod def forward(ctx, data, mask): @@ -265,6 +281,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE + if func in _MASKEDTENSOR_FUNCTION_TABLE: return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) @@ -286,6 +303,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func = func.overloadpacket from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE + if func in _MASKEDTENSOR_DISPATCH_TABLE: return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) diff --git a/torch/masked/maskedtensor/creation.py b/torch/masked/maskedtensor/creation.py index 861984a21e1c..a013ef1beb66 100644 --- a/torch/masked/maskedtensor/creation.py +++ b/torch/masked/maskedtensor/creation.py @@ -1,21 +1,23 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from .core import MaskedTensor + __all__ = [ "as_masked_tensor", "masked_tensor", ] -"""" -These two factory functions are intended to mirror - torch.tensor - guaranteed to be a leaf node - torch.as_tensor - differentiable constructor that preserves the autograd history -""" +# These two factory functions are intended to mirror +# torch.tensor - guaranteed to be a leaf node +# torch.as_tensor - differentiable constructor that preserves the autograd history + def masked_tensor(data, mask, requires_grad=False): return MaskedTensor(data, mask, requires_grad) + def as_masked_tensor(data, mask): return MaskedTensor._from_values(data, mask) diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index 91c9e5f81830..4a2e79456c86 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates """ These are functions that should simply be applied to both mask and data. @@ -10,6 +11,7 @@ from .core import _map_mt_args_kwargs, _wrap_result + __all__ = [] # type: ignore[var-annotated] diff --git a/torch/masked/maskedtensor/reductions.py b/torch/masked/maskedtensor/reductions.py index 737f4b240beb..fedab1c12a63 100644 --- a/torch/masked/maskedtensor/reductions.py +++ b/torch/masked/maskedtensor/reductions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import warnings @@ -7,6 +8,7 @@ from .core import is_masked_tensor from .creation import as_masked_tensor, masked_tensor + __all__ = [] # type: ignore[var-annotated] @@ -159,6 +161,7 @@ def grad_reduce(*args, **kwargs): TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys()) TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys()) + def _is_reduction(fn): return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP diff --git a/torch/masked/maskedtensor/unary.py b/torch/masked/maskedtensor/unary.py index b3d5c136bfd4..790d86ef92e4 100644 --- a/torch/masked/maskedtensor/unary.py +++ b/torch/masked/maskedtensor/unary.py @@ -1,9 +1,11 @@ +# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import torch from .core import _map_mt_args_kwargs, _wrap_result + __all__ = [] # type: ignore[var-annotated] @@ -108,18 +110,18 @@ def _unary_helper(fn, args, kwargs, inplace): if len(kwargs) != 0: - raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. " - "If you need support for this, please open an issue on Github.") + raise ValueError( + "MaskedTensor unary ops require that len(kwargs) == 0. " + "If you need support for this, please open an issue on Github." + ) for a in args[1:]: if torch.is_tensor(a): - raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments") - - mask_args, mask_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x._masked_mask - ) - data_args, data_kwargs = _map_mt_args_kwargs( - args, kwargs, lambda x: x._masked_data - ) + raise TypeError( + "MaskedTensor unary ops do not support additional Tensor arguments" + ) + + mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask) + data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data) if args[0].layout == torch.sparse_coo: data_args[0] = data_args[0].coalesce() diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index 6118c2b05686..5c61eaf91bd0 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python. Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased @@ -128,6 +129,16 @@ def driver_allocated_memory() -> int: return torch._C._mps_driverAllocatedMemory() +def recommended_max_memory() -> int: + r"""Returns recommended max Working set size for GPU memory in bytes. + + .. note:: + Recommended max working set size for Metal. + returned from device.recommendedMaxWorkingSetSize. + """ + return torch._C._mps_recommendedMaxMemory() + + from . import profiler from .event import Event @@ -144,4 +155,5 @@ def driver_allocated_memory() -> int: "driver_allocated_memory", "Event", "profiler", + "recommended_max_memory", ] diff --git a/torch/mps/event.py b/torch/mps/event.py index a206b640ef4a..d619c027480c 100644 --- a/torch/mps/event.py +++ b/torch/mps/event.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index 9094a275136c..d9ca3f55c5e6 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 4007f0e584f2..f9554a9bcb27 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package enables an interface for accessing MTIA backend in python """ @@ -159,6 +160,18 @@ def set_stream(stream: Stream): torch._C._mtia_setCurrentStream(stream) +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._accelerator_hooks_set_current_device(device) + + class device: r"""Context-manager that changes the selected device. @@ -256,6 +269,7 @@ def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: "current_device", "current_stream", "default_stream", + "set_device", "set_stream", "stream", "device", diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 8cbb1fb07ff8..5d69bc7daa1a 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module. It registers custom reducers, that use shared memory to provide shared diff --git a/torch/multiprocessing/_atfork.py b/torch/multiprocessing/_atfork.py index 92a3280fee78..37ebe377838d 100644 --- a/torch/multiprocessing/_atfork.py +++ b/torch/multiprocessing/_atfork.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys __all__ = ["register_after_fork"] diff --git a/torch/multiprocessing/queue.py b/torch/multiprocessing/queue.py index 99da145e75f1..876bf8d0e745 100644 --- a/torch/multiprocessing/queue.py +++ b/torch/multiprocessing/queue.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import multiprocessing.queues import pickle diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index f5eb0a6abd86..9de36c39d7b5 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import multiprocessing import os import threading diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 7c5a0896b436..408a3908cf45 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import multiprocessing import multiprocessing.connection @@ -277,5 +278,5 @@ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): "To use a different start_method use:\n\t\t" " torch.multiprocessing.start_processes(...)" ) - warnings.warn(msg, FutureWarning) + warnings.warn(msg, FutureWarning, stacklevel=2) return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py index ea1cce595011..0a12e14e1aff 100644 --- a/torch/nested/__init__.py +++ b/torch/nested/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 5cc6b1c75d7a..66d25eacc7ad 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Tuple import torch @@ -118,8 +119,8 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): self._metadata_cache = kwargs.get("_metadata_cache") or {} # collapsed ragged dim must always be dynamic - torch._dynamo.mark_dynamic(self, self._ragged_idx) - torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1) + torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) + torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) def values(self): # dispatch to get proper view relationship diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index d448628b7cad..f900a9a9ab01 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import operator @@ -616,16 +617,23 @@ def unbind_int(func, *args, **kwargs): values = inp.values() offsets = inp.offsets() lengths = inp.lengths() + ragged_idx = inp._ragged_idx - if inp._ragged_idx != 1: + if lengths is None: + return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1)) + + if ragged_idx <= 0: raise RuntimeError( - "unbind(): only supported for NestedTensor when jagged dimension is 1" + "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" ) - - if lengths is None: - return torch.split(values, offsets.diff().tolist()) + for i in range(lengths.shape[0]): + if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]: + raise RuntimeError( + "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension" + ) return [ - values[offsets[i] : (offsets[i] + lengths[i])] for i in range(lengths.shape[0]) + torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i]) + for i in range(lengths.shape[0]) ] diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index c393fb1bf357..b7c69c905e9a 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging from typing import Optional, Tuple diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index 3d317b7c09f2..23447d484409 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .modules import * # noqa: F403 from .parameter import ( Parameter as Parameter, diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 039d76a32f4b..6bf1ffb68e69 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ import contextlib from typing import List, Union diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index bd999ec39118..06ddd7c3dc2f 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This module implements the user facing API for flex_attention in PyTorch.""" import functools from typing import Callable @@ -101,11 +102,6 @@ def score_mod( # Some basic input validation _validate_sdpa_input(query, key, value) - # This will restriction will be removed in newer version of the kernel - if query.size(-2) != key.size(-2): - raise ValueError( - "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." - ) if query.size(-2) % 128 != 0: raise ValueError("NYI: S and L must be a multiple of 128") diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index 6662eb58f361..9785f74c6683 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Defines utilities for interacting with scaled_dot_product_attention""" import math from typing import List, Optional diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index d54ed8915789..773ed38f82e8 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Defines bias subclasses that work with scaled_dot_product_attention""" from enum import auto, IntEnum from typing import Optional @@ -249,7 +250,6 @@ def _dispatch( custom_mask_type=int(attn_mask.variant), compute_log_sumexp=compute_log_sumexp, scale=scale, - causal_diagonal=None, seqlen_k=None, )[0].transpose(1, 2) else: diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py index 5250b4bff167..3cb0f3ff57e2 100644 --- a/torch/nn/backends/thnn.py +++ b/torch/nn/backends/thnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # this is for historical pickle deserialization, it is not used otherwise def _get_thnn_function_backend(): diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index a08c7b314100..98a61bfb7c42 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Functionality for Python <-> C++ frontend inter-op.""" from torch import nn diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 805e0b40cdd2..f67e2ddee04a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3827,6 +3827,7 @@ def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners= warnings.warn( "`nn.functional.upsample` is deprecated. " "Use `nn.functional.interpolate` instead.", + stacklevel=2, ) return interpolate(input, size, scale_factor, mode, align_corners) @@ -4150,6 +4151,7 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 warnings.warn( "`nn.functional.upsample_nearest` is deprecated. " "Use `nn.functional.interpolate` instead.", + stacklevel=2, ) return interpolate(input, size, scale_factor, mode="nearest") @@ -4209,6 +4211,7 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 warnings.warn( "`nn.functional.upsample_bilinear` is deprecated. " "Use `nn.functional.interpolate` instead.", + stacklevel=2, ) return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) @@ -5088,8 +5091,10 @@ def forward(self, ...): A boolean mask where a value of True indicates that the element *should* take part in attention. A float mask of the same type as query, key, value that is added to the attention score. dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied - is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal - are set. + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 5bb847a0a727..9dec24809e24 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( Any, Callable, diff --git a/torch/nn/grad.py b/torch/nn/grad.py index 660c87fb4133..dbd38fcdd38c 100644 --- a/torch/nn/grad.py +++ b/torch/nn/grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Gradient interface.""" import torch diff --git a/torch/nn/init.py b/torch/nn/init.py index f5be081e7dd0..b3179abb4937 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file contains utilities for initializing neural network parameters.""" import math import warnings diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 669448ce4fda..0e19faa99e5c 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.distributed as dist diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 0d8911893011..3d8b65175956 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from typing import Optional, Tuple @@ -222,12 +223,14 @@ def __init__( warnings.warn( "keyword argument `min_value` is deprecated and rename to `min_val`", FutureWarning, + stacklevel=2, ) min_val = min_value if max_value is not None: warnings.warn( "keyword argument `max_value` is deprecated and rename to `max_val`", FutureWarning, + stacklevel=2, ) max_val = max_value diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 83b37696c8a7..a6c2da5f596f 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import namedtuple diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 3c48e56d5e6e..75c8b5504d46 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Any import torch diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 775a826d69cc..c82d8d7d3037 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict, abc as container_abcs from itertools import chain, islice import operator diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 4ab4c8bff9fc..fb6a1557aa71 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index eaf62d5bbeea..f1c44fd350d1 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .module import Module from typing import Tuple, Union diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index ae187e98b7e6..e6a3e1c0a3a1 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from torch import Tensor diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index c4b7459c4acd..f4be1b7db706 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools from typing import Protocol, Optional, Type, Any diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 720c1ca01c15..be2739462399 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math from typing import Any @@ -47,7 +48,7 @@ def forward(self, input: Tensor) -> Tensor: class Linear(Module): - r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`. This module supports :ref:`TensorFloat32`. diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index fb7172e9ae54..497da8218506 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .distance import PairwiseDistance from .module import Module from .. import functional as F @@ -167,8 +168,8 @@ class NLLLoss(_WeightedLoss): the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` - Shape: - - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or + Shape:: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` or :math:`()`, where each value is @@ -181,27 +182,29 @@ class NLLLoss(_WeightedLoss): Examples:: - >>> m = nn.LogSoftmax(dim=1) - >>> loss = nn.NLLLoss() - >>> # input is of size N x C = 3 x 5 + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> loss_fn = nn.NLLLoss() + >>> # input to NLLLoss is of size N x C = 3 x 5 >>> input = torch.randn(3, 5, requires_grad=True) - >>> # each element in target has to have 0 <= value < C + >>> # each element in target must have 0 <= value < C >>> target = torch.tensor([1, 0, 4]) - >>> output = loss(m(input), target) - >>> output.backward() + >>> loss = loss_fn(log_softmax(input), target) + >>> loss.backward() >>> >>> >>> # 2D loss example (used, for example, with image inputs) >>> N, C = 5, 4 - >>> loss = nn.NLLLoss() - >>> # input is of size N x C x height x width + >>> loss_fn = nn.NLLLoss() >>> data = torch.randn(N, 16, 10, 10) >>> conv = nn.Conv2d(16, C, (3, 3)) - >>> m = nn.LogSoftmax(dim=1) - >>> # each element in target has to have 0 <= value < C + >>> log_softmax = nn.LogSoftmax(dim=1) + >>> # output of conv forward is of shape [N, C, 8, 8] + >>> output = log_softmax(conv(data)) + >>> # each element in target must have 0 <= value < C >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) - >>> output = loss(m(conv(data)), target) - >>> output.backward() + >>> # input to NLLLoss is of size N x C x height (8) x width (8) + >>> loss = loss_fn(output, target) + >>> loss.backward() """ __constants__ = ['ignore_index', 'reduction'] ignore_index: int diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 58129acd48a3..f803d3f02a17 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict, namedtuple import itertools import warnings @@ -348,7 +349,7 @@ def _forward_unimplemented(self, *input: Any) -> None: instead of this since the former takes care of running the registered hooks while the latter silently ignores them. """ - raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function") + raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "forward" function') class Module: @@ -1335,22 +1336,28 @@ def _get_backward_pre_hooks(self): def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): if not isinstance(result, torch.Tensor): if not (isinstance(result, tuple) and all(isinstance(r, torch.Tensor) for r in result)): - warnings.warn("Using non-full backward hooks on a Module that does not return a " - "single Tensor or a tuple of Tensors is deprecated and will be removed " - "in future versions. This hook will be missing some of the grad_output. " - "Please use register_full_backward_hook to get the documented behavior.", - FutureWarning) + warnings.warn( + "Using non-full backward hooks on a Module that does not return a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_output. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) return else: result = (result,) if not isinstance(inputs, torch.Tensor): if not (isinstance(inputs, tuple) and all(isinstance(i, torch.Tensor) for i in inputs)): - warnings.warn("Using non-full backward hooks on a Module that does not take as input a " - "single Tensor or a tuple of Tensors is deprecated and will be removed " - "in future versions. This hook will be missing some of the grad_input. " - "Please use register_full_backward_hook to get the documented behavior.", - FutureWarning) + warnings.warn( + "Using non-full backward hooks on a Module that does not take as input a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_input. " + "Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) return else: inputs = (inputs,) @@ -1358,15 +1365,21 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): # At this point we are sure that inputs and result are tuple of Tensors out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): - warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " - "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_output.", - FutureWarning) + warnings.warn( + "Using a non-full backward hook when outputs are nested in python data structure " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output.", + FutureWarning, + stacklevel=2, + ) elif len(out_grad_fn) > 1: - warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " - "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_output. Please use register_full_backward_hook to get the documented behavior.", - FutureWarning) + warnings.warn( + "Using a non-full backward hook when outputs are generated by different autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output. Please use register_full_backward_hook to get the documented behavior.", + FutureWarning, + stacklevel=2, + ) else: # At this point the grad_output part of the hook will most likely be correct inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} @@ -1374,11 +1387,14 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): next_functions = {n[0] for n in grad_fn.next_functions} if inputs_grad_fn != next_functions: - warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " - "is deprecated and will be removed in future versions. This hook will be missing " - "some grad_input. Please use register_full_backward_hook to get the documented " - "behavior.", - FutureWarning) + warnings.warn( + "Using a non-full backward hook when the forward contains multiple autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_input. Please use register_full_backward_hook to get the documented " + "behavior.", + FutureWarning, + stacklevel=2, + ) def register_forward_pre_hook( self, @@ -1892,19 +1908,20 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): """ # TODO: Remove `args` and the parsing logic when BC allows. if len(args) > 0: - if destination is None: - destination = args[0] - if len(args) > 1 and prefix == '': - prefix = args[1] - if len(args) > 2 and keep_vars is False: - keep_vars = args[2] # DeprecationWarning is ignored by default warnings.warn( "Positional args are being deprecated, use kwargs instead. Refer to " "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" " for details.", FutureWarning, + stacklevel=2, ) + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == '': + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] if destination is None: destination = OrderedDict() diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 97c9c307c5d9..d503409d53a1 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import numbers from torch.nn.parameter import Parameter diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 0aecca58c305..4b29fbf1c8f4 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .module import Module from .utils import _pair, _quadruple, _ntuple from .. import functional as F diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 6050b7eaea60..e6136350b3a4 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -16,7 +16,7 @@ class PixelShuffle(Module): See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ - by Shi et. al (2016) for more details. + by Shi et al. (2016) for more details. Args: upscale_factor (int): factor to increase spatial resolution by @@ -69,7 +69,7 @@ class PixelUnshuffle(Module): See the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ - by Shi et. al (2016) for more details. + by Shi et al. (2016) for more details. Args: downscale_factor (int): factor to decrease spatial resolution by diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 38acd9fb430a..61ce56390981 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -380,10 +380,10 @@ class MaxUnpool2d(_MaxUnpoolNd): [ 0., 0., 0., 0.], [ 0., 14., 0., 16.]]]]) >>> # Now using output_size to resolve an ambiguous size for the inverse - >>> input = torch.torch.tensor([[[[ 1., 2., 3., 4., 5.], - [ 6., 7., 8., 9., 10.], - [11., 12., 13., 14., 15.], - [16., 17., 18., 19., 20.]]]]) + >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.], + [ 6., 7., 8., 9., 10.], + [11., 12., 13., 14., 15.], + [16., 17., 18., 19., 20.]]]]) >>> output, indices = pool(input) >>> # This call will not work without specifying output_size >>> unpool(output, indices, output_size=input.size()) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index b4bdd7824474..8ba4f9f08319 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import warnings import numbers diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index f053a0c8f3c2..512b17d03222 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 3c9a8547df32..f5980cd6b1e8 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy from typing import Optional, Any, Union, Callable diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index da9b23add18d..7d674da0d5c3 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .module import Module from .. import functional as F diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 019dabe3e533..4a051ed1eba5 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections from itertools import repeat from typing import List, Dict, Any diff --git a/torch/nn/parallel/__init__.py b/torch/nn/parallel/__init__.py index adcd6bd838eb..8f08e5099d8b 100644 --- a/torch/nn/parallel/__init__.py +++ b/torch/nn/parallel/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing_extensions import deprecated from .parallel_apply import parallel_apply diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 2e090f123c34..b907de4004b1 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings import torch from torch.cuda import nccl @@ -228,6 +229,7 @@ def gather(tensors, dim=0, destination=None, *, out=None): 'Using -1 to represent CPU tensor is deprecated. Please use a ' 'device object or string instead, e.g., "cpu".', FutureWarning, + stacklevel=2, ) destination = _get_device_index(destination, allow_cpu=True, optional=True) return torch._C._gather(tensors, dim, destination) diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 4471cee6f379..3980706a932a 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import operator import torch import warnings diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 5f2013664f56..80ed52d9a0b6 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import inspect @@ -548,7 +549,8 @@ class DistributedDataParallel(Module, Joinable): multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. :attr:`bucket_cap_mb` controls the bucket size in - MegaBytes (MB). (default: 25) + MebiBytes (MiB). If ``None``, a default size of 25 MiB + will be used. (default: ``None``) find_unused_parameters (bool): Traverse the autograd graph from all tensors contained in the return value of the wrapped module's ``forward`` function. Parameters @@ -631,7 +633,7 @@ def __init__( dim=0, broadcast_buffers=True, process_group=None, - bucket_cap_mb=25, + bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, @@ -773,6 +775,7 @@ def __init__( "The `check_reduction` argument in `DistributedDataParallel` " "module is deprecated. Please avoid using it.", FutureWarning, + stacklevel=2, ) # Check that a module does not have Uninitialized parameters @@ -787,14 +790,21 @@ def __init__( self.broadcast_bucket_size = int(250 * 1024 * 1024) # reduction bucket size + if bucket_cap_mb is None: + # default case (bucket cap is 25 MiB) + bucket_cap_mb = 25 + self.bucket_bytes_cap_default = True + else: + self.bucket_bytes_cap_default = False self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + # Whether to perform input tensor CPU to GPU copies on a side-stream self.use_side_stream_for_tensor_copies = ( os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1" ) # Initialize gradient buffers and register all reduce hook - self._delay_grad_buffer = None + self._delay_grad_buffer: Optional[torch.Tensor] = None self._delay_grad_views: List[torch.Tensor] = [] self._delay_all_reduce_all_params = False if len(self._delay_all_reduce_params) != 0: @@ -1155,10 +1165,13 @@ def _ddp_init_helper( if static_graph is True or self.find_unused_parameters is False: bucket_size_limits = [sys.maxsize] else: - bucket_size_limits = [ - dist._DEFAULT_FIRST_BUCKET_BYTES, - self.bucket_bytes_cap, - ] + if self.bucket_bytes_cap_default: + bucket_size_limits = [ + dist._DEFAULT_FIRST_BUCKET_BYTES, + self.bucket_bytes_cap, + ] + else: + bucket_size_limits = [self.bucket_bytes_cap] ( bucket_indices, per_bucket_size_limits, @@ -1194,7 +1207,9 @@ def _ddp_init_helper( param_to_name_mapping, # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first # bucket. - dist._DEFAULT_FIRST_BUCKET_BYTES, + dist._DEFAULT_FIRST_BUCKET_BYTES + if self.bucket_bytes_cap_default + else self.bucket_bytes_cap, ) self.logger = dist.Logger(self.reducer) @@ -1467,7 +1482,7 @@ def _lazy_init(self): def _should_disable_cpp_reducer(self) -> bool: return self._use_python_reducer and ( - torch.compiler.is_compiling() or self._force_to_disable_cpp_reducer + torch._utils.is_compiling() or self._force_to_disable_cpp_reducer ) def _pre_forward(self, *inputs, **kwargs): @@ -1480,7 +1495,7 @@ def _pre_forward(self, *inputs, **kwargs): h.remove() self._accum_grad_hooks.clear() - if not self._lazy_init_ran and not torch.compiler.is_compiling(): + if not self._lazy_init_ran and not torch._utils.is_compiling(): self._lazy_init() if self._delay_all_reduce_all_params: diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index 016a6fbd0c40..fbe12d23ee8b 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -7,8 +7,8 @@ from collections import OrderedDict if TYPE_CHECKING: - import torch.jit - import torch.jit._state + from torch.jit import ScriptModule + from torch.jit._state import EnabledProxy __all__ = ['replicate'] @@ -22,12 +22,12 @@ def _is_script_method(module: Module) -> bool: return isinstance(module, torch._C.ScriptMethod) -def _init_script_module() -> "torch.jit.ScriptModule": +def _init_script_module() -> "ScriptModule": import torch.jit return torch.jit.ScriptModule() -def _is_jit_enabled() -> "torch.jit._state.EnabledProxy": +def _is_jit_enabled() -> "EnabledProxy": import torch.jit._state return torch.jit._state._enabled diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index f6fb9d47ecbf..73e753760e72 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload from typing_extensions import deprecated diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 219bb6d4efa2..221ffacc3520 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins from typing import Optional, Tuple diff --git a/torch/nn/utils/_deprecation_utils.py b/torch/nn/utils/_deprecation_utils.py index 1b2a9b6e29f2..9910db96e66c 100644 --- a/torch/nn/utils/_deprecation_utils.py +++ b/torch/nn/utils/_deprecation_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Callable import importlib import warnings diff --git a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py index c10ccb90ae92..147346796d1f 100644 --- a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index b675e3b892bd..2836809d40be 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py index c7956a3a1b1f..593fa9e5eed7 100644 --- a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F from .expanded_weights_impl import implements_per_sample_grads diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 94e6041c6de5..664e65cc7d90 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager import torch diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index 249dbe591204..840be6a163f5 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py index fe29b1eafbe2..6e2919803e4f 100644 --- a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import reduce import operator import torch diff --git a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py index f3e68b940660..1d0f40c54081 100644 --- a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from functools import partial import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py index f2ead2d4c08f..b18c284cd7cf 100644 --- a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F diff --git a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py index c2cbae63f336..6a80c1dc9219 100644 --- a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch import torch.nn.functional as F from .expanded_weights_impl import implements_per_sample_grads diff --git a/torch/nn/utils/_per_sample_grad.py b/torch/nn/utils/_per_sample_grad.py index 0644ab5d2535..a64942083f0c 100644 --- a/torch/nn/utils/_per_sample_grad.py +++ b/torch/nn/utils/_per_sample_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import torch diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 4ac8a4e7445b..cc83353909f9 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from typing import Union, Iterable, List, Dict, Tuple, Optional, cast from typing_extensions import deprecated diff --git a/torch/nn/utils/init.py b/torch/nn/utils/init.py index 416ad0db8ef7..4768d3009005 100644 --- a/torch/nn/utils/init.py +++ b/torch/nn/utils/init.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import torch diff --git a/torch/nn/utils/memory_format.py b/torch/nn/utils/memory_format.py index c8fc22bea51c..aaa2b6bfb198 100644 --- a/torch/nn/utils/memory_format.py +++ b/torch/nn/utils/memory_format.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index f9b25bcac0cb..cf686504072f 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from enum import Enum, auto import torch diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index f512b7c3b22a..b828c1d230f1 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.__future__ import get_swap_module_params_on_conversion from torch.nn.modules.container import ModuleList, ModuleDict, Module diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 0375106d69e0..b0e1f99a6c1f 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Pruning methods.""" import numbers from abc import ABC, abstractmethod diff --git a/torch/nn/utils/rnn.pyi b/torch/nn/utils/rnn.pyi index fed87febe2a6..9ffc650714ff 100644 --- a/torch/nn/utils/rnn.pyi +++ b/torch/nn/utils/rnn.pyi @@ -1,14 +1,6 @@ -from typing import ( - Any, - Iterable, - NamedTuple, - Optional, - overload, - Sequence, - Tuple, - TypeVar, - Union, -) +# mypy: allow-untyped-defs +from typing import Any, Iterable, NamedTuple, Optional, overload, Sequence, Tuple, Union + from typing_extensions import Self from torch import Tensor @@ -24,8 +16,6 @@ class PackedSequence_(NamedTuple): def bind(optional: Any, fn: Any): ... -_T = TypeVar("_T") - class PackedSequence(PackedSequence_): def __new__( cls, @@ -34,39 +24,39 @@ class PackedSequence(PackedSequence_): sorted_indices: Optional[Tensor] = ..., unsorted_indices: Optional[Tensor] = ..., ) -> Self: ... - def pin_memory(self: _T) -> _T: ... - def cuda(self: _T, *args: Any, **kwargs: Any) -> _T: ... - def cpu(self: _T) -> _T: ... - def double(self: _T) -> _T: ... - def float(self: _T) -> _T: ... - def half(self: _T) -> _T: ... - def long(self: _T) -> _T: ... - def int(self: _T) -> _T: ... - def short(self: _T) -> _T: ... - def char(self: _T) -> _T: ... - def byte(self: _T) -> _T: ... + def pin_memory(self: Self) -> Self: ... + def cuda(self: Self, *args: Any, **kwargs: Any) -> Self: ... + def cpu(self: Self) -> Self: ... + def double(self: Self) -> Self: ... + def float(self: Self) -> Self: ... + def half(self: Self) -> Self: ... + def long(self: Self) -> Self: ... + def int(self: Self) -> Self: ... + def short(self: Self) -> Self: ... + def char(self: Self) -> Self: ... + def byte(self: Self) -> Self: ... @overload def to( - self: _T, + self: Self, dtype: _dtype, non_blocking: bool = False, copy: bool = False, - ) -> _T: ... + ) -> Self: ... @overload def to( - self: _T, + self: Self, device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None, non_blocking: bool = False, copy: bool = False, - ) -> _T: ... + ) -> Self: ... @overload def to( - self: _T, + self: Self, other: Tensor, non_blocking: bool = False, copy: bool = False, - ) -> _T: ... + ) -> Self: ... @property def is_cuda(self) -> bool: ... def is_pinned(self) -> bool: ... diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index bda54b9a1222..fcc4bbf5fe29 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Spectral Normalization from https://arxiv.org/abs/1802.05957.""" import torch from torch.nn.functional import normalize diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 660a1a484ebb..07b03c04a120 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from collections import defaultdict from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union diff --git a/torch/nn/utils/weight_norm.py b/torch/nn/utils/weight_norm.py index 6cfe4b3e526d..abb21a7b4672 100644 --- a/torch/nn/utils/weight_norm.py +++ b/torch/nn/utils/weight_norm.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Weight Normalization from https://arxiv.org/abs/1602.07868.""" from torch.nn.parameter import Parameter, UninitializedParameter from torch import _weight_norm, norm_except_dim diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 4d16ef09c8b3..2b2f2bdae0de 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch import _C from torch._C import _onnx as _C_onnx from torch._C._onnx import ( diff --git a/torch/onnx/_deprecation.py b/torch/onnx/_deprecation.py index 0fd2cd764fc9..1f78dd55bd5d 100644 --- a/torch/onnx/_deprecation.py +++ b/torch/onnx/_deprecation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utility for deprecating functions.""" import functools diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py index f827d12be7fb..22c05075dba8 100644 --- a/torch/onnx/_globals.py +++ b/torch/onnx/_globals.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Globals used internally by the ONNX exporter. Do not use this module outside of `torch.onnx` and its tests. diff --git a/torch/onnx/_internal/_beartype.py b/torch/onnx/_internal/_beartype.py index 25e1c1cb7299..1e5006fb56c1 100644 --- a/torch/onnx/_internal/_beartype.py +++ b/torch/onnx/_internal/_beartype.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """An internal wrapper for the beartype library. The module returns a no-op decorator when the beartype library is not installed. diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 09079d5e9c4a..e5b22b07539c 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`.""" from __future__ import annotations diff --git a/torch/onnx/_internal/diagnostics/_rules.py b/torch/onnx/_internal/diagnostics/_rules.py index 0bfda96c5bce..3b2ca727d0d1 100644 --- a/torch/onnx/_internal/diagnostics/_rules.py +++ b/torch/onnx/_internal/diagnostics/_rules.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ GENERATED CODE - DO NOT EDIT DIRECTLY This file is generated by gen_diagnostics.py. diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index c118f3e5ae14..e51c99a3151b 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file defines an additional layer of abstraction on top of the SARIF OM.""" from __future__ import annotations diff --git a/torch/onnx/_internal/diagnostics/infra/context.py b/torch/onnx/_internal/diagnostics/infra/context.py index 22370850df86..f670adc2cae2 100644 --- a/torch/onnx/_internal/diagnostics/infra/context.py +++ b/torch/onnx/_internal/diagnostics/infra/context.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """A diagnostic context based on SARIF.""" from __future__ import annotations @@ -21,6 +22,8 @@ TypeVar, ) +from typing_extensions import Self + from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version @@ -92,24 +95,24 @@ def sarif(self) -> sarif.Result: ) return sarif_result - def with_location(self: _Diagnostic, location: infra.Location) -> _Diagnostic: + def with_location(self: Self, location: infra.Location) -> Self: """Adds a location to the diagnostic.""" self.locations.append(location) return self def with_thread_flow_location( - self: _Diagnostic, location: infra.ThreadFlowLocation - ) -> _Diagnostic: + self: Self, location: infra.ThreadFlowLocation + ) -> Self: """Adds a thread flow location to the diagnostic.""" self.thread_flow_locations.append(location) return self - def with_stack(self: _Diagnostic, stack: infra.Stack) -> _Diagnostic: + def with_stack(self: Self, stack: infra.Stack) -> Self: """Adds a stack to the diagnostic.""" self.stacks.append(stack) return self - def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic: + def with_graph(self: Self, graph: infra.Graph) -> Self: """Adds a graph to the diagnostic.""" self.graphs.append(graph) return self diff --git a/torch/onnx/_internal/diagnostics/infra/decorator.py b/torch/onnx/_internal/diagnostics/infra/decorator.py index 0ac803815703..67066f5da500 100644 --- a/torch/onnx/_internal/diagnostics/infra/decorator.py +++ b/torch/onnx/_internal/diagnostics/infra/decorator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index 7eefc5a917b0..ac62de0214dd 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions) annotations, ) @@ -1276,6 +1277,12 @@ def export(self) -> ONNXProgram: "ONNXScript optimizer is not available. Skipping optimization. " "Please `pip install onnxscript -U` to enable post-export optimization." ) + except Exception as e: + warnings.warn( + "ONNXScript optimizer failed. Skipping optimization. " + "\n\nPLEASE REPORT A BUG AT https://github.com/microsoft/onnxscript/issues " + f"\n\nDetail:\n{e}" + ) return torch.onnx.ONNXProgram( onnx_model, diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index 69fa023b9add..cef8e045f7fb 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import abc diff --git a/torch/onnx/_internal/fx/analysis/unsupported_nodes.py b/torch/onnx/_internal/fx/analysis/unsupported_nodes.py index 5da0dbed3d91..deec2a85e1da 100644 --- a/torch/onnx/_internal/fx/analysis/unsupported_nodes.py +++ b/torch/onnx/_internal/fx/analysis/unsupported_nodes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses diff --git a/torch/onnx/_internal/fx/decomposition_skip.py b/torch/onnx/_internal/fx/decomposition_skip.py index 7fb971a3307a..646e0765f190 100644 --- a/torch/onnx/_internal/fx/decomposition_skip.py +++ b/torch/onnx/_internal/fx/decomposition_skip.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """A context manager that disables the decomposition of certain ops during dynamo tracing. The approach is to temporarily hijack the operator callable with PT2 custom operator. diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py index 4f3f705ca867..027d580717af 100644 --- a/torch/onnx/_internal/fx/decomposition_table.py +++ b/torch/onnx/_internal/fx/decomposition_table.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Dispatcher for AtenLib functions from onnx-script.""" from __future__ import annotations @@ -111,4 +112,12 @@ def create_onnx_friendly_decomposition_table( ): continue decomposition_table[op_overload] = decomp_fn + + # NOTE: There are ops in core ATen and under torch._refs, + # that are not decomposed to prim::ops. We need to pick them + # back + for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): + if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: + continue + decomposition_table[op_overload] = decomp_fn return decomposition_table diff --git a/torch/onnx/_internal/fx/diagnostics.py b/torch/onnx/_internal/fx/diagnostics.py index 11e4c79f2e1a..0be358751c11 100644 --- a/torch/onnx/_internal/fx/diagnostics.py +++ b/torch/onnx/_internal/fx/diagnostics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index fbc7d92e043f..1379a0613895 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE: This file is referenced by name at # /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. # introduced by https://github.com/pytorch/pytorch/pull/98894. diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index 50ead7556f37..a0be86e11d6b 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import inspect diff --git a/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py b/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py index 18dc84e19585..1d7d191cbd25 100644 --- a/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py +++ b/torch/onnx/_internal/fx/fx_symbolic_graph_extractor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py index 2986ac279ec3..3886733093a3 100644 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Dispatcher for AtenLib functions from onnx-script.""" from __future__ import annotations diff --git a/torch/onnx/_internal/fx/op_validation.py b/torch/onnx/_internal/fx/op_validation.py index b306bc2141de..01161aee25ea 100644 --- a/torch/onnx/_internal/fx/op_validation.py +++ b/torch/onnx/_internal/fx/op_validation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Module for handling op-level validation during exporting.""" from __future__ import annotations diff --git a/torch/onnx/_internal/fx/passes/_utils.py b/torch/onnx/_internal/fx/passes/_utils.py index 92a883469a52..6e49bccfcfaf 100644 --- a/torch/onnx/_internal/fx/passes/_utils.py +++ b/torch/onnx/_internal/fx/passes/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Common utility functions for FX passes. These functions should NOT be directly invoked outside of `passes` package. diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py index b9a131b97466..5185b1152485 100644 --- a/torch/onnx/_internal/fx/passes/decomp.py +++ b/torch/onnx/_internal/fx/passes/decomp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py index 21f2691cbb8e..dfdee6e88c85 100644 --- a/torch/onnx/_internal/fx/passes/functionalization.py +++ b/torch/onnx/_internal/fx/passes/functionalization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import contextlib diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index b7c3b90cab66..6e1352f73046 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import abc diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py index 64887ad2ee6e..2b3518b79ea6 100644 --- a/torch/onnx/_internal/fx/passes/readability.py +++ b/torch/onnx/_internal/fx/passes/readability.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from typing import Dict, List, Sequence, Tuple, Union diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 944cad4acf1c..bc584ff32925 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: onnx"] from __future__ import annotations diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py index 66ca69d7a70f..cd77b6eec18b 100644 --- a/torch/onnx/_internal/fx/passes/virtualization.py +++ b/torch/onnx/_internal/fx/passes/virtualization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations from typing import List, Optional, Tuple diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py index ee919eae00d1..dbd8fb591126 100644 --- a/torch/onnx/_internal/fx/patcher.py +++ b/torch/onnx/_internal/fx/patcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import functools import io diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py index 726bf4219330..5739442163a3 100644 --- a/torch/onnx/_internal/fx/serialization.py +++ b/torch/onnx/_internal/fx/serialization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import io diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py index fb3f0e99a6d6..a825e466f1aa 100644 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ b/torch/onnx/_internal/fx/torch_export_graph_extractor.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE: This file is referenced by name at # /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. # introduced by https://github.com/pytorch/pytorch/pull/98894. diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index b7f3d6cea642..3aac02a51214 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for converting and operating on ONNX, JIT and torch types.""" from __future__ import annotations @@ -22,7 +23,7 @@ from torch._subclasses import fake_tensor if TYPE_CHECKING: - import onnx.defs.OpSchema.AttrType # type: ignore[import] + import onnx.defs.OpSchema.AttrType # type: ignore[import] # noqa: TCH004 # Enable both TorchScriptTensor and torch.Tensor to be tested diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 2f8c9202d7bb..12100d0f489c 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import inspect diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index 719f4b0c16e8..13ae4209da5d 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for manipulating the torch.Graph object and the torchscript.""" from __future__ import annotations diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py index 278af3feacc6..40eb1bd8d64e 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" from __future__ import annotations diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index aa3495ee5ac5..d8a7e55e8f9e 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import dataclasses import importlib import logging diff --git a/torch/onnx/_internal/registration.py b/torch/onnx/_internal/registration.py index 017a2fb7dadf..f051708f864d 100644 --- a/torch/onnx/_internal/registration.py +++ b/torch/onnx/_internal/registration.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Module for handling symbolic function registration.""" import warnings @@ -61,7 +62,7 @@ def _dispatch_opset_version( _V = TypeVar("_V") -class OverrideDict(Generic[_K, _V], Collection[_K]): +class OverrideDict(Collection[_K], Generic[_K, _V]): """A dictionary that merges built-in and custom symbolic functions. It supports overriding and un-overriding built-in symbolic functions with custom diff --git a/torch/onnx/_onnx_supported_ops.py b/torch/onnx/_onnx_supported_ops.py index 2611b0d81e9b..e2707298d6d9 100644 --- a/torch/onnx/_onnx_supported_ops.py +++ b/torch/onnx/_onnx_supported_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from typing import Dict, List, Union diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py index d13232507317..d9b647c807f3 100644 --- a/torch/onnx/_type_utils.py +++ b/torch/onnx/_type_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Utilities for converting and operating on ONNX, JIT and torch types.""" from __future__ import annotations diff --git a/torch/onnx/operators.py b/torch/onnx/operators.py index e5f12444c355..88ac6779f91c 100644 --- a/torch/onnx/operators.py +++ b/torch/onnx/operators.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""This file provides a location for operators that help exporting models via onnx. E.g. `shape_as_tensor` and `reshape_from_tensor_shape` @@ -13,8 +14,34 @@ def shape_as_tensor(x): + """Get the shape of a tensor as a tensor. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x. + + Example: + >>> x = torch.randn(2, 3) + >>> shape_as_tensor(x) + tensor([2, 3]) + + """ return torch._shape_as_tensor(x) def reshape_from_tensor_shape(x, shape): + """Reshape a tensor to the given shape. + + This function is used to make dynamic size operations traceable when exporting models via ONNX. + This function is kept for backward-compatibility. It is implemented directly in ATen. + + Parameters: + x (Tensor): the tensor to be reshaped. + shape (Tensor): the target shape. + + Returns: + Tensor: the reshaped tensor. + """ return torch._reshape_from_tensor(x, shape) diff --git a/torch/onnx/symbolic_caffe2.py b/torch/onnx/symbolic_caffe2.py index 3398fcd2fe10..ed2dc6cd9fdb 100644 --- a/torch/onnx/symbolic_caffe2.py +++ b/torch/onnx/symbolic_caffe2.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import importlib import inspect diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 4430babaef00..676c3d68048b 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 6fd576822e2c..e9ba8b4015f2 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 99d5064ad7a0..e562d5a47567 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 11.""" from __future__ import annotations diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 130b02a889b0..5a6bf720df36 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import functools diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 5bba817bbce0..bb7045c0f58b 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 1b4b8ee7917c..62e05910dd72 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 14. Note [ONNX operators that are added/updated in opset 14] diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index 4f316a77f62e..793c1cad8fb9 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 15. Note [ONNX operators that are added/updated in opset 15] diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index 24306b475366..cd5829ada850 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 16. Note [ONNX Operators that are added/updated in opset 16] diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index 3aad249a1126..44c789017d75 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 17. Note [ONNX Operators that are added/updated in opset 17] @@ -26,7 +27,7 @@ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md -__all__ = ["layer_norm", "stft"] +__all__ = ["layer_norm", "stft", "quantized_layer_norm"] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) @@ -67,6 +68,24 @@ def layer_norm( ) +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + def _compute_edge_sizes(n_fft, window_size): """Helper function to compute the sizes of the edges (left and right) of a given window centered within an FFT size.""" diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py index d80361dd417f..68e14c987731 100644 --- a/torch/onnx/symbolic_opset18.py +++ b/torch/onnx/symbolic_opset18.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 18. Note [ONNX Operators that are added/updated in opset 18] diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index 9c81bc3e3c49..9557b5f2828e 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 20. Note [ONNX Operators that are added/updated in opset 20] diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index 0537e8a92888..c647ead4e297 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Note [ONNX operators that are added/updated from opset 7 to opset 8] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index b2fbee3b9784..87b4be230e78 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Note [ONNX operators that are added/updated from opset 8 to opset 9] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 95e8fcef391f..b4c937ed3f66 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1,8 +1,10 @@ +# mypy: allow-untyped-defs """This file exports ONNX ops for opset 9. Opset 9 is supported by ONNX release 1.4.1 release on 01/23/19 """ + from __future__ import annotations import builtins @@ -343,6 +345,20 @@ def reshape_as(g: jit_utils.GraphContext, self, other): @_onnx_symbolic("aten::add") @_beartype.beartype def add(g: jit_utils.GraphContext, self, other, alpha=None): + """ + This function takes the add function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (float, optional): The scaling factor for the second operand. Defaults to None. + + Returns: + ONNX operator. + """ if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): return symbolic_helper._onnx_opset_unsupported_detailed( "Add", 9, 11, "Add between list of tensors not supported", self @@ -355,6 +371,21 @@ def add(g: jit_utils.GraphContext, self, other, alpha=None): @_onnx_symbolic("aten::sub") @_beartype.beartype def sub(g: jit_utils.GraphContext, self, other, alpha=None): + """ + Consumes sub function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (Optional[Tensor]): A scaling factor to apply to the second operand. + If `alpha` is not provided, it defaults to 1. + + Returns: + ONNX operator + """ if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: other = g.op("Mul", other, alpha) return g.op("Sub", self, other) @@ -521,6 +552,16 @@ def reciprocal(g: jit_utils.GraphContext, self): @symbolic_helper.parse_args("v", "i") @_beartype.beartype def cat(g: jit_utils.GraphContext, tensor_list, dim): + """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. + + Parameters: + g (jit_utils.GraphContext): Graph context. + tensor_list (List[torch.Tensor]): List of tensors to concatenate. + dim (int): Dimension along which to concatenate the tensors. + + Returns: + ONNX graph node representing the concatenated tensor. + """ tensors = symbolic_helper._unpack_list(tensor_list) # torch.cat ignores empty tensors such as `torch.Tensor([])` # These needs to be removed as input from ONNX's concat too, otherwise shape inference @@ -746,6 +787,16 @@ def atan2(g: jit_utils.GraphContext, self, other): @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) @_beartype.beartype def sigmoid(g: jit_utils.GraphContext, self): + """Converts the corresponding PyTorch function into ONNX operators. + + It is not meant to be called directly by a user. + + Args: + g (jit_utils.GraphContext): Graph context. + self (Tensor): the input tensor. + Returns: + ONNX operator + """ return g.op("Sigmoid", self) @@ -849,6 +900,7 @@ def numpy_T(g: jit_utils.GraphContext, input): @symbolic_helper.quantized_args(True) @_beartype.beartype def expand(g: jit_utils.GraphContext, self, size, implicit): + """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" size = symbolic_helper._maybe_get_const(size, "is") if not symbolic_helper._is_value(size): size = g.op("Constant", value_t=torch.LongTensor(size)) @@ -1132,6 +1184,10 @@ def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): @symbolic_helper.parse_args("v", "i", "v") @_beartype.beartype def select(g: jit_utils.GraphContext, self, dim, index): + """Implement the select functionality for a pytorch tensor in ONNX. + + Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. + """ index = symbolic_helper._maybe_get_scalar(index) if (not symbolic_helper._is_value(index)) and (index < 0): if index == -1: @@ -1417,29 +1473,39 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding): ] # ensure last pooling starts inside ceiled_output_dim = [ - ceiled_output_dim[i] - 1 - if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) - else ceiled_output_dim[i] + ( + ceiled_output_dim[i] - 1 + if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) + else ceiled_output_dim[i] + ) for i in range(0, len(ceiled_output_dim)) ] padding_ceil = [ - 0 - if (stride[i] == 1) - else ( - kernel_size[i] - - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)) + ( + 0 + if (stride[i] == 1) + else ( + kernel_size[i] + - ( + dim[i] + + 2 * padding[i] + - ((ceiled_output_dim[i] - 1) * stride[i] + 1) + ) + ) ) for i in range(0, len(padding)) ] # ensure padding is not > kernel_size padding_ceil = [ ( - int(padding_ceil[i]) - if padding_ceil[i] < kernel_size[i] - 1 - else int(kernel_size[i] - 1) + ( + int(padding_ceil[i]) + if padding_ceil[i] < kernel_size[i] - 1 + else int(kernel_size[i] - 1) + ) + if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) + else int(padding_ceil[i]) ) - if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) - else int(padding_ceil[i]) for i in range(0, len(padding_ceil)) ] return padding_ceil @@ -4081,6 +4147,7 @@ def alias(g: jit_utils.GraphContext, self): @symbolic_helper.parse_args("v", "i") @_beartype.beartype def unsqueeze(g: jit_utils.GraphContext, self, dim): + """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" # Handle negative dim if dim < 0: rank = symbolic_helper._get_tensor_rank(self) @@ -5580,6 +5647,10 @@ def lift(g: jit_utils.GraphContext, self): @_onnx_symbolic("aten::masked_fill") @_beartype.beartype def masked_fill(g: jit_utils.GraphContext, self, mask, value): + """Implement the masked_fill functionality available for a pytorch tensor in ONNX. + + Fills elements of the input tensor with `value` where `mask` is True. + """ mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) value = symbolic_helper._maybe_get_scalar(value) return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f5206d425b4d..94a57786a4bd 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Functions to export models into the ONNX IR format. These models can be loaded with the ONNX library and then @@ -186,11 +187,10 @@ def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) -@_beartype.beartype def export( model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], args: Union[Tuple[Any, ...], torch.Tensor], - f: Union[str, io.BytesIO], + f: Optional[Union[str, io.BytesIO]] = None, export_params: bool = True, verbose: bool = False, training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, @@ -206,7 +206,8 @@ def export( custom_opsets: Optional[Mapping[str, int]] = None, export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False, autograd_inlining: Optional[bool] = True, -) -> None: + dynamo: bool = False, +) -> Optional[torch.onnx.ONNXProgram]: r"""Exports a model into ONNX format. If ``model`` is not a :class:`torch.jit.ScriptModule` nor a @@ -500,6 +501,8 @@ def forward(self, x): autograd_inlining (bool, default True): Flag used to control whether to inline autograd functions. Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + dynamo (bool, default False): Whether to export the model with Dynamo instead of TorchScript. + Raises: :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it @@ -508,6 +511,43 @@ def forward(self, x): All errors are subclasses of :class:`errors.OnnxExporterError`. """ + if dynamo: + # Unsupported parameters for dynamo export + # TODO: These are not supported AT THE TIME + warnings.warn( + "f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, " + "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " + "autograd_inlining are not supported for dynamo export at the moment." + ) + # TODO: check args normalization + args = _decide_input_format(model, args) + kwargs = {} + if args is not None and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + # TODO: refactor this when we have migrated ExportedProgam and + # needs users to specify dynamic_axes + if dynamic_axes is None or not isinstance(dynamic_axes, dict): + dynamic_shapes = False + else: + dynamic_shapes = True + warnings.warn( + "Specified dynamic axes is not supported for dynamo export at the moment." + ) + # TODO: expose more ExportOptions? + export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_shapes) + onnx_program = torch.onnx.dynamo_export( + model, *args, **kwargs, export_options=export_options + ) + if f is not None: + onnx_program.save(f) + return onnx_program + + if f is None: + raise ValueError( + "Export destination must be specified for torchscript-onnx export." + ) + _export( model, args, @@ -527,6 +567,8 @@ def forward(self, x): autograd_inlining=autograd_inlining, ) + return None + @_beartype.beartype def _is_constant_tensor_list(node): @@ -870,7 +912,6 @@ def _decide_input_format(model, args): warnings.warn("No input args, skipping _decide_input_format") except Exception as e: warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") - return args diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 6b49e7fc72b9..95ed873bf633 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Functions to verify exported ONNX model is functionally equivalent to original PyTorch model. ONNX Runtime is required, and is used as the ONNX backend for export verification. diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 58d9c948416b..341d07b1a2e8 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -6,21 +6,22 @@ future. """ -from . import lr_scheduler, swa_utils -from .adadelta import Adadelta -from .adagrad import Adagrad -from .adam import Adam -from .adamax import Adamax -from .adamw import AdamW -from .asgd import ASGD -from .lbfgs import LBFGS -from .nadam import NAdam -from .optimizer import Optimizer -from .radam import RAdam -from .rmsprop import RMSprop -from .rprop import Rprop -from .sgd import SGD -from .sparse_adam import SparseAdam +from torch.optim import lr_scheduler, swa_utils +from torch.optim.adadelta import Adadelta +from torch.optim.adagrad import Adagrad +from torch.optim.adam import Adam +from torch.optim.adamax import Adamax +from torch.optim.adamw import AdamW +from torch.optim.asgd import ASGD +from torch.optim.lbfgs import LBFGS +from torch.optim.nadam import NAdam +from torch.optim.optimizer import Optimizer +from torch.optim.radam import RAdam +from torch.optim.rmsprop import RMSprop +from torch.optim.rprop import Rprop +from torch.optim.sgd import SGD +from torch.optim.sparse_adam import SparseAdam + del adadelta # type: ignore[name-defined] # noqa: F821 del adagrad # type: ignore[name-defined] # noqa: F821 @@ -36,3 +37,6 @@ del optimizer # type: ignore[name-defined] # noqa: F821 del nadam # type: ignore[name-defined] # noqa: F821 del lbfgs # type: ignore[name-defined] # noqa: F821 + + +import torch.optim._multi_tensor diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index 4a6198956fb8..a307cc76846d 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Functional interface.""" import math from typing import List diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 4d1a4e25319c..d6f19fb069ae 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional import torch @@ -254,7 +255,7 @@ def _single_tensor_adadelta( has_complex: bool, ): # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -310,7 +311,7 @@ def _multi_tensor_adadelta( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -413,7 +414,7 @@ def adadelta( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index a95e985b49eb..0b6dfe852d08 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 1c625682fc34..86785be4ed17 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch @@ -353,7 +354,7 @@ def _single_tensor_adam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -466,7 +467,7 @@ def _multi_tensor_adam( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -743,7 +744,7 @@ def adam( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 005327d8bb88..27caa5f9d81c 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch @@ -243,7 +244,7 @@ def _single_tensor_adamax( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -315,7 +316,7 @@ def _multi_tensor_adamax( return # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -424,7 +425,7 @@ def adamax( See :class:`~torch.optim.Adamax` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 707ac17c361c..00931bed0227 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional, Tuple, Union import torch @@ -354,7 +355,7 @@ def _single_tensor_adamw( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -467,7 +468,7 @@ def _multi_tensor_adamw( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -728,7 +729,7 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 633a14832282..84c7602912d0 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple, Union import torch @@ -214,7 +215,7 @@ def _single_tensor_asgd( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type @@ -287,7 +288,7 @@ def _multi_tensor_asgd( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index e8818cca538c..480b45c84d72 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 77bdb6b46aac..4a5f162a0b20 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import types import warnings @@ -65,6 +66,24 @@ def _check_verbose_deprecated_warning(verbose): return False +def _format_param(name: str, optimizer: Optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + + def _copy(_param): + return _param.clone() if isinstance(_param, Tensor) else _param + + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError( + f"{name} must have the same length as optimizer.param_groups. " + f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}." + ) + else: + param = [param] * len(optimizer.param_groups) + + return list(map(_copy, param)) + + class LRScheduler: _get_lr_called_within_step: bool = False @@ -77,7 +96,10 @@ def __init__(self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated"): # Initialize epoch and base learning rates if last_epoch == -1: for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) else: for i, group in enumerate(optimizer.param_groups): if "initial_lr" not in group: @@ -265,7 +287,7 @@ class LambdaLR(LRScheduler): factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -367,7 +389,7 @@ class MultiplicativeLR(LRScheduler): factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -466,7 +488,7 @@ class StepLR(LRScheduler): gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -525,7 +547,7 @@ class MultiStepLR(LRScheduler): gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -588,7 +610,7 @@ class ConstantLR(LRScheduler): total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor. Default: 5. last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -664,7 +686,7 @@ class LinearLR(LRScheduler): total_iters (int): The number of iterations that multiplicative factor reaches to 1. Default: 5. last_epoch (int): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -755,7 +777,7 @@ class ExponentialLR(LRScheduler): optimizer (Optimizer): Wrapped optimizer. gamma (float): Multiplicative factor of learning rate decay. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -790,7 +812,7 @@ class SequentialLR(LRScheduler): schedulers (list): List of chained schedulers. milestones (list): List of integers that reflects milestone points. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): Does nothing. + verbose (bool | str): Does nothing. .. deprecated:: 2.2 ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the @@ -924,7 +946,7 @@ class PolynomialLR(LRScheduler): optimizer (Optimizer): Wrapped optimizer. total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. power (float): The power of the polynomial. Default: 1.0. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1014,7 +1036,7 @@ class CosineAnnealingLR(LRScheduler): T_max (int): Maximum number of iterations. eta_min (float): Minimum learning rate. Default: 0. last_epoch (int): The index of last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1217,7 +1239,7 @@ class ReduceLROnPlateau(LRScheduler): eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1447,7 +1469,7 @@ class CyclicLR(LRScheduler): number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1491,16 +1513,16 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer - base_lrs = self._format_param("base_lr", optimizer, base_lr) + base_lrs = _format_param("base_lr", optimizer, base_lr) if last_epoch == -1: for lr, group in zip(base_lrs, optimizer.param_groups): if isinstance(group["lr"], Tensor): lr_val = lr.item() if isinstance(lr, Tensor) else lr - group["lr"].fill_(lr) + group["lr"].fill_(lr_val) else: group["lr"] = lr - self.max_lrs = self._format_param("max_lr", optimizer, max_lr) + self.max_lrs = _format_param("max_lr", optimizer, max_lr) step_size_up = float(step_size_up) step_size_down = ( @@ -1531,12 +1553,10 @@ def __init__( ) self.use_beta1 = "betas" in self.optimizer.defaults - self.base_momentums = self._format_param( + self.base_momentums = _format_param( "base_momentum", optimizer, base_momentum ) - self.max_momentums = self._format_param( - "max_momentum", optimizer, max_momentum - ) + self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) if last_epoch == -1: for m_momentum, b_momentum, group in zip( self.max_momentums, self.base_momentums, optimizer.param_groups @@ -1564,17 +1584,6 @@ def _init_scale_fn(self): self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma) self.scale_mode = "iterations" - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError( - f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}" - ) - return param - else: - return [param] * len(optimizer.param_groups) - def scale_fn(self, x) -> float: if self._scale_fn_custom is not None: return self._scale_fn_custom(x) @@ -1684,7 +1693,7 @@ class CosineAnnealingWarmRestarts(LRScheduler): T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1. eta_min (float, optional): Minimum learning rate. Default: 0. last_epoch (int, optional): The index of the last epoch. Default: -1. - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -1888,7 +1897,7 @@ class OneCycleLR(LRScheduler): number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 - verbose (bool): If ``True``, prints a message to stdout for + verbose (bool | str): If ``True``, prints a message to stdout for each update. Default: ``False``. .. deprecated:: 2.2 @@ -2012,7 +2021,7 @@ def __init__( self._anneal_func_type = anneal_strategy # Initialize learning rate variables - max_lrs = self._format_param("max_lr", self.optimizer, max_lr) + max_lrs = _format_param("max_lr", self.optimizer, max_lr) if last_epoch == -1: for idx, group in enumerate(self.optimizer.param_groups): group["initial_lr"] = max_lrs[idx] / div_factor @@ -2030,10 +2039,8 @@ def __init__( "optimizer must support momentum or beta1 with `cycle_momentum` option enabled" ) self.use_beta1 = "betas" in self.optimizer.defaults - max_momentums = self._format_param("max_momentum", optimizer, max_momentum) - base_momentums = self._format_param( - "base_momentum", optimizer, base_momentum - ) + max_momentums = _format_param("max_momentum", optimizer, max_momentum) + base_momentums = _format_param("base_momentum", optimizer, base_momentum) if last_epoch == -1: for m_momentum, b_momentum, group in zip( max_momentums, base_momentums, optimizer.param_groups @@ -2047,17 +2054,6 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) - def _format_param(self, name, optimizer, param): - """Return correctly formatted lr/momentum for each param group.""" - if isinstance(param, (list, tuple)): - if len(param) != len(optimizer.param_groups): - raise ValueError( - f"expected {len(optimizer.param_groups)} values for {name}, got {len(param)}" - ) - return param - else: - return [param] * len(optimizer.param_groups) - def _anneal_func(self, *args, **kwargs): if hasattr(self, "_anneal_func_type"): if self._anneal_func_type == "cos": diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 75a6f49be262..cd2eeff92c05 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional, Tuple, Union import torch @@ -12,6 +13,7 @@ _get_capturable_supported_devices, _get_scalar_dtype, _get_value, + _maximize_doc, _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, @@ -34,6 +36,7 @@ def __init__( decoupled_weight_decay: bool = False, *, foreach: Optional[bool] = None, + maximize: bool = False, capturable: bool = False, differentiable: bool = False, ): @@ -56,6 +59,7 @@ def __init__( weight_decay=weight_decay, momentum_decay=momentum_decay, decoupled_weight_decay=decoupled_weight_decay, + maximize=maximize, foreach=foreach, capturable=capturable, differentiable=differentiable, @@ -65,6 +69,7 @@ def __init__( def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: + group.setdefault("maximize", False) group.setdefault("foreach", None) group.setdefault("capturable", False) group.setdefault("differentiable", False) @@ -188,6 +193,7 @@ def step(self, closure=None): weight_decay=group["weight_decay"], momentum_decay=group["momentum_decay"], eps=group["eps"], + maximize=group["maximize"], decoupled_weight_decay=group["decoupled_weight_decay"], foreach=group["foreach"], capturable=group["capturable"], @@ -207,12 +213,15 @@ def step(self, closure=None): &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ - &\hspace{13mm} \: \textit{decoupled\_weight\_decay} \\ + &\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ - &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\ &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\ @@ -249,6 +258,7 @@ def step(self, closure=None): decoupled_weight_decay (bool, optional): whether to use decoupled weight decay as in AdamW to obtain NAdamW (default: False) {_foreach_doc} + {_maximize_doc} {_capturable_doc} {_differentiable_doc} @@ -276,12 +286,13 @@ def _single_tensor_nadam( momentum_decay: float, eps: float, decoupled_weight_decay: bool, + maximize: bool, capturable: bool, differentiable: bool, has_complex: bool, ): for i, param in enumerate(params): - grad = grads[i] + grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] mu_product = mu_products[i] @@ -294,7 +305,7 @@ def _single_tensor_nadam( exp_avg_sq = torch.view_as_real(exp_avg_sq) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == mu_product.device.type == step_t.device.type @@ -369,6 +380,7 @@ def _multi_tensor_nadam( momentum_decay: float, eps: float, decoupled_weight_decay: bool, + maximize: bool, capturable: bool, differentiable: bool, has_complex: bool, @@ -379,7 +391,7 @@ def _multi_tensor_nadam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) @@ -406,6 +418,9 @@ def _multi_tensor_nadam( grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs ) + if maximize: + grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] + # Update steps # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just @@ -422,9 +437,15 @@ def _multi_tensor_nadam( # Perform stepweight decay torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) else: - grouped_grads = torch._foreach_add( # type: ignore[assignment] - grouped_grads, grouped_params, alpha=weight_decay - ) + # Re-use the intermediate memory (grouped_grads) already allocated for maximize + if maximize: + torch._foreach_add_( + grouped_grads, grouped_params, alpha=weight_decay + ) + else: + grouped_grads = torch._foreach_add( # type: ignore[assignment] + grouped_grads, grouped_params, alpha=weight_decay + ) # Decay the first and second moment running average coefficient torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1) @@ -560,6 +581,7 @@ def nadam( capturable: bool = False, differentiable: bool = False, has_complex: bool = False, + maximize: bool = False, *, beta1: float, beta2: float, @@ -608,6 +630,7 @@ def nadam( lr=lr, weight_decay=weight_decay, momentum_decay=momentum_decay, + maximize=maximize, decoupled_weight_decay=decoupled_weight_decay, eps=eps, capturable=capturable, diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index fc091e273c36..582dc2105a5a 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import math import warnings @@ -24,6 +25,7 @@ import torch import torch.utils.hooks as hooks +from torch._utils import is_compiling from torch.utils._foreach_utils import ( _get_foreach_kernels_supported_devices, _get_fused_kernels_supported_devices, @@ -96,14 +98,14 @@ def _use_grad(self, *args, **kwargs): def _get_value(x): # item is significantly faster than a cpu tensor in eager mode - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return x else: return x.item() if isinstance(x, torch.Tensor) else x def _stack_if_compiling(x): - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and is_compiling(): return torch.stack(x) else: return x @@ -144,7 +146,7 @@ def wrapper(func): # the capturable flag. If capturable=True, this is not a problem. @functools.wraps(func) def maybe_fallback(*args, **kwargs): - if torch.compiler.is_compiling() and ( + if is_compiling() and ( not kwargs.get("capturable", False) and has_state_steps and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) @@ -417,7 +419,7 @@ def _cuda_graph_capture_health_check(self) -> None: # Thus, when compiling, inductor will determine if cudagraphs # can be enabled based on whether there is input mutation or CPU tensors. if ( - not torch.compiler.is_compiling() + not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available() ): @@ -504,7 +506,7 @@ def _group_tensors_by_device_and_dtype( """Groups a list of lists of tensors by device and dtype. Skips this step if we are compiling since this will occur during inductor lowering. """ - if torch.compiler.is_compiling(): + if is_compiling(): return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} else: return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 619f10493587..1ecf20ffde86 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import cast, List, Optional, Tuple, Union import torch @@ -271,7 +272,7 @@ def _single_tensor_radam( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step_t.device.type @@ -369,7 +370,7 @@ def _multi_tensor_radam( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False ) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index bdc3ec0b8b3f..5311aa2fd6b8 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch @@ -276,7 +277,7 @@ def _single_tensor_rmsprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -349,7 +350,7 @@ def _multi_tensor_rmsprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -467,7 +468,7 @@ def rmsprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index af1854cc518a..ae34865f1c15 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional, Tuple import torch @@ -236,7 +237,7 @@ def _single_tensor_rprop( step = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert ( param.device.type == step.device.type @@ -302,7 +303,7 @@ def _multi_tensor_rprop( assert not differentiable, "_foreach ops don't support autograd" # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: + if not torch._utils.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices() assert all( p.device.type == step.device.type @@ -414,7 +415,7 @@ def rprop( """ # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch.compiler.is_compiling() and not all( + if not torch._utils.is_compiling() and not all( isinstance(t, torch.Tensor) for t in state_steps ): raise RuntimeError( diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index aa3062095c6a..8cf26cfcf95c 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Optional import torch @@ -208,7 +209,7 @@ def step(self, closure=None): .. note:: The implementation of SGD with Momentum/Nesterov subtly differs from - Sutskever et. al. and implementations in some other frameworks. + Sutskever et al. and implementations in some other frameworks. Considering the specific case of Momentum, the update can be written as @@ -221,7 +222,7 @@ def step(self, closure=None): where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the parameters, gradient, velocity, and momentum respectively. - This is in contrast to Sutskever et. al. and + This is in contrast to Sutskever et al. and other frameworks which employ an update of the form .. math:: @@ -429,7 +430,7 @@ def _multi_tensor_sgd( if not device_has_sparse_grad: # handle internal item() call if lr is a tensor - if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling(): + if isinstance(lr, torch.Tensor) and torch._utils.is_compiling(): grads_x_lr = torch._foreach_mul(device_grads, -lr) torch._foreach_add_(device_params, grads_x_lr) else: diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 88643d1a5646..adb7c17629c2 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import List, Tuple import torch diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 7c2c9cdaf6f9..440897e6041e 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import itertools import math import warnings @@ -7,7 +8,7 @@ import torch from torch import Tensor from torch.nn import Module -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import _format_param, LRScheduler from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices from .optimizer import Optimizer @@ -390,7 +391,7 @@ def __init__( anneal_strategy: Literal["cos", "linear"] = "cos", last_epoch=-1, ): - swa_lrs = self._format_param(optimizer, swa_lr) + swa_lrs = _format_param("swa_lr", optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups): group["swa_lr"] = swa_lr if anneal_strategy not in ["cos", "linear"]: @@ -409,22 +410,6 @@ def __init__( self.anneal_epochs = anneal_epochs super().__init__(optimizer, last_epoch) - @staticmethod - def _format_param( - optimizer: Optimizer, - swa_lrs: Union[float, List[float], Tuple[float, ...]], - ) -> Union[List[float], Tuple[float, ...]]: - if isinstance(swa_lrs, (list, tuple)): - if len(swa_lrs) != len(optimizer.param_groups): - raise ValueError( - "swa_lr must have the same length as " - f"optimizer.param_groups: swa_lr has {len(swa_lrs)}, " - f"optimizer.param_groups has {len(optimizer.param_groups)}" - ) - return swa_lrs - else: - return [swa_lrs] * len(optimizer.param_groups) - @staticmethod def _linear_anneal(t): return t diff --git a/torch/overrides.py b/torch/overrides.py index 6c521bc7003b..509568900983 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -357,6 +357,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor._is_any_true, Tensor._addmm_activation, Tensor.to_padded_tensor, + Tensor._use_count, } diff --git a/torch/package/_digraph.py b/torch/package/_digraph.py index f84a51398f00..8b753f7ebdc4 100644 --- a/torch/package/_digraph.py +++ b/torch/package/_digraph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import deque from typing import List, Set diff --git a/torch/package/_directory_reader.py b/torch/package/_directory_reader.py index cec5333c3e3f..77d629cccce2 100644 --- a/torch/package/_directory_reader.py +++ b/torch/package/_directory_reader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os.path from glob import glob from typing import cast diff --git a/torch/package/_importlib.py b/torch/package/_importlib.py index fd303b6141e7..9741925315e5 100644 --- a/torch/package/_importlib.py +++ b/torch/package/_importlib.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import _warnings import os.path diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 0876d64664a2..7dcf3538631f 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Import mangling. See mangling.md for details. """ diff --git a/torch/package/_mock.py b/torch/package/_mock.py index b0bdb95cc48c..44876b1a1d3f 100644 --- a/torch/package/_mock.py +++ b/torch/package/_mock.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs _magic_methods = [ "__subclasscheck__", "__hex__", diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index cabc6a82164f..2ac59395b73b 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """isort:skip_file""" from pickle import ( # type: ignore[attr-defined] _compat_pickle, diff --git a/torch/package/_package_unpickler.py b/torch/package/_package_unpickler.py index b00210e3c191..890e6b4e03ba 100644 --- a/torch/package/_package_unpickler.py +++ b/torch/package/_package_unpickler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import _compat_pickle import pickle diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index a810d50661cb..2d5145b40aa7 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """List of Python standard library modules. Sadly, there is no reliable way to tell whether a module is part of the diff --git a/torch/package/analyze/trace_dependencies.py b/torch/package/analyze/trace_dependencies.py index 9f882fb33481..405fcf2f9bc2 100644 --- a/torch/package/analyze/trace_dependencies.py +++ b/torch/package/analyze/trace_dependencies.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys from typing import Any, Callable, Iterable, List, Tuple diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 1453ad3a5ded..44e07978640f 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Dict, List from .glob_group import GlobGroup, GlobPattern diff --git a/torch/package/find_file_dependencies.py b/torch/package/find_file_dependencies.py index af8cd9fec84d..80cfccbec50a 100644 --- a/torch/package/find_file_dependencies.py +++ b/torch/package/find_file_dependencies.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ast from typing import List, Optional, Tuple diff --git a/torch/package/glob_group.py b/torch/package/glob_group.py index a8434788d016..974364400502 100644 --- a/torch/package/glob_group.py +++ b/torch/package/glob_group.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import re from typing import Iterable, Union diff --git a/torch/package/importer.py b/torch/package/importer.py index dd01d09209a8..513847513910 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import importlib from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined] diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 493c017ccf99..bfa00278fa4b 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import importlib.machinery import io diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 9e2f74354db5..1a103ab6c5c9 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins import importlib import importlib.machinery diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index e3c4145fd91f..4a681daf788e 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler's context manager API can be used to better understand what model operators are the most expensive, diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index b719df2a56ee..1834f0494e02 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import dataclasses import enum diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 02e9b014d308..a7ec5d05dd68 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import json import math import os diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 35f6e71de558..d69fa4630595 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import operator import re diff --git a/torch/profiler/itt.py b/torch/profiler/itt.py index 4d072957d6fe..4666bba515a3 100644 --- a/torch/profiler/itt.py +++ b/torch/profiler/itt.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager try: diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index a9f65104a99e..f43dcc06de20 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gzip import json import os @@ -601,6 +602,7 @@ def __init__( warn( "`use_cuda` is deprecated, use `activities` argument instead", FutureWarning, + stacklevel=2, ) if use_cuda: activities_set.add(ProfilerActivity.CUDA) @@ -849,6 +851,7 @@ def start(self): if self._registered and not self._execution_trace_running: _enable_execution_trace_observer() self._execution_trace_running = True + self._record_pg_config() def stop(self): """ @@ -875,3 +878,16 @@ def get_output_file_path(self) -> str: "A callback to the ET profiler needs to be registered " "first before getting the output file path" ) + + def _record_pg_config(self) -> None: + # Records the PG config info to the trace as node: + # ## process_group:init ## + if ( + self.is_registered + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info + torch.autograd._record_function_with_args_enter( + "## process_group:init ##", json.dumps(pg_config_info) + ) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index fd83d88a3e3e..a82518db6084 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from .quantize import * # noqa: F403 from .observer import * # noqa: F403 from .qconfig import * # noqa: F403 diff --git a/torch/quantization/_quantized_conversions.py b/torch/quantization/_quantized_conversions.py index 2b7670ea4802..8d930c366c0d 100644 --- a/torch/quantization/_quantized_conversions.py +++ b/torch/quantization/_quantized_conversions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/quasirandom.py b/torch/quasirandom.py index 884d1d17e77c..a1218012ceb6 100644 --- a/torch/quasirandom.py +++ b/torch/quasirandom.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import Optional diff --git a/torch/random.py b/torch/random.py index 74d448488042..0916fe115a92 100644 --- a/torch/random.py +++ b/torch/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Generator import warnings diff --git a/torch/serialization.py b/torch/serialization.py index e4ad1f7e9c6e..311aac28c8c5 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import difflib import functools import os @@ -390,6 +391,25 @@ def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.Untyp def default_restore_location(storage, location): + """ + Restores `storage` using a deserializer function registered for the `location`. + + This function looks in the registry for deserializer functions that match the `location`. + If found, it attempts to use them, in priority order, to restore `storage` until one + returns a not `None` result. If no deserializer can be found in the registry, or all found fail + to bear a result, it raises a `RuntimeError`. + + Args: + storage (STORAGE): the storage object to restore + location (str): the location tag associated with the storage object + + Returns: + storage: Optional[STORAGE] + + Raises: + RuntimeError: If no deserializer matching `location` is found in the registry or if + all matching ones return `None`. + """ for _, _, fn in _package_registry: result = fn(storage, location) if result is not None: @@ -921,7 +941,8 @@ def load( pickle_module: module used for unpickling metadata and objects (has to match the :attr:`pickle_module` used to serialize file) weights_only: Indicates whether unpickler should be restricted to - loading only tensors, primitive types and dictionaries + loading only tensors, primitive types, dictionaries + and any types added via :func:`torch.serialization.add_safe_globals`. mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they are moved to the location that they were tagged with when saving, or specified by ``map_location``. This @@ -1266,7 +1287,7 @@ def persistent_load(saved_id): if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): raise RuntimeError( "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " - f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this " + f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this ' "functionality.") magic_number = pickle_module.load(f, **pickle_load_args) diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index d86a1245dc27..f9f73b2dca07 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Iterable import torch diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 8ca4aed7d71a..5b86e068096f 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # The Tensor classes are added to this module by python_tensor.cpp from typing import Optional, Tuple, List, Union, Any diff --git a/torch/sparse/_semi_structured_conversions.py b/torch/sparse/_semi_structured_conversions.py index 5203ad245b28..141464f7dc76 100644 --- a/torch/sparse/_semi_structured_conversions.py +++ b/torch/sparse/_semi_structured_conversions.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 551111b429a5..bcaa889ba1ee 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import torch diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index a22b5c8077e3..e11bdf59c882 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import os import torch diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index e6fc1329e812..eedfa03b756a 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Provides optimal triton kernel parameters. Aim diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index d592e5ef6a62..6105038e4df7 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from collections import namedtuple from typing import Any, Optional, Tuple, List, Callable, Dict diff --git a/torch/storage.py b/torch/storage.py index 32070783f494..c094ba5ac3e9 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import torch @@ -39,7 +40,7 @@ def size(self) -> int: def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 - def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704 """Returns a copy of this object in CUDA memory. If this object is already in CUDA memory and on the correct device, then @@ -54,7 +55,7 @@ def cuda(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] device2 = torch.device('cuda', device) if device else torch.device('cuda') return self.to(device=device2, non_blocking=non_blocking) - def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var] # noqa: E704 + def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[type-var, misc] # noqa: E704 """Returns a copy of this object in HPU memory. If this object is already in HPU memory and on the correct device, then @@ -182,7 +183,7 @@ def _to(self, dtype): storage = storage.clone() return storage - def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] # noqa: E704 + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc] # noqa: E704 return _to(self, device, non_blocking) def double(self): @@ -856,7 +857,7 @@ def hpu(self, device=None, non_blocking=False) -> T: # type: ignore[misc, type- hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking) return self._new_wrapped_storage(hpu_storage) - def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var] + def to(self, *, device: torch.device, non_blocking: bool = False) -> T: # type: ignore[type-var, misc] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError(f"Cannot create {device.type.upper()} storage with quantized dtype") diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 58b8f828e354..352ce67e074a 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -1,3 +1,4 @@ from torch._C import FileCheck as FileCheck +from . import _utils from ._comparison import assert_allclose, assert_close as assert_close from ._creation import make_tensor as make_tensor diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 85d5adb0cd3a..9815cc2a8807 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import abc import cmath import collections.abc diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 2433552a0873..d8fb2ef18b1d 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -152,6 +152,7 @@ def clamp(a: float, l: float, h: float) -> float: "is deprecated since 2.1 and will be removed in 2.3. " "Use `torch.full(...)` instead.", FutureWarning, + stacklevel=3, ) elif low >= high: raise ValueError(f"`low` must be less than `high`, but got {low} >= {high}") diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 054f1a135740..02b38bf9351a 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -50,13 +50,24 @@ def evaluate_platform_supports_flash_attention(): return not IS_WINDOWS and SM80OrLater return False +def evaluate_platform_supports_cudnn_attention(): + return (not TEST_WITH_ROCM) and (not IS_WINDOWS) and TEST_CUDA and SM80OrLater + +def evaluate_platform_supports_efficient_attention(): + if TEST_WITH_ROCM: + return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + if TEST_CUDA: + return True + return False + PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) -PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM) -# TODO(eqy): gate this against a cuDNN version -PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and - torch.backends.cuda.cudnn_sdp_enabled()) +PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) +PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention()) + # This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate -PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) +PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or + PLATFORM_SUPPORTS_CUDNN_ATTENTION or + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 07caa0ac3eee..2e2a379a501e 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -15,7 +15,7 @@ import torch from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ - IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, \ + IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \ _TestParametrizer, compose_parametrize_fns, dtype_name, \ TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \ get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \ @@ -590,6 +590,18 @@ def setUpClass(cls): def _should_stop_test_suite(self): return False +class HPUTestBase(DeviceTypeTestBase): + device_type = 'hpu' + primary_device: ClassVar[str] + + @classmethod + def get_primary_device(cls): + return cls.primary_device + + @classmethod + def setUpClass(cls): + cls.primary_device = 'hpu:0' + class PrivateUse1TestBase(DeviceTypeTestBase): primary_device: ClassVar[str] device_mod = None @@ -701,6 +713,8 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l test_bases.append(MPSTestBase) if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases: test_bases.append(XPUTestBase) + if TEST_HPU and HPUTestBase not in test_bases: + test_bases.append(HPUTestBase) # Filter out the device types based on user inputs desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for) if include_lazy: @@ -1060,6 +1074,10 @@ class skipMPSIf(skipIf): def __init__(self, dep, reason): super().__init__(dep, reason, device_type='mps') +class skipHPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type='hpu') + # Skips a test on XLA if the condition is true. class skipXLAIf(skipIf): @@ -1343,6 +1361,9 @@ def onlyMPS(fn): def onlyXPU(fn): return onlyOn('xpu')(fn) +def onlyHPU(fn): + return onlyOn('hpu')(fn) + def onlyPRIVATEUSE1(fn): device_type = torch._C._get_privateuse1_backend_name() device_mod = getattr(torch, device_type, None) @@ -1401,6 +1422,9 @@ def expectedFailureMeta(fn): def expectedFailureXLA(fn): return expectedFailure('xla')(fn) +def expectedFailureHPU(fn): + return expectedFailure('hpu')(fn) + # Skips a test on CPU if LAPACK is not available. def skipCPUIfNoLapack(fn): return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) @@ -1578,6 +1602,9 @@ def skipXLA(fn): def skipMPS(fn): return skipMPSIf(True, "test doesn't work on MPS backend")(fn) +def skipHPU(fn): + return skipHPUIf(True, "test doesn't work on HPU backend")(fn) + def skipPRIVATEUSE1(fn): return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index b325a9601e25..473e5c35e07a 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -357,7 +357,7 @@ def create_tcp_store( timeout=timedelta(minutes=5), wait_for_workers=True, jit_class=False, - use_libuv=False + use_libuv=True, ): """ Creates a TCP store. Retries if the chosen port is already in use. @@ -544,7 +544,11 @@ def wrapper(self): # Constructor patches current instance test method to # assume the role of the main process and join its subprocesses, # or run the underlying test function. - def __init__(self, method_name: str = "runTest") -> None: + def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName super().__init__(method_name) fn = getattr(self, method_name) setattr(self, method_name, self.join_or_run(fn)) @@ -867,7 +871,9 @@ def run_subtests( # Map keyword to chosen value subtest_kwargs = dict(zip(subtest_config_keys, values)) with cls_inst.subTest(**subtest_kwargs): + torch._dynamo.reset() test_fn(*test_args, **test_kwargs, **subtest_kwargs) + torch._dynamo.reset() c10d.barrier() diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 4e266117c13b..2b5fdc613c2e 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["oncall: distributed"] import contextlib @@ -9,7 +10,7 @@ from contextlib import nullcontext from copy import deepcopy from enum import auto, Enum -from functools import partial, wraps +from functools import wraps from typing import ( Any, Callable, @@ -1086,6 +1087,12 @@ def setUp(self): def run_subtests(self, *args, **kwargs): return run_subtests(self, *args, **kwargs) + def perThreadSetUp(self): + torch._dynamo.reset() + + def perThreadTearDown(self): + torch._dynamo.reset() + class FSDPTest(MultiProcessTestCase): def setUp(self): @@ -1156,7 +1163,9 @@ def _run(cls, rank, test_name, file_name, pipe): # immediately exiting due to a skip doesn't cause flakiness. dist.barrier(device_ids=device_ids) + torch._dynamo.reset() self.run_test(test_name, pipe) + torch._dynamo.reset() dist.barrier(device_ids=device_ids) @@ -1416,45 +1425,49 @@ def _test_fsdp_parity( def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None): def fully_shard_with_compiled_compute(*args, **kwargs): - # compile ``module._call_impl`` - # to showcase how to include user-registered hooks + torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator] if compile_compute_on_module is None or isinstance( args[0], compile_compute_on_module ): args[0].compile() - return torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator] - class FullyShardPatch(Enum): - # apply ``partial`` in order to use ``Enum.value`` - EAGER = partial(torch.distributed._composable.fsdp.fully_shard) # type: ignore[var-annotated, arg-type] - COMPILED_COMPUTE = partial(fully_shard_with_compiled_compute) # type: ignore[arg-type] - # add FULL for tracing FSDP + class FullyShardMode(Enum): + EAGER = auto() + COMPILED_COMPUTE = auto() def decorator(func): @wraps(func) def wrapper(*args, **kwargs): original_fully_shard = torch.distributed._composable.fsdp.fully_shard - for fully_shard_patch in FullyShardPatch: - if fully_shard_patch != FullyShardPatch.EAGER and not has_triton(): + for mode in FullyShardMode: + if mode != FullyShardMode.EAGER and not has_triton(): warnings.warn("Inductor on GPU needs Triton and recent GPU arch") continue - imported_fully_shard = ( - f"{func.__module__}.{original_fully_shard.__name__}" - ) - with mock.patch( - imported_fully_shard, - fully_shard_patch.value, - ): - func(*args, **kwargs) - torch.distributed.barrier() - # mock.patch.__exit__ does not work with multi-thread - # thread 1 set {func.__module__}.fully_shard - # thread 2 read {func.__module__}.fully_shard and thought it is original - # hence we manually reset them after __exit__ - import_path, _ = mock._get_target(imported_fully_shard) # type: ignore[attr-defined] - setattr( - import_path(), original_fully_shard.__name__, original_fully_shard - ) + # barrier to ensure thread reading the same value + original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks + original_compile_threads = torch._inductor.config.compile_threads + torch.distributed.barrier() + + if mode == FullyShardMode.EAGER: + fully_shard_patch = original_fully_shard + elif mode == FullyShardMode.COMPILED_COMPUTE: + torch._dynamo.config.skip_fsdp_hooks = True + torch._inductor.config.compile_threads = 1 + fully_shard_patch = fully_shard_with_compiled_compute # type: ignore[assignment] + else: + raise NotImplementedError( + f"Need to implement FullyShardMode={mode}" + ) + + # fully_shard is imported as a global + # through `from ... import fully_shard` + func.__globals__[original_fully_shard.__name__] = fully_shard_patch + func(*args, **kwargs) + # other threads use patched func before this thread restores + torch.distributed.barrier() + func.__globals__[original_fully_shard.__name__] = original_fully_shard + torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks + torch._inductor.config.compile_threads = original_compile_threads return wrapper diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f551beb759cf..5c32d1a11aff 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -36,9 +36,10 @@ make_fullrank_matrices_with_distinct_singular_values, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, - GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW, + GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, TEST_WITH_TORCHINDUCTOR ) +from torch.testing._utils import wrapper_set_seed import torch._refs as refs # noqa: F401 import torch._refs.nn.functional @@ -2100,7 +2101,7 @@ def error_inputs_T(self, device, has_ndims_error=False): r'to reverse their shape is not supported\.')) -def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False, **kwargs): +def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False): """ This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n). Their matrix product could be used to generate tensor of shape (*, m, n) of rank k. @@ -2114,13 +2115,18 @@ def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad= for k in range(min(3, m, n)): a = make_arg((*batch, m, k)) b = make_arg((*batch, n, k)) - yield SampleInput(a, b, **kwargs) + yield a, b def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): - for sample in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad, **kwargs): - *batch, m, k = sample.input.shape - *_, n, _ = sample.args[0].shape + # Function that's well defined on the outputs for complex inputs + def fn(usv): + U, S, V = usv + return U @ V.mH, S + + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] # NOTE: since svd_lowrank relies on non rank-revealing SVD, # it inherits the problem of unstable behavior with repeated @@ -2129,20 +2135,13 @@ def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwa # we can only use k for q. # This issues could be resolved with using a rank-revealing SVD # which does not include "zero" singular values. - op_kwargs = { - 'q': k, - 'M': None - } + yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn) - # without M specified - yield clone_sample(sample, **op_kwargs) - - # now with M - # TODO: fix bug in the documentation for svd_lowrank: - # M has to be (*, m, n), and not (*, 1, n) as written - # in the documentation - op_kwargs['M'] = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) - yield clone_sample(sample, **op_kwargs) + for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): + *batch, m, k = a.shape + n = b.shape[-2] + M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) + yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn) def chunk_iter(iterable, size): it = iter(iterable) @@ -2624,6 +2623,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): make_arg((S,)), 0, torch.tensor([], dtype=torch.uint8, device=device)) + yield SampleInput( + make_arg((S,)), + 0, + torch.tensor([[], []], dtype=torch.uint8, device=device)) # 0D tensor case yield SampleInput( make_arg(()), @@ -8689,7 +8692,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=mask_type, compute_log_sumexp=requires_grad, scale=scale, - causal_diagonal=None, seqlen_k=None )) @@ -8707,7 +8709,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=0, # No Mask compute_log_sumexp=requires_grad, scale=None, - causal_diagonal=None, seqlen_k=None ) @@ -8726,7 +8727,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=0, # No Mask compute_log_sumexp=requires_grad, scale=None, - causal_diagonal=None, seqlen_k=None ) ) @@ -8749,7 +8749,6 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g custom_mask_type=0, # No Mask compute_log_sumexp=requires_grad, scale=None, - causal_diagonal=None, seqlen_k=None, ) ) @@ -9165,6 +9164,7 @@ def __init__( self._set_rightmost_arg_types( rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor, ) + self._intersperse_empty = (True, False) def _set_rightmost_arg_types( self, @@ -9329,7 +9329,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): # add empty tensor interspersion to test fully fixing #100701 for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( - num_input_tensors, self._rightmost_arg_types, (True, False)): + num_input_tensors, self._rightmost_arg_types, self._intersperse_empty): if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'): # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy continue @@ -9363,6 +9363,24 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): ) +class foreach_max_sample_func(foreach_inputs_sample_func): + def __init__( + self, + arity: int, + rightmost_supports_scalar: bool, + rightmost_supports_scalarlist: bool, + rightmost_supports_tensor: bool = False, + ) -> None: + super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor) + self._intersperse_empty = (False,) + + def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): + return [] + + def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): + return False + + class foreach_norm_sample_func(foreach_inputs_sample_func): def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): assert "num_input_tensors" not in kwargs @@ -11019,14 +11037,30 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=(torch.half,), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=(torch.half,), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.half,), device_type="cpu"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=(torch.half,), device_type="cpu"), DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), DecorateInfo( unittest.expectedFailure, @@ -11057,10 +11091,14 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), @@ -11089,15 +11127,59 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), # fails with div_cpu is not implemented with ComplexHalf - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), ), ), ] foreach_reduce_op_db: List[ForeachFuncInfo] = [ + ForeachFuncInfo( + "max", + sample_inputs_func=foreach_max_sample_func(1, False, False), + supports_autograd=True, + supports_inplace_autograd=True, + supports_forward_ad=True, + decorators=( + # no complex support for ordering ops like max + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_autodiff", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestForeach", + "test_foreach_reduce_large_input", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_meta_outplace", + dtypes=(torch.complex128, torch.complex64), + ), + ), + ), ForeachFuncInfo( "norm", sample_inputs_func=foreach_norm_sample_func(1, False, False), @@ -11299,22 +11381,6 @@ def reference_mse_loss(input, target, reduction="mean"): return se -def wrapper_set_seed(op, *args, **kwargs): - """Wrapper to set seed manually for some functions like dropout - See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. - """ - with freeze_rng_state(): - torch.manual_seed(42) - output = op(*args, **kwargs) - - if isinstance(output, torch.Tensor) and output.device.type == "lazy": - # We need to call mark step inside freeze_rng_state so that numerics - # match eager execution - torch._lazy.mark_step() - - return output - - def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] @@ -11522,6 +11588,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] return np.reshape(input, out_shape) + +def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): + yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) + yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -13025,6 +13097,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_diagonal_scatter), + OpInfo('alias_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), + sample_inputs_func=sample_inputs_alias_copy, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True), BinaryUfuncInfo('eq', ref=np.equal, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), @@ -14925,8 +15002,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # RuntimeError: UNSUPPORTED DTYPE: complex DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), - # RuntimeError: Conv3D is not supported on MPS - DecorateInfo(unittest.expectedFailure, 'TestConsistency'), # AssertionError: Tensor-likes are not close! # break slow tests DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), @@ -15863,15 +15938,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): device_type='cpu'), DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', device_type='cpu'), - # TODO: Do not work even on MI200 because of stride mismatching. - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_amp', - device_type='cuda', active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cpu'), @@ -15891,6 +15957,19 @@ def reference_flatten(input, start_dim=0, end_dim=-1): device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), + # FIXME + DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), + 'TestCompositeCompliance', 'test_cow_input', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), + DecorateInfo(unittest.skip('test_fake_crossref_backward_amp does not work with efficient attention on ROCM'), + 'TestFakeTensor', 'test_fake_crossref_backward_amp', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), + DecorateInfo(unittest.skip('test_fake_crossref_backward_no_amp does not work with efficient attention on ROCM'), + 'TestFakeTensor', 'test_fake_crossref_backward_no_amp', + device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), # registered in fake_impls.py instead of _meta_registrations.py, so meta kernels will fail. # However, for implementations that fall back to the constituent ops, the meta kernels may not # fail. Fused kernels will fail, whereas unfused kernels will not fail. @@ -15898,6 +15977,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # mem_eff_attention also supports fp32 - so if it is supported the test will fail. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bfloat16, torch.float16), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), + # TODO: float32 support in ROCM efficient attention DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", @@ -15937,13 +16017,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), # None Mismatch Tensor DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), - # TODO: Do not work on MI200 because of stride mismatching. - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', - device_type='cuda', dtypes=[torch.float16, torch.bfloat16], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), ) ), OpInfo( @@ -15960,7 +16033,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): check_batched_forward_grad=False, # TODO: Skip because it produces a CUDA illegal memory access for some reason skip_cow_input_backward=True, - decorators=[skipCUDAIf(TEST_WITH_ROCM, "ROCm doesn't support efficient attention")], + # FIXME: mask_type == 2 (LowerRight) + decorators=[ + skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), + skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], skips=( # Device mismatch due to philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), @@ -17656,10 +17732,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1): lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs), *args, **kwargs ), - dtypes=floating_types(), + dtypes=floating_and_complex_types(), # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, supports_out=False, + # Due to the use of randomness check_batched_grad=False, check_batched_gradgrad=False, check_batched_forward_grad=False, @@ -17667,14 +17744,29 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, sample_inputs_func=sample_inputs_svd_lowrank, decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, - DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}), - 'TestCommon', 'test_noncontiguous_samples', - device_type='cuda')], + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=1e-02, rtol=1e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + DecorateInfo(unittest.skip("See comment above"), + 'TestBwdGradientsCUDA', + 'test_fn_gradgrad', + dtypes=[torch.complex128]), + ], skips=( # test does not work with passing lambda for op DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'), )), OpInfo('pca_lowrank', @@ -17682,7 +17774,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs), *args, **kwargs ), - dtypes=floating_types(), + dtypes=floating_and_complex_types(), # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, supports_out=False, @@ -17693,13 +17785,25 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_pca_lowrank, decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, - DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}), - 'TestCommon', 'test_noncontiguous_samples', - device_type='cuda')], + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), + torch.complex64: tol(atol=4e-02, rtol=4e-02)}), + 'TestCommon', 'test_noncontiguous_samples'), + # FIXME This should be the following, but the toleranceOverride does not seem to do anything! + # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), + # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), + DecorateInfo(unittest.skip("See comment above"), + 'TestFwdGradients', + 'test_fn_fwgrad_bwgrad', + dtypes=[torch.complex128]), + + ], skips=( # test does not work with passing lambda for op DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 + DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.complex64, torch.complex128)), DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), )), BinaryUfuncInfo('polar', @@ -23135,6 +23239,10 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # # View & Shape OpInfos # + PythonRefInfo( + "_refs.alias_copy", + torch_opinfo_name="alias_copy", + ), PythonRefInfo( "_refs.atleast_1d", torch_opinfo_name="atleast_1d", diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index c11314721f27..0505c749a7f9 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -3016,7 +3016,7 @@ def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mea return output -# this directly follows Graves et al's paper, in contrast to the production implementation, it does not use log-space +# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'): input_lengths = torch.as_tensor(input_lengths, dtype=torch.long) target_lengths = torch.as_tensor(target_lengths, dtype=torch.long) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index c7122c8666d4..ac4a7f920cc2 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -641,7 +641,6 @@ def optim_error_inputs_func_lbfgs(device, dtype): return error_inputs -# Weird story bro, NAdam and RAdam do not have maximize. def optim_inputs_func_nadam(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), @@ -694,6 +693,11 @@ def optim_inputs_func_nadam(device, dtype=None): }, desc="decoupled_weight_decay", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), ] + (cuda_supported_configs if "cuda" in str(device) else []) @@ -1253,6 +1257,17 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_fused_matches_forloop", ), + DecorateInfo( + # Note on tolerances: + # Tracking through #127000 + toleranceOverride( + { + torch.float32: tol(atol=3e-5, rtol=1.3e-06), + } + ), + "TestCudaOptims", + "test_grad_scaling_autocast_fused_optimizers", + ), ), skips=( DecorateInfo( @@ -1369,6 +1384,20 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_fused_matches_forloop", ), + # Note on tolerances: + # Tracking through #127000 + DecorateInfo( + toleranceOverride( + { + torch.float32: tol( + atol=3e-5, + rtol=1.3e-06, + ) + } + ), + "TestCudaOptims", + "test_grad_scaling_autocast_fused_optimizers", + ), ), skips=( DecorateInfo( @@ -1549,13 +1578,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_load_nontensor_step", ), - DecorateInfo( - skipIfTorchDynamo( - "Errors, see https://github.com/pytorch/pytorch/issues/117150" - ), - "TestOptimRenewed", - "test_state_dict_with_cuda_params", - ), DecorateInfo( skipIfTorchDynamo( "This test uses mocks, which dynamo does not support" diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 5f9ef602d518..2097e25bdaa8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -96,7 +96,6 @@ from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree -from .composite_compliance import no_dispatch try: import pytest has_pytest = True @@ -104,6 +103,10 @@ has_pytest = False +def freeze_rng_state(*args, **kwargs): + return torch.testing._utils.freeze_rng_state(*args, **kwargs) + + # Class to keep track of test flags configurable by environment variables. # Flags set here are intended to be read-only and should not be modified after # definition. @@ -1233,6 +1236,7 @@ def TemporaryDirectoryName(suffix=None): TEST_MKL = torch.backends.mkl.is_available() TEST_MPS = torch.backends.mps.is_available() TEST_XPU = torch.xpu.is_available() +TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False TEST_CUDA = torch.cuda.is_available() custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available() @@ -1546,6 +1550,31 @@ def has_corresponding_torch_dtype(np_dtype): torch.complex32: np.complex64 }) +def skipIfNNModuleInlined( + msg="test doesn't currently work with nn module inlining", + condition=torch._dynamo.config.inline_inbuilt_nn_modules, +): # noqa: F821 + def decorator(fn): + if not isinstance(fn, type): + + @wraps(fn) + def wrapper(*args, **kwargs): + if condition: + raise unittest.SkipTest(msg) + else: + fn(*args, **kwargs) + + return wrapper + + assert isinstance(fn, type) + if condition: + fn.__unittest_skip__ = True + fn.__unittest_skip_why__ = msg + + return fn + + return decorator + def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): def dec_fn(fn): reason = f"skipIfRocm: {msg}" @@ -1594,6 +1623,15 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +def skipIfHpu(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_HPU: + raise unittest.SkipTest("test doesn't currently work with HPU") + else: + fn(*args, **kwargs) + return wrapper + # Skips a test on CUDA if ROCm is available and its version is lower than requested. def skipIfRocmVersionLessThan(version=None): def dec_fn(fn): @@ -1949,35 +1987,6 @@ def set_rng_seed(seed): np.random.seed(seed) -disable_functorch = torch._C._DisableFuncTorch - - -@contextlib.contextmanager -def freeze_rng_state(): - # no_dispatch needed for test_composite_compliance - # Some OpInfos use freeze_rng_state for rng determinism, but - # test_composite_compliance overrides dispatch for all torch functions - # which we need to disable to get and set rng state - with no_dispatch(), disable_functorch(): - rng_state = torch.get_rng_state() - if torch.cuda.is_available(): - cuda_rng_state = torch.cuda.get_rng_state() - try: - yield - finally: - # Modes are not happy with torch.cuda.set_rng_state - # because it clones the state (which could produce a Tensor Subclass) - # and then grabs the new tensor's data pointer in generator.set_state. - # - # In the long run torch.cuda.set_rng_state should probably be - # an operator. - # - # NB: Mode disable is to avoid running cross-ref tests on thes seeding - with no_dispatch(), disable_functorch(): - if torch.cuda.is_available(): - torch.cuda.set_rng_state(cuda_rng_state) - torch.set_rng_state(rng_state) - @contextlib.contextmanager def set_default_dtype(dtype): saved_dtype = torch.get_default_dtype() @@ -2310,7 +2319,7 @@ def matches_test(target: str): print(f"Test {disabled_test} is disabled for some unrecognized ", f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ", - "assigned to this flaky test, changing \"Platforms: ...\" to a comma separated ", + 'assigned to this flaky test, changing "Platforms: ..." to a comma separated ', f"subset of the following (or leave it blank to match all platforms): {valid_plats}") # Sanitize the platforms list so that we continue to disable the test for any valid platforms given @@ -2622,7 +2631,11 @@ def rel_tol(self, prec: float) -> None: # the test, skip it instead. _ignore_not_implemented_error = False - def __init__(self, method_name='runTest'): + def __init__(self, method_name='runTest', methodName='runTest'): + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName super().__init__(method_name) test_method = getattr(self, method_name, None) @@ -4427,8 +4440,8 @@ def check_test_defined_in_running_script(test_case): if running_script_path is None: return test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__))) - assert test_case_class_file == running_script_path, f"Class of loaded TestCase \"{test_case.id()}\" " \ - f"is not defined in the running script \"{running_script_path}\", but in \"{test_case_class_file}\". Did you " \ + assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \ + f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \ "accidentally import a unittest.TestCase from another file?" def load_tests(loader, tests, pattern): @@ -5010,6 +5023,7 @@ def repl_frame(m): return m.group(0) s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s) + s = re.sub(r'( len(osd[PG])) - new_pg = copy.deepcopy(dist_osd[PG][0]) + old_dist_osd_pg = dist_osd[_PG] + if len(osd[_PG]) != len(dist_osd[_PG]): + self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG])) + new_pg = copy.deepcopy(dist_osd[_PG][0]) new_pg["params"] = [] - for dist_group in dist_osd[PG]: + for dist_group in dist_osd[_PG]: new_pg["params"].extend(dist_group["params"]) - dist_osd[PG] = [new_pg] + dist_osd[_PG] = [new_pg] - self.assertEqual(len(osd[PG]), len(dist_osd[PG])) - for group, dist_group in zip(osd[PG], dist_osd[PG]): + self.assertEqual(len(osd[_PG]), len(dist_osd[_PG])) + for group, dist_group in zip(osd[_PG], dist_osd[_PG]): self.assertEqual(len(group), len(dist_group)) for key, value in group.items(): # Below doesn't work because param_groups can have None @@ -99,7 +99,7 @@ def _verify_osd( self.assertEqual(sorted(fqns), sorted(dist_value)) else: self.assertEqual(value, dist_value) - dist_osd[PG] = old_dist_osd_pg + dist_osd[_PG] = old_dist_osd_pg def _verify_osd_by_load( self, diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index b9873b9950fa..0ec5dd222444 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -6995,7 +6995,8 @@ def _validate_execution_trace_nccl(self, et_file: str) -> None: """ with open(et_file) as f: et = json.load(f) - + pg_cfg_node = [n for n in et["nodes"] if n["name"] == "## process_group:init ##"] + self.assertGreaterEqual(len(pg_cfg_node), 1) nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"] self.assertEqual(len(nccl_meta_nodes), 3) per_coll_meta = defaultdict(list) @@ -7052,7 +7053,6 @@ def test_ddp_profiling_execution_trace(self): fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() et_file = fp.name - et = ExecutionTraceObserver().register_callback(et_file) # first profiler context need not have ET @@ -9888,6 +9888,71 @@ def test_ddp_update_process_group_new_group(self): def test_ddp_update_process_group_default_group(self): self._run_ddp_update_process_group(new_pg=False) + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_grad_undefined(self): + class SimulateError(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + raise RuntimeError + + class MyModel(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10).cuda(device) + self.fc2 = torch.nn.Linear(10, 10).cuda(device) + self.fc3 = torch.nn.Linear(10, 10).cuda(device) + + def forward(self, inp, error): + if error: + return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp)))) + else: + return self.fc2(self.fc1(inp)) + + + input = torch.rand(10, 10, requires_grad=True).cuda(self.rank) + ddp = torch.nn.parallel.DistributedDataParallel( + MyModel(self.rank), + device_ids=[self.rank], + find_unused_parameters=True, + bucket_cap_mb=1, + ) + + try: + ddp(input, True).sum().backward() + except RuntimeError: + ddp._update_process_group(_get_default_group()) + + # Reset grads. + for param in ddp.parameters(): + param.grad = None + + # Run ddp again. + ddp(input, False).sum().backward() + + @skip_if_lt_x_gpu(4) + @require_world_size(4) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_update_process_group_no_find_unused(self): + ddp = torch.nn.parallel.DistributedDataParallel( + torch.nn.Linear(10, 10).cuda(self.rank), + device_ids=[self.rank], + find_unused_parameters=False, + ) + ddp._update_process_group(_get_default_group()) + + @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["ddp"], diff --git a/torch/testing/_internal/distributed/pipe_with_ddp_test.py b/torch/testing/_internal/distributed/pipe_with_ddp_test.py deleted file mode 100644 index 1ed9f3cc96df..000000000000 --- a/torch/testing/_internal/distributed/pipe_with_ddp_test.py +++ /dev/null @@ -1,149 +0,0 @@ -# mypy: ignore-errors - -import torch -import torch.distributed as dist - -from torch import nn -from torch.nn.parallel import DistributedDataParallel -from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init -from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( - RpcAgentTestFixture, -) -from torch.testing._internal.common_distributed import ( - requires_gloo, - requires_nccl, - skip_if_lt_x_gpu, - skip_if_rocm, -) -from torch.distributed.pipeline.sync import Pipe - -class PipeWithDDPTest(RpcAgentTestFixture): - @property - def world_size(self) -> int: - return 2 - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never(self): - self._run_basic_test("nccl", "never") - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_never_find_unused(self): - self._run_basic_test("nccl", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_always(self): - self._run_basic_test("nccl", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_nccl() - @dist_init - @skip_if_rocm - def test_basic_nccl_ckpt_except_last(self): - self._run_basic_test("nccl", "except_last", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never(self): - self._run_basic_test("gloo", "never") - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_never_find_unused(self): - self._run_basic_test("gloo", "never", find_unused_parameters=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_always(self): - self._run_basic_test("gloo", "always", static_graph=True) - - @skip_if_lt_x_gpu(4) - @requires_gloo() - @dist_init - @skip_if_rocm - def test_basic_gloo_ckpt_except_last(self): - self._run_basic_test("gloo", "except_last", static_graph=True) - - def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): - dist.init_process_group( - backend=backend, - init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), - world_size=self.world_size, - rank=self.rank, - ) - - # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another - # pipe between GPU 2 and 3. Both replicas are replicated via DDP. - fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) - - class MyModule(nn.Module): - def __init__(self, device): - super().__init__() - self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) - self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) - - def forward(self, inp): - if find_unused_parameters: - return self.fc2(inp) - else: - return self.fc3(self.fc2(inp)) - - layer2 = MyModule(2 * self.rank + 1) - model = nn.Sequential( - fc1, - layer2 - ) - model = Pipe(model, chunks=2, checkpoint=checkpoint) - model = DistributedDataParallel( - model, - find_unused_parameters=find_unused_parameters, - static_graph=static_graph, - ) - - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Run forward again for find_unused_parameters to trigger any potential errors. - if find_unused_parameters: - # Ensure inputs are different across ranks to verify that gradient - # sync indeed occurs. - unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - model(unused_param_input).local_value().sum().backward() - - # Run a few more iterations of fwd + bwd to ensure gradient synchronization - # occurs properly across iterations via delay_all_reduce/bucketized allreduce. - for _ in range(3): - model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) - out = model(model_input).local_value() - out.sum().backward() - - # Check grads - output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] - dist.all_gather(output, fc1.weight.grad) - self.assertEqual(output[0], output[1]) - - output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] - dist.all_gather(output, layer2.fc2.weight.grad) - self.assertEqual(output[0], output[1]) - - if not find_unused_parameters: - output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] - dist.all_gather(output, layer2.fc3.weight.grad) - self.assertEqual(output[0], output[1]) diff --git a/torch/testing/_internal/distributed/pipeline/__init__.py b/torch/testing/_internal/distributed/pipeline/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 5d2a67cd473a..764198338636 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -1232,7 +1232,7 @@ def test_self_remote_rref_as_self_remote_arg(self): def test_rref_proxy_non_exist(self): dst = worker_name((self.rank + 1) % self.world_size) rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3)) - msg = "has no attribute \'non_exist\'" + msg = "has no attribute 'non_exist'" with self.assertRaisesRegex(AttributeError, msg): rref.rpc_sync().non_exist() diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index cdbbdcfd0681..5b6e2c90770f 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -16,9 +16,6 @@ DdpComparisonTest, DdpUnderDistAutogradTest, ) -from torch.testing._internal.distributed.pipe_with_ddp_test import ( - PipeWithDDPTest, -) from torch.testing._internal.distributed.nn.api.remote_module_test import ( CudaRemoteModuleTest, RemoteModuleTest, @@ -121,7 +118,6 @@ def tearDown(self): CudaDistAutogradTest, CudaRemoteModuleTest, CudaDdpComparisonTest, - PipeWithDDPTest, ] diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py index eb626b552ce6..3b5c291bc41f 100644 --- a/torch/testing/_internal/dynamo_test_failures.py +++ b/torch/testing/_internal/dynamo_test_failures.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import logging import os import sys diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index e8db1e394b96..1078a189f69c 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -5,7 +5,7 @@ import unittest import functools from subprocess import CalledProcessError - +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton from torch.testing._internal.common_utils import ( diff --git a/torch/testing/_internal/optests/aot_autograd.py b/torch/testing/_internal/optests/aot_autograd.py index 13ce9e883789..4f281c777175 100644 --- a/torch/testing/_internal/optests/aot_autograd.py +++ b/torch/testing/_internal/optests/aot_autograd.py @@ -2,7 +2,7 @@ import torch import torch.utils._pytree as pytree -from torch.testing._internal.common_methods_invocations import wrapper_set_seed +from torch.testing._utils import wrapper_set_seed from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop from .make_fx import randomize import re diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index d01f91563c92..70ee48274800 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -569,7 +569,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if ( torch.jit.is_tracing() or torch.jit.is_scripting() - or torch.compiler.is_compiling() + or torch._dynamo.is_compiling() ): return func(*args, **kwargs) # Pre-existing code may not use the .default overload. If we see an diff --git a/torch/testing/_internal/optests/make_fx.py b/torch/testing/_internal/optests/make_fx.py index 95f746a31af3..83cefd18bc05 100644 --- a/torch/testing/_internal/optests/make_fx.py +++ b/torch/testing/_internal/optests/make_fx.py @@ -2,7 +2,7 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx -from torch.testing._internal.common_methods_invocations import wrapper_set_seed +from torch.testing._utils import wrapper_set_seed import torch.utils._pytree as pytree diff --git a/torch/testing/_internal/static_module.py b/torch/testing/_internal/static_module.py index b39daa380d9d..0a031b0d8f6e 100644 --- a/torch/testing/_internal/static_module.py +++ b/torch/testing/_internal/static_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Owner(s): ["module: unknown"] import torch diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py index 4ae765c206f7..5d127a9a50c4 100644 --- a/torch/testing/_internal/torchbind_impls.py +++ b/torch/testing/_internal/torchbind_impls.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib from typing import Optional diff --git a/torch/testing/_utils.py b/torch/testing/_utils.py new file mode 100644 index 000000000000..50d077cb1649 --- /dev/null +++ b/torch/testing/_utils.py @@ -0,0 +1,51 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + +# Common testing utilities for use in public testing APIs. +# NB: these should all be importable without optional dependencies +# (like numpy and expecttest). + + +def wrapper_set_seed(op, *args, **kwargs): + """Wrapper to set seed manually for some functions like dropout + See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. + """ + with freeze_rng_state(): + torch.manual_seed(42) + output = op(*args, **kwargs) + + if isinstance(output, torch.Tensor) and output.device.type == "lazy": + # We need to call mark step inside freeze_rng_state so that numerics + # match eager execution + torch._lazy.mark_step() # type: ignore[attr-defined] + + return output + + +@contextlib.contextmanager +def freeze_rng_state(): + # no_dispatch needed for test_composite_compliance + # Some OpInfos use freeze_rng_state for rng determinism, but + # test_composite_compliance overrides dispatch for all torch functions + # which we need to disable to get and set rng state + with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + try: + yield + finally: + # Modes are not happy with torch.cuda.set_rng_state + # because it clones the state (which could produce a Tensor Subclass) + # and then grabs the new tensor's data pointer in generator.set_state. + # + # In the long run torch.cuda.set_rng_state should probably be + # an operator. + # + # NB: Mode disable is to avoid running cross-ref tests on thes seeding + with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + torch.set_rng_state(rng_state) diff --git a/torch/types.py b/torch/types.py index 10f091a4b24e..a522d622bcc7 100644 --- a/torch/types.py +++ b/torch/types.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import builtins from typing import Any, List, Optional, Sequence, Tuple, Union diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index ccdad48eca97..24e426a46187 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os.path as _osp import torch @@ -46,6 +47,32 @@ def swap_attr(name): setattr(t1, name, (getattr(t2, name))) setattr(t2, name, tmp) + def error_pre_hook(grad_outputs): + raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors " + "this can happen when you try to run backward on a tensor that was swapped. " + "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` " + "you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) " + "between running forward and backward. To resolve this, please only change the " + "device/dtype before running forward (or after both forward and backward).") + + def check_use_count(t, name='t1'): + use_count = t._use_count() + error_str = (f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} " + f"make sure you are not holding references to the tensor in other places.") + if use_count > 1: + if use_count == 2 and t.is_leaf: + accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node + # Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge + if t._use_count() == 2: + accum_grad_node.register_prehook(error_pre_hook) + else: + raise RuntimeError(error_str) + else: + raise RuntimeError(error_str) + + check_use_count(t1, 't1') + check_use_count(t2, 't2') + # Swap the types # Note that this will fail if there are mismatched slots swap_attr("__class__") diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 6b38645e486b..0e548aa7f741 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import copy diff --git a/torch/utils/_config_typing.pyi b/torch/utils/_config_typing.pyi index b2d99e67fabb..2ebb4c09e33e 100644 --- a/torch/utils/_config_typing.pyi +++ b/torch/utils/_config_typing.pyi @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, Optional, TYPE_CHECKING, Union """ diff --git a/torch/utils/_content_store.py b/torch/utils/_content_store.py index f36837ed674e..dec70d90b7d3 100644 --- a/torch/utils/_content_store.py +++ b/torch/utils/_content_store.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This module provides a FAST (on GPU) content addressable store for storages # (and tensors on top of them) with VERY WEAK portability guarantees (e.g., # don't expect CPU/CUDA to address to the same hash, don't expect it to be diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 59b7d368af26..4f1b991438c0 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Extra utilities for working with context managers that should have been # in the standard library but are not diff --git a/torch/utils/_cpp_extension_versioner.py b/torch/utils/_cpp_extension_versioner.py index 0c09a82413fe..0686e826007d 100644 --- a/torch/utils/_cpp_extension_versioner.py +++ b/torch/utils/_cpp_extension_versioner.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections diff --git a/torch/utils/_device.py b/torch/utils/_device.py index d4909e54c267..c852cd30c775 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional import torch from torch.overrides import TorchFunctionMode diff --git a/torch/utils/_exposed_in.py b/torch/utils/_exposed_in.py index ddd845349916..54faf279ecfc 100644 --- a/torch/utils/_exposed_in.py +++ b/torch/utils/_exposed_in.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Allows one to expose an API in a private submodule publicly as per the definition # in PyTorch's public api policy. # diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index 6f8a9b5b7e23..bcc274579ad0 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -34,12 +34,7 @@ def _group_tensors_by_device_and_dtype( tensorlistlist: TensorListList, with_indices: bool = False, ) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]: - return { - (device, getattr(torch, str_dtype)): value - for (device, str_dtype), value in - torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items() - } - + return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) def _device_has_foreach_support(device: torch.device) -> bool: return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py index c7be90a4baee..f813ca28b81c 100644 --- a/torch/utils/_freeze.py +++ b/torch/utils/_freeze.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ Freeze Python packages. diff --git a/torch/utils/_get_clean_triton.py b/torch/utils/_get_clean_triton.py index ea0e27cf7d5c..70faa6a8e79d 100644 --- a/torch/utils/_get_clean_triton.py +++ b/torch/utils/_get_clean_triton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import os import re diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index b7756a6fa62f..1102fa8a019d 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import importlib.util diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py index c6e3cbb5e940..91c0e07b3d93 100644 --- a/torch/utils/_mode_utils.py +++ b/torch/utils/_mode_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from typing import TypeVar diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index c417f1d9d72a..36a4ff65af6f 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import warnings diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 989be9b2d617..b4a0db5db730 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -298,6 +298,7 @@ def _register_pytree_node( "`to_str_fn` and `maybe_from_str_fn` is deprecated. " "Please use `to_dumpable_context` and `from_dumpable_context` instead.", FutureWarning, + stacklevel=2, ) _private_register_pytree_node( diff --git a/torch/utils/_stats.py b/torch/utils/_stats.py index 5b33f7b8cb02..c11cbd5df270 100644 --- a/torch/utils/_stats.py +++ b/torch/utils/_stats.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE. # IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils # AND SCRUB AWAY TORCH NOTIONS THERE. diff --git a/torch/utils/_strobelight/examples/cli_function_profiler_example.py b/torch/utils/_strobelight/examples/cli_function_profiler_example.py index d97f339ba081..222a70c9fe2d 100644 --- a/torch/utils/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/utils/_strobelight/examples/cli_function_profiler_example.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils._strobelight.cli_function_profiler import ( diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 1384261b4512..7b8387303336 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,43 +1,80 @@ +# mypy: allow-untyped-defs +import functools import math +import operator +import sys import sympy from sympy import S -from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "Pow", - "TrueDiv", + "IntTrueDiv", + "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "Round", + "RoundToInt", "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", ] +def _keep_float(f): + @functools.wraps(f) + def inner(*args): + r = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + return inner + + def fuzzy_eq(x, y): if None in (x, y): return None return x == y +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - # Default return type for SymPy assumptions. - # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers - is_real = True + is_integer = True @property def base(self): @@ -52,29 +89,14 @@ def _sympystr(self, printer): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" - # SymPy assumptions based on argument types. - def _eval_is_real(self): - return fuzzy_or([self.base.is_real, self.divisor.is_real]) - - def _eval_is_integer(self): - return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) - # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - def check_supported_type(x): - if ( - x.is_integer is False and x.is_real is False and x.is_complex - ) or x.is_Boolean: - raise TypeError( - f"unsupported operand type(s) for //: " - f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real" - ) - - check_supported_type(base) - check_supported_type(divisor) + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -85,26 +107,22 @@ def check_supported_type(x): return sympy.S.Zero if base.is_integer and divisor == 1: return base - if base.is_real and divisor == 1: - return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( - divisor, (sympy.Integer, sympy.Float) - ): - return sympy.floor(base / divisor) + return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) - if isinstance(divisor, sympy.Rational) and divisor.p == 1: - return sympy.floor(base * divisor.q) + # gcd in sympy is over polynomials, so you'll end up with rationals if + # you do this. Don't. + """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd + """ try: gcd = sympy.gcd(base, divisor) @@ -189,6 +207,19 @@ class Where(sympy.Function): nargs = (3,) + def _eval_is_integer(self): + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self): + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -197,28 +228,27 @@ def eval(cls, c, p, q): return q -class Mod(sympy.Function): - """ - We maintain this so that we avoid SymPy correctness issues, such as: - https://github.com/sympy/sympy/issues/25146 - """ - +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): nargs = (2,) + is_integer = True + @classmethod def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q if q.is_zero: raise ZeroDivisionError("Modulo by zero") - # If either of them is NaN or infinite. - if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: - return S.NaN + # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): + if p is S.Zero or p in (q, -q) or q == 1: return S.Zero # Evaluate if they are both literals. @@ -247,10 +277,7 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero - def _eval_is_integer(self): - p, q = self.args - return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] - + # NB: args[1] for PythonMod def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -258,6 +285,58 @@ def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + assert p >= 0, p + assert q >= 1, q + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -267,6 +346,36 @@ class CleanDiv(FloorDiv): pass +# Don't use sympy ceiling/floor as they will attempt simplifications involving +# frac +class CeilToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.ceil(float(number))) + + +class FloorToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.floor(float(number))) + + class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. @@ -275,6 +384,8 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -282,6 +393,8 @@ def __new__(cls, base, divisor): class LShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -290,6 +403,8 @@ def eval(cls, base, shift): class RShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -297,28 +412,107 @@ def eval(cls, base, shift): return base // 2**shift -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class Pow(sympy.Function): +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp > sys.maxsize - 1: + return sys.maxsize - 1 + + result = half_exp * half_exp + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, exp): - if exp.is_zero: - return sympy.Integer(1) - elif base.is_zero and exp < 0: - raise ZeroDivisionError(f"{base} cannot be raised to a negative power") - else: - return base**exp + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Integer(safe_pow(base, exp)) + if isinstance(exp, sympy.Integer): + # Translate power into iterated multiplication + r = sympy.Integer(1) + for _ in range(int(exp)): + r *= base + return r + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occuring +class FloatPow(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -class TrueDiv(sympy.Function): +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + if divisor.is_zero: raise ZeroDivisionError("division by zero") - else: - return base / divisor + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, divisor): + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(int(base) / int(divisor)) # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -334,64 +528,133 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function): def eval(cls, *args): assert len(args) % 2 == 0 dim = len(args) // 2 - # TODO: it is possible to make progress evaluating this guard - # even if not all of the inputs are known. For example, a 2D - # tensor with non-0/1 sizes but strides (0, 1) is definitely - # false, because we know its numel > 1 but it's broadcasted - # in dim 0. + sizes = args[0:dim] + strides = args[dim:] + + # sym_node imported in torch.__init__. Local import to avoid an import cycle + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + if all(isinstance(a, sympy.Integer) for a in args): - # sym_node imported in torch.__init__. Local import to avoid an import cycle - from torch.fx.experimental.symbolic_shapes import ( - eval_is_non_overlapping_and_dense, + return eval_is_non_overlapping_and_dense( + [int(a) for a in sizes], [int(a) for a in strides] ) - size_args = args[0:dim] - stride_args = args[dim:] - return eval_is_non_overlapping_and_dense( - [int(a) for a in size_args], [int(a) for a in stride_args] + if dim == 1: + # Manually implement the rank one short circuit + if strides[0].is_Number and strides[0] == 1: + return 1 + + if sizes[0].is_Number and sizes[0] < 2: + return 1 + + # return 0 case covered by case above + + # TODO: Inability to access size-obliviousness sucks: if we have a + # size oblivious test on a size-like unbacked SymInt, we could + # confidently return zero when we have a size-like u0 stride + # and a size-like u1 size. Maybe a fancy ValueRanges analysis for + # this function could help figure this out. + + if all(isinstance(a, sympy.Integer) for a in strides): + assert dim != 0 + # When all strides are integral, we can sort, and the size for the + # largest stride doesn't matter and can be arbitrarily symbolic + s_sizes, s_strides = zip( + *sorted(zip(sizes, strides), key=operator.itemgetter(1)) ) + # Put something arbitrary in the max size spot, it'll be ignored + if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): + s_sizes = s_sizes[:-1] + (42,) + # We can reuse the regular eval, because it is invariant to + # permutation of dimensions + return eval_is_non_overlapping_and_dense( + [int(a) for a in s_sizes], [int(a) for a in s_strides] + ) + return None -class Trunc(sympy.Function): +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -class Round(sympy.Function): +# This is float -> int +class RoundToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): - return sympy.Integer(round(float(number))) + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float): + return sympy.Integer(round(float(number), 0)) + - def __int__(self): - # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and - # no longer an expression. If it were, the float call would fail and the caller would handle this further. - return round(float(self.args[0])) # type: ignore[arg-type] +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } +# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, number, ndigits): - if number.is_integer and ndigits >= 0: + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: return number - elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = ( - (int, sympy.Integer) - if isinstance(number, sympy.Integer) - else (float, sympy.Float) - ) - return output_type(round(value_type(number), int(ndigits))) + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 806e91cfe281..640b991cd104 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This is a simple interpreter for Sympy expressions that dispatches to classes following the torch._inductor.virtualized calling convention. @@ -15,16 +16,23 @@ import torch from .functions import ( + CeilToInt, CleanDiv, + FloatPow, + FloatTrueDiv, FloorDiv, + FloorToInt, + IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - Pow, - Round, + PowByNatural, + PythonMod, RoundDecimal, - TrueDiv, - Trunc, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, Where, ) @@ -49,30 +57,39 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - TrueDiv: "truediv", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "div", - Trunc: "trunc", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - Pow: "pow", - sympy.Pow: "pow", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", Mod: "mod", + PythonMod: "mod", # TODO: this is wrong + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", - sympy.floor: "floor", - sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - Round: "round", - RoundDecimal: "round", + RoundDecimal: "round_decimal", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -84,7 +101,11 @@ def handlers(): def sympy_interp( - analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] + analysis, + env: Dict[sympy.Symbol, Any], + expr: Union[sympy.Expr, SympyBoolean], + *, + index_dtype=torch.int64, ): # Handle base cases dtype = None @@ -105,9 +126,32 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) + if isinstance(expr, ToFloat): + return analysis.to_dtype( + sympy_interp(analysis, env, expr.args[0]), torch.float64 + ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + FloorToInt: "floor_to_int", + CeilToInt: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 8bd688b0c0c9..156891ac5497 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,12 +1,26 @@ +# mypy: allow-untyped-defs import math +import operator + import sympy import torch from torch.utils._sympy.functions import ( + _keep_float, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, ) @@ -62,18 +76,41 @@ def not_(a): @staticmethod def reciprocal(x): - return 1 / x + return FloatTrueDiv(1.0, x) @staticmethod def square(x): - return x * x + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def mod(x, y): - ret = abs(x) % abs(y) - if x < 0: - ret *= -1 - return ret + return Mod(x, y) @staticmethod def abs(x): @@ -85,37 +122,31 @@ def neg(x): @staticmethod def truediv(a, b): - return a / b + return FloatTrueDiv(a, b) @staticmethod - def div(a, b): - return ReferenceAnalysis.truediv(a, b) + def int_truediv(a, b): + return IntTrueDiv(a, b) @staticmethod def floordiv(a, b): - if b == 0: - return sympy.nan if a == 0 else sympy.zoo - return a // b + return FloorDiv(a, b) @staticmethod def truncdiv(a, b): - result = a / b - if result.is_finite: - result = sympy.Integer(result) - - return result + raise NotImplementedError("TODO: truncdiv") @staticmethod def add(a, b): - return a + b + return _keep_float(operator.add)(a, b) @staticmethod def mul(a, b): - return a * b + return _keep_float(operator.mul)(a, b) @staticmethod def sub(a, b): - return a - b + return _keep_float(operator.sub)(a, b) @staticmethod def exp(x): @@ -131,39 +162,27 @@ def sqrt(x): @staticmethod def pow(a, b): - return a**b + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) @staticmethod def minimum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Min(result_type(a), result_type(b)) + return sympy.Min(a, b) @staticmethod def maximum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Max(result_type(a), result_type(b)) + return sympy.Max(a, b) @staticmethod - def floor(x): - return sympy.floor(x) + def round_to_int(a, dtype): + return RoundToInt(a) @staticmethod - def ceil(x): - return sympy.ceiling(x) + def round_decimal(a, b): + return RoundDecimal(a, b) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -189,10 +208,20 @@ def not_(a): def floordiv(a, b): return a // b + @staticmethod + def mod(x, y): + return x % y + @staticmethod def truncdiv(a, b): return a / b + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -214,9 +243,40 @@ def maximum(a, b): return torch.sym_max(a, b) @staticmethod - def floor(x): + def floor_to_int(x, dtype): return math.floor(x) @staticmethod - def ceil(x): + def ceil_to_int(x, dtype): return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py index 870bda554e74..1b5e8a96104f 100644 --- a/torch/utils/_sympy/singleton_int.py +++ b/torch/utils/_sympy/singleton_int.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sympy from sympy.multipledispatch import dispatch diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 6276c696293c..02ddf7c34219 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,6 +88,7 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/symbol.py b/torch/utils/_sympy/symbol.py index 89908a09e197..bd853faee6d2 100644 --- a/torch/utils/_sympy/symbol.py +++ b/torch/utils/_sympy/symbol.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """ This file contains canonical definitions for our symbol naming conventions, across torch.fx.experimental.symbolic_shapes and torch._inductor. The diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c7cc96beb980..97f47c4f28ac 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import dataclasses @@ -5,6 +6,7 @@ import logging import math import operator +import sys from typing import ( Callable, Dict, @@ -25,17 +27,20 @@ from torch._prims_common import dtype_to_type from .functions import ( - OpaqueUnaryFn_acos, - OpaqueUnaryFn_asinh, - OpaqueUnaryFn_atan, - OpaqueUnaryFn_cosh, + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, - OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - OpaqueUnaryFn_tanh, - Round, + PowByNatural, RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, ) from .interp import sympy_interp @@ -83,7 +88,10 @@ def sympy_generic_le(lower, upper): return lower <= upper else: # only negative condition is True > False - assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean) + assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( + lower, + upper, + ) return not (lower and not upper) @@ -120,6 +128,11 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool + is_int: bool + is_float: bool + + def __repr__(self) -> str: + return f"VR[{self.lower}, {self.upper}]" @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -142,8 +155,39 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - assert isinstance(upper, SympyBoolean) == self.is_bool + if self.is_bool: + assert isinstance(upper, SympyBoolean), (lower, upper) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + + object.__setattr__( + self, + "is_int", + not self.is_bool + and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + assert self.is_bool or self.is_int or self.is_float, (lower, upper) def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -184,6 +228,8 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -353,7 +399,12 @@ def constant(value, dtype): # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - return ValueRanges.unknown() + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges(-sys.maxsize - 1, sys.maxsize) if is_python: type_ = dtype_to_type(dtype) @@ -369,7 +420,18 @@ def constant(value, dtype): # dtype is intXX assert value.is_integer - return ValueRanges.wrap(value) + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + return ValueRanges.increasing_map(a, ToFloat) + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + return ValueRanges.increasing_map(a, TruncToInt) @staticmethod def not_(a): @@ -428,7 +490,9 @@ def ge(cls, a, b): @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) @classmethod def mul(cls, a, b): @@ -448,11 +512,20 @@ def safe_mul(a, b): else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) - @classmethod - def div(cls, a, b): - return cls.truediv(a, b) + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(IntTrueDiv) + ) @staticmethod def truediv(a, b): @@ -463,18 +536,22 @@ def truediv(a, b): ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(FloatTrueDiv) + ) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + # TODO: make this more precise + (-sympy.oo in a or sympy.oo in a) + or (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) + return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) @classmethod def mod(cls, x, y): @@ -523,17 +600,51 @@ def modular_indexing(cls, a, b, c): @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() + return ValueRanges.unknown() # TODO: type here is wrong @classmethod - def pow(cls, a, b): - def is_integer(val): - return isinstance(val, int) or ( - hasattr(val, "is_integer") and val.is_integer + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) + ) + + @classmethod + def pow(cls, a, b): + return ValueRanges.unknown() + # We could implement all this, but for floating point pow, is there + # really a point? + """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -553,8 +664,7 @@ def is_integer(val): if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - type_ = sympy.Float if a.lower.is_real else sympy.Integer - return ValueRanges.wrap(type_(1)) + return ValueRanges.wrap(1.0) if b < 0: a = cls.reciprocal(a) @@ -563,21 +673,12 @@ def is_integer(val): if a == ValueRanges.unknown(): return ValueRanges.unknown() - # Here b > 0 - if not is_integer(b): - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) - else: - return ValueRanges.unknown() + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) else: - # b > 0 integer - if b % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map(a, lambda x: x**b) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: x**b) + return ValueRanges.unknown() + """ @staticmethod def reciprocal(x): @@ -586,7 +687,7 @@ def reciprocal(x): if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: 1 / y) + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) # type: ignore[operator] @staticmethod def abs(x): @@ -615,45 +716,64 @@ def maximum(cls, a, b): def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) - # Performs upcasting first - def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: - # Poorman's version of upcasting in Sympy - # Inf is not a float... - if x.is_Integer and y.is_Integer: - result_type = sympy.Integer - elif x.is_rational and y.is_rational: - result_type = sympy.Rational - else: - assert x.is_real or not x.is_finite or y.is_real or not y.is_finite - result_type = sympy.Float - return fn(result_type(x), result_type(y)) + @classmethod + def floor_to_int(cls, x, dtype): + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) - return ValueRanges.coordinatewise_increasing_map(a, b, fn_) + @classmethod + def ceil_to_int(cls, x, dtype): + return ValueRanges.increasing_map( + x, sympy.functions.elementary.integers.ceiling + ) + + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. @classmethod def floor(cls, x): - return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) @classmethod def ceil(cls, x): return ValueRanges.increasing_map( - x, sympy.functions.elementary.integers.ceiling + x, _keep_float(sympy.functions.elementary.integers.ceiling) ) @classmethod - def round(cls, number, ndigits=None): - if ndigits is None: - fn = Round - else: - assert ndigits.is_singleton() - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) + @classmethod + def round_to_int(cls, number, dtype): + return ValueRanges.increasing_map(number, RoundToInt) + # It's used in some models on symints @staticmethod def sqrt(x): @@ -708,12 +828,15 @@ def cos(x): @staticmethod def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) + """ @staticmethod def sin(x): @@ -723,7 +846,8 @@ def sin(x): @staticmethod def sinh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def tan(x): @@ -731,32 +855,37 @@ def tan(x): @staticmethod def tanh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() + """ @staticmethod def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() + """ @staticmethod def atan(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - def trunc(x): - return sympy.Integer(x) if x.is_finite else x - - return ValueRanges.increasing_map(x, trunc) + return ValueRanges.increasing_map(x, TruncToFloat) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -791,9 +920,10 @@ def store(self, name, index, value, mode=None): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - def index_expr(self, index, dtype): + @classmethod + def index_expr(cls, index, dtype): assert isinstance(index, ValueRanges) - return index + return cls.to_dtype(index, dtype) @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -802,6 +932,8 @@ def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): if dtype == torch.bool: if x.is_singleton(): return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x elif 0 not in x: return ValueRanges.wrap(sympy.true) else: @@ -830,12 +962,15 @@ def cast(x, dtype): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: y * y) + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -856,6 +991,7 @@ def __getattr__(self, name): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: + log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr) diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index 9f4d04c55105..aa3944d41708 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from types import TracebackType from typing import List, Optional import tempfile diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 9184f782cc73..ff8a5fc73b64 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import hashlib @@ -61,7 +62,9 @@ def triton_hash_with_backend(): backend = triton_backend() key = f"{triton_key()}-{backend.hash()}" - return hashlib.sha256(key.encode("utf-8")).hexdigest() + + # Hash is upper case so that it can't contain any Python keywords. + return hashlib.sha256(key.encode("utf-8")).hexdigest().upper() def dtype_to_string(dtype): diff --git a/torch/utils/_zip.py b/torch/utils/_zip.py index f37ddb449878..c7dd6445fabe 100644 --- a/torch/utils/_zip.py +++ b/torch/utils/_zip.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import glob import os diff --git a/torch/utils/backcompat/__init__.py b/torch/utils/backcompat/__init__.py index fdd16eec5aca..6a53076c90a6 100644 --- a/torch/utils/backcompat/__init__.py +++ b/torch/utils/backcompat/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch._C import _set_backcompat_broadcast_warn from torch._C import _get_backcompat_broadcast_warn from torch._C import _set_backcompat_keepdim_warn diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 6a4cbcb8436b..6f3444116f3a 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.overrides import ( handle_torch_function, diff --git a/torch/utils/benchmark/examples/blas_compare_setup.py b/torch/utils/benchmark/examples/blas_compare_setup.py index c08acb50950f..323138d19ddd 100644 --- a/torch/utils/benchmark/examples/blas_compare_setup.py +++ b/torch/utils/benchmark/examples/blas_compare_setup.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import collections import os import shutil @@ -183,7 +184,7 @@ def main(): check_run = subprocess.run( # Shameless abuse of `python -c ...` f"source activate {env_path} && " - "python -c \"" + 'python -c "' "import torch;" "from torch.utils.benchmark import Timer;" "print(torch.__config__.show());" diff --git a/torch/utils/benchmark/examples/compare.py b/torch/utils/benchmark/examples/compare.py index 6f99d9d06ad5..5d797a5b0a2b 100644 --- a/torch/utils/benchmark/examples/compare.py +++ b/torch/utils/benchmark/examples/compare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of Timer and Compare APIs: $ python -m examples.compare diff --git a/torch/utils/benchmark/examples/fuzzer.py b/torch/utils/benchmark/examples/fuzzer.py index 9728bf3d26c9..ee2c9f9c04ed 100644 --- a/torch/utils/benchmark/examples/fuzzer.py +++ b/torch/utils/benchmark/examples/fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of the Timer and Fuzzer APIs: $ python -m examples.fuzzer diff --git a/torch/utils/benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py index e2f0861d20ac..cdf3a7853d73 100644 --- a/torch/utils/benchmark/examples/op_benchmark.py +++ b/torch/utils/benchmark/examples/op_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example use of Timer and op fuzzers to measure kernel performance. $ python -m examples.op_benchmark diff --git a/torch/utils/benchmark/examples/simple_timeit.py b/torch/utils/benchmark/examples/simple_timeit.py index 81aaa6dee981..390b88f59e70 100644 --- a/torch/utils/benchmark/examples/simple_timeit.py +++ b/torch/utils/benchmark/examples/simple_timeit.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Trivial use of Timer API: $ python -m examples.simple_timeit diff --git a/torch/utils/benchmark/examples/sparse/compare.py b/torch/utils/benchmark/examples/sparse/compare.py index 4adbd6d2b35e..640912e0167e 100644 --- a/torch/utils/benchmark/examples/sparse/compare.py +++ b/torch/utils/benchmark/examples/sparse/compare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of Timer and Compare APIs: $ python -m examples.sparse.compare diff --git a/torch/utils/benchmark/examples/sparse/fuzzer.py b/torch/utils/benchmark/examples/sparse/fuzzer.py index 38421474ccf8..8f3885839d3f 100644 --- a/torch/utils/benchmark/examples/sparse/fuzzer.py +++ b/torch/utils/benchmark/examples/sparse/fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example of the Timer and Sparse Fuzzer APIs: $ python -m examples.sparse.fuzzer diff --git a/torch/utils/benchmark/examples/sparse/op_benchmark.py b/torch/utils/benchmark/examples/sparse/op_benchmark.py index f998f6d5db47..3efb75e8ea13 100644 --- a/torch/utils/benchmark/examples/sparse/op_benchmark.py +++ b/torch/utils/benchmark/examples/sparse/op_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Example use of Timer and sparse op fuzzers to measure kernel performance. $ python -m examples.sparse.op_benchmark diff --git a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py index 3ac54059416c..a3c8cbe5b12c 100644 --- a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py +++ b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Microbenchmarks for the torch.fft module""" from argparse import ArgumentParser from collections import namedtuple diff --git a/torch/utils/benchmark/op_fuzzers/binary.py b/torch/utils/benchmark/op_fuzzers/binary.py index 91289d88db8a..75f394179b3e 100644 --- a/torch/utils/benchmark/op_fuzzers/binary.py +++ b/torch/utils/benchmark/op_fuzzers/binary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/torch/utils/benchmark/op_fuzzers/sparse_binary.py index 984493fe4a71..014361877dea 100644 --- a/torch/utils/benchmark/op_fuzzers/sparse_binary.py +++ b/torch/utils/benchmark/op_fuzzers/sparse_binary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/torch/utils/benchmark/op_fuzzers/sparse_unary.py index 70b5ae3cd3a5..f6fe622183f6 100644 --- a/torch/utils/benchmark/op_fuzzers/sparse_unary.py +++ b/torch/utils/benchmark/op_fuzzers/sparse_unary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/op_fuzzers/spectral.py b/torch/utils/benchmark/op_fuzzers/spectral.py index 29359ba3edb6..2b9e92d7a2c7 100644 --- a/torch/utils/benchmark/op_fuzzers/spectral.py +++ b/torch/utils/benchmark/op_fuzzers/spectral.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import torch diff --git a/torch/utils/benchmark/op_fuzzers/unary.py b/torch/utils/benchmark/op_fuzzers/unary.py index a0f810d0b9fa..e780b421f24c 100644 --- a/torch/utils/benchmark/op_fuzzers/unary.py +++ b/torch/utils/benchmark/op_fuzzers/unary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np import torch diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 20122df66718..36c5a77cd1eb 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Display class to aggregate and print the results of many measurements.""" import collections import enum diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index dcee32ace403..fa8f6b63b437 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch __all__ = ["bench_all", "benchmark_compile"] diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index 7d1ee8ebb8f8..08206efce377 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools import itertools as it from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index eac6a6baf910..5d3cd051e1de 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional, Tuple, Union from numbers import Number import torch diff --git a/torch/utils/bottleneck/__main__.py b/torch/utils/bottleneck/__main__.py index 4444211a0f87..9b23b1483fe0 100644 --- a/torch/utils/bottleneck/__main__.py +++ b/torch/utils/bottleneck/__main__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import argparse import cProfile import pstats diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index 201a000b3006..21fa4e50396d 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union, Sequence, Dict, Callable import textwrap import torch diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index a98c9b2059b8..5cbfd1543cf4 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import contextlib import platform import uuid @@ -1395,7 +1396,7 @@ def recompute_fn(*inputs): device_autocast_ctx = torch.amp.autocast( device_type=device, **device_autocast_kwargs ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() - with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] + with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] fn(*args, **kwargs) new_frame = _CheckpointFrame( diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 6cbf598156b0..ed0e02c4c1b9 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Unlike the rest of the PyTorch this file must be python2 compliant. # This script outputs relevant system environment info @@ -434,6 +435,21 @@ def is_xnnpack_available(): return "N/A" def get_env_info(): + """ + Collects environment information to aid in debugging. + + The returned environment information contains details on torch version, is debug build + or not, cuda compiled version, gcc version, clang version, cmake version, operating + system, libc version, python version, python platform, CUDA availability, CUDA + runtime version, CUDA module loading config, GPU model and configuration, Nvidia + driver version, cuDNN version, pip version and versions of relevant pip and + conda packages, HIP runtime version, MIOpen runtime version, + Caching allocator config, XNNPACK availability and CPU information. + + Returns: + SystemEnv (namedtuple): A tuple containining various environment details + and system information. + """ run_lambda = run pip_version, pip_list_output = get_pip_packages(run_lambda) @@ -599,6 +615,17 @@ def maybe_start_on_next_line(string): def get_pretty_env_info(): + """ + Returns a pretty string of environment information. + + This function retrieves environment information by calling the `get_env_info` function + and then formats the information into a human-readable string. The retrieved environment + information is listed in the document of `get_env_info`. + This function is used in `python collect_env.py` that should be executed when reporting a bug. + + Returns: + str: A pretty string of the environment information. + """ return pretty_str(get_env_info()) diff --git a/torch/utils/cpp_backtrace.py b/torch/utils/cpp_backtrace.py index 40dbbb5b913a..af4a7fcb63e2 100644 --- a/torch/utils/cpp_backtrace.py +++ b/torch/utils/cpp_backtrace.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch._C import _get_cpp_backtrace def get_cpp_backtrace(frames_to_skip=0, maximum_number_of_frames=64) -> str: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 913947ea84c7..1904f8c3ecae 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import glob import importlib diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py index 62cfdf91f1ea..7c2b452c15cb 100644 --- a/torch/utils/data/_utils/__init__.py +++ b/torch/utils/data/_utils/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. A lot of multiprocessing is used in data loading, which only supports running diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 4c17597bd6f1..1f705c09f0f4 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These methods are used to collate samples fetched from dataset into Tensor(s). diff --git a/torch/utils/data/_utils/fetch.py b/torch/utils/data/_utils/fetch.py index 553c516ff3ce..3fa6c49404f6 100644 --- a/torch/utils/data/_utils/fetch.py +++ b/torch/utils/data/_utils/fetch.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. This logic is shared in both single- and multi-processing data loading. diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index 9de645cd7ee7..ecb7f8875f23 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. These **needs** to be in global scope since Py2 doesn't support serializing diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py index da8f3780bed2..6f0219e91c27 100644 --- a/torch/utils/data/_utils/signal_handling.py +++ b/torch/utils/data/_utils/signal_handling.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Signal handling for multiprocessing data loading. NOTE [ Signal handling in multiprocessing data loading ] diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 137791c4c436..849f4b9300fe 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing diff --git a/torch/utils/data/backward_compatibility.py b/torch/utils/data/backward_compatibility.py index f51418265f41..e8f1c4e30ef7 100644 --- a/torch/utils/data/backward_compatibility.py +++ b/torch/utils/data/backward_compatibility.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing_extensions import deprecated as _deprecated diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 80784f2ec362..9ad0db898a04 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. To support these two classes, in `./_utils` we define many utility methods and diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 93ef42076c21..9c5b25d7f22d 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect from functools import wraps from typing import Any, Callable, Optional, Type, Union, get_type_hints diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index 49e17438d60e..00b44cbede61 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import functools from enum import Enum diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 08d54bfb31ad..f3fe402690b6 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # Taking reference from official Python typing # https://github.com/python/cpython/blob/master/Lib/typing.py diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py index 9a03a8f00efc..67c5b5408b50 100644 --- a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Optional _pandas: Any = None diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index a93ea6ba2d82..677104538b23 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Any, Dict, List, Optional from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index a75cc5c7a7c2..de0bb8246fb5 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import random from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/dataframe/structures.py b/torch/utils/data/datapipes/dataframe/structures.py index 507a04e491d3..ad5f6f6d588e 100644 --- a/torch/utils/data/datapipes/dataframe/structures.py +++ b/torch/utils/data/datapipes/dataframe/structures.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes.datapipe import DataChunk from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index 6b3cbe34b46a..4d03665d5d66 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection # The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py index c0f8a801bd07..e2b3ad966a21 100644 --- a/torch/utils/data/datapipes/gen_pyi.py +++ b/torch/utils/data/datapipes/gen_pyi.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import pathlib from collections import defaultdict @@ -44,10 +45,10 @@ def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str def extract_method_name(line: str) -> str: """Extract method name from decorator in the form of "@functional_datapipe({method_name})".""" - if "(\"" in line: - start_token, end_token = "(\"", "\")" - elif "(\'" in line: - start_token, end_token = "(\'", "\')" + if '("' in line: + start_token, end_token = '("', '")' + elif "('" in line: + start_token, end_token = "('", "')" else: raise RuntimeError(f"Unable to find appropriate method name within line:\n{line}") start, end = line.find(start_token) + len(start_token), line.find(end_token) @@ -71,9 +72,9 @@ def parse_datapipe_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str], method_name, class_name, signature = "", "", "" skip = False for line in f: - if line.count("\"\"\"") % 2 == 1: + if line.count('"""') % 2 == 1: skip = not skip - if skip or "\"\"\"" in line: # Saving docstrings + if skip or '"""' in line: # Saving docstrings doc_string_dict[method_name].append(line) continue if "@functional_datapipe" in line: diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 9a67cc0592ff..f29c96e886e6 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import functools from collections import namedtuple diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 16d2f5444dcd..b86b28f9d7e1 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import random import torch diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 9a4365516a33..878d885c2042 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod diff --git a/torch/utils/data/datapipes/iter/filelister.py b/torch/utils/data/datapipes/iter/filelister.py index bb10fe4c4965..7384a3a26cb8 100644 --- a/torch/utils/data/datapipes/iter/filelister.py +++ b/torch/utils/data/datapipes/iter/filelister.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Iterator, List, Sequence, Union diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 67e9797fe335..b58ee14a4378 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from io import IOBase from typing import Iterable, Tuple, Optional diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index c11804ea2cc0..31aa90af5451 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import warnings from collections import defaultdict from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index fee74582e61b..5910ab0da2ec 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Callable, Iterator, Tuple, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/iter/sharding.py b/torch/utils/data/datapipes/iter/sharding.py index f5bd3261fc1b..f493af685fb4 100644 --- a/torch/utils/data/datapipes/iter/sharding.py +++ b/torch/utils/data/datapipes/iter/sharding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import ( Dict, Sized, diff --git a/torch/utils/data/datapipes/iter/streamreader.py b/torch/utils/data/datapipes/iter/streamreader.py index 9fd80e94e509..4e379db92bc5 100644 --- a/torch/utils/data/datapipes/iter/streamreader.py +++ b/torch/utils/data/datapipes/iter/streamreader.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Tuple from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe diff --git a/torch/utils/data/datapipes/iter/utils.py b/torch/utils/data/datapipes/iter/utils.py index 3794f7f0e778..096188b1369e 100644 --- a/torch/utils/data/datapipes/iter/utils.py +++ b/torch/utils/data/datapipes/iter/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import warnings from torch.utils.data.datapipes.datapipe import IterDataPipe diff --git a/torch/utils/data/datapipes/map/callable.py b/torch/utils/data/datapipes/map/callable.py index c9202bb1eefb..9ddd51ba9bb1 100644 --- a/torch/utils/data/datapipes/map/callable.py +++ b/torch/utils/data/datapipes/map/callable.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from typing import Callable, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe diff --git a/torch/utils/data/datapipes/map/combinatorics.py b/torch/utils/data/datapipes/map/combinatorics.py index c21d532d4925..7b435ce7c130 100644 --- a/torch/utils/data/datapipes/map/combinatorics.py +++ b/torch/utils/data/datapipes/map/combinatorics.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import random import torch diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index 809b44dc96cd..731418239ba0 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import MapDataPipe from typing import Sized, Tuple, TypeVar diff --git a/torch/utils/data/datapipes/map/grouping.py b/torch/utils/data/datapipes/map/grouping.py index a94cc7b5679e..d5d216158acd 100644 --- a/torch/utils/data/datapipes/map/grouping.py +++ b/torch/utils/data/datapipes/map/grouping.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk from typing import List, Sized, TypeVar diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 18d4fd18a193..d22e708c1538 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import copy import warnings from torch.utils.data.datapipes.datapipe import MapDataPipe diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 3c466d3392ad..3e8e99c4b154 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import fnmatch import functools import inspect diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index 0211a8fe4ba4..b465f3a0aaa6 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs # This file takes partial of the implementation from NVIDIA's webdataset at here: # https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py @@ -28,7 +29,34 @@ ################################################################ # handle basic datatypes ################################################################ -def basichandlers(extension, data): +def basichandlers(extension: str, data): + """Transforms raw data (byte stream) into python objects. + + Looks at the extension and loads the data into a python object supporting + the corresponding extension. + + Args: + extension (str): The file extension + data (byte stream): Data to load into a python object. + + Returns: + object: The data loaded into a corresponding python object + supporting the extension. + + Example: + >>> import pickle + >>> data = pickle.dumps('some data') + >>> new_data = basichandlers('pickle', data) + >>> new_data + some data + + The transformation of data for extensions are: + - txt, text, transcript: utf-8 decoded data of str format + - cls, cls2, class, count, index, inx, id: int + - json, jsn: json loaded data + - pickle, pyd: pickle loaded data + - pt: torch loaded data + """ if extension in "txt text transcript": return data.decode("utf-8") diff --git a/torch/utils/data/datapipes/utils/snapshot.py b/torch/utils/data/datapipes/utils/snapshot.py index 02487d0da573..8b2266d15d62 100644 --- a/torch/utils/data/datapipes/utils/snapshot.py +++ b/torch/utils/data/datapipes/utils/snapshot.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from torch.utils.data.datapipes._hook_iterator import _SnapshotState from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.graph_settings import apply_random_seed diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index b3cf9d92943d..6ce4b67bfb06 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import bisect import itertools import math diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index cd78db474d5e..d3a882e58595 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import io import pickle import warnings diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index 573069279201..f9de29df288e 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import inspect import warnings diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 4c4c967ef9a9..c6ad6933fb49 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -1,5 +1,5 @@ +# mypy: allow-untyped-defs import torch -from torch import Tensor from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union @@ -212,7 +212,7 @@ class WeightedRandomSampler(Sampler[int]): [0, 1, 4, 3, 2] """ - weights: Tensor + weights: torch.Tensor num_samples: int replacement: bool diff --git a/torch/utils/deterministic.py b/torch/utils/deterministic.py index 98a6d30b067b..a055c43be531 100644 --- a/torch/utils/deterministic.py +++ b/torch/utils/deterministic.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import sys import types diff --git a/torch/utils/file_baton.py b/torch/utils/file_baton.py index b55db82b8532..77ee5091b3f7 100644 --- a/torch/utils/file_baton.py +++ b/torch/utils/file_baton.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import os import time diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index d7080c9e4e38..a4f05c6c720b 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from .module_tracker import ModuleTracker @@ -242,7 +243,9 @@ def sdpa_flop_count(query_shape, key_shape, value_shape): return total_flops -@register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention]) +@register_flop_formula([aten._scaled_dot_product_efficient_attention, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_cudnn_attention]) def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for self-attention.""" # NB: We aren't accounting for causal attention here @@ -434,7 +437,9 @@ def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape return total_flops -@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward]) +@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, + aten._scaled_dot_product_flash_attention_backward, + aten._scaled_dot_product_cudnn_attention_backward]) def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: """Count flops for self-attention backward.""" return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) @@ -515,8 +520,10 @@ def _efficient_attention_backward_flop( aten.convolution_backward: conv_backward_flop, aten._scaled_dot_product_efficient_attention: sdpa_flop, aten._scaled_dot_product_flash_attention: sdpa_flop, + aten._scaled_dot_product_cudnn_attention: sdpa_flop, aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop, aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop, + aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop, aten._flash_attention_forward: _flash_attention_forward_flop, aten._efficient_attention_forward: _efficient_attention_forward_flop, aten._flash_attention_backward: _flash_attention_backward_flop, diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 39e7070144aa..755a50404055 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs """ The Python Hipify script. ## # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. @@ -973,7 +974,7 @@ def repl(m): hipify_result.current_state = CurrentState.DONE return hipify_result except PermissionError as e: - print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}", + print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}', file=sys.stderr) hipify_result.hipified_path = fin_path hipify_result.status = "[skipped, no permissions]" diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index f70a43ad6857..ee828034bdf6 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch from collections import OrderedDict import weakref diff --git a/torch/utils/jit/log_extract.py b/torch/utils/jit/log_extract.py index 2e89a769eff0..51894f495e8e 100644 --- a/torch/utils/jit/log_extract.py +++ b/torch/utils/jit/log_extract.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from contextlib import contextmanager from typing import Any, List, Tuple, cast import random diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py index 2d1d8cd89ff5..06ca96d2de9a 100644 --- a/torch/utils/mkldnn.py +++ b/torch/utils/mkldnn.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index 038572806f41..6d2230da8ae1 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This module contains utility method for mobile model optimization and lint.""" import torch diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index a8d491ed6b3a..7e2bc36d2e71 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs """ model_dump: a one-stop shop for TorchScript model inspection. diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index f2d83fb36f92..9feef40ca4da 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import weakref from typing import Set diff --git a/torch/utils/show_pickle.py b/torch/utils/show_pickle.py index 24ea1eb4e1e9..66549fac2673 100644 --- a/torch/utils/show_pickle.py +++ b/torch/utils/show_pickle.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: allow-untyped-defs import sys import pickle import struct diff --git a/torch/utils/tensorboard/_convert_np.py b/torch/utils/tensorboard/_convert_np.py index 9368464c2491..80a3c684579d 100644 --- a/torch/utils/tensorboard/_convert_np.py +++ b/torch/utils/tensorboard/_convert_np.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """This module converts objects into numpy array.""" import numpy as np diff --git a/torch/utils/tensorboard/_embedding.py b/torch/utils/tensorboard/_embedding.py index afbe68191aa9..44cb6c41b017 100644 --- a/torch/utils/tensorboard/_embedding.py +++ b/torch/utils/tensorboard/_embedding.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import math import numpy as np from ._convert_np import make_np diff --git a/torch/utils/tensorboard/_onnx_graph.py b/torch/utils/tensorboard/_onnx_graph.py index 5c923fcb0ee5..c744ca8719f3 100644 --- a/torch/utils/tensorboard/_onnx_graph.py +++ b/torch/utils/tensorboard/_onnx_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from tensorboard.compat.proto.graph_pb2 import GraphDef from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.versions_pb2 import VersionDef diff --git a/torch/utils/tensorboard/_proto_graph.py b/torch/utils/tensorboard/_proto_graph.py index 3c0d15723d24..30140a22cff6 100644 --- a/torch/utils/tensorboard/_proto_graph.py +++ b/torch/utils/tensorboard/_proto_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Optional from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.attr_value_pb2 import AttrValue diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index f4274199ffd3..d3d2f37cad74 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from collections import OrderedDict import contextlib from typing import Dict, Any diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index f79f59749f53..30984cfadf17 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import numpy as np diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 4d94c3e6158b..55a74f3f8771 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import json import logging import os diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index c646ce0c0c11..cdc4c565734a 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs """Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization.""" import os diff --git a/torch/utils/throughput_benchmark.py b/torch/utils/throughput_benchmark.py index 5607fadee9e9..2778b37b5a78 100644 --- a/torch/utils/throughput_benchmark.py +++ b/torch/utils/throughput_benchmark.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import torch._C diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index f17348e401c3..8c1b9da7a6ad 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import gc import sys from typing import Any, Dict, List, NamedTuple, Optional, Tuple diff --git a/torch/utils/weak.py b/torch/utils/weak.py index a5e33a34d7aa..cc272a7f2637 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from __future__ import annotations import weakref diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 3e7f43b87d4a..6049a11861d2 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs r""" This package introduces support for the XPU backend, specifically tailored for Intel GPU optimization. diff --git a/torch/xpu/random.py b/torch/xpu/random.py index 733c55b658cd..1ebdd476ed8c 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs from typing import Iterable, List, Union import torch diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index f4e35a376e7c..19a7cda162f4 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-defs import ctypes import torch diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 1a55211b9990..10b011741d55 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -321,6 +321,17 @@ def is_foreach_func(f: NativeFunction) -> bool: "_foreach_mul.Tensor", "_foreach_div.Tensor", } +# The following do not support the alpha kwarg, which the nonforeach versions support. +_skip_argument_len_check = { + "_foreach_add.Scalar", + "_foreach_add_.Scalar", + "_foreach_add.ScalarList", + "_foreach_add_.ScalarList", + "_foreach_sub.Scalar", + "_foreach_sub_.Scalar", + "_foreach_sub.ScalarList", + "_foreach_sub_.ScalarList", +} # Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function @@ -335,6 +346,11 @@ def is_reference_for_foreach( not function_schema.name.name.inplace or str(f.func.name) in _foreach_with_inplace_ref ) + and ( + str(f.func.name) in _skip_argument_len_check + or len(f.func.arguments.flat_non_out) + == len(function_schema.arguments.flat_non_out) + ) and all( ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) for arg, ref_arg in zip( diff --git a/torchgen/gen.py b/torchgen/gen.py index d715361146ea..a1c1a8f957f3 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -165,9 +165,11 @@ def parse_native_yaml_struct( rs: List[NativeFunction] = [] bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) for e in es: + assert isinstance(e, dict), f"expected to be dict: {e}" assert isinstance(e.get("__line__"), int), e loc = Location(path, e["__line__"]) funcs = e.get("func") + assert funcs is not None, f"missed 'func' in {e}" with context(lambda: f"in {loc}:\n {funcs}"): func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys) rs.append(func) @@ -268,7 +270,11 @@ def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: base_func_map[f.func.name.name].append(f) for f in funcs: if f.structured_delegate is not None: - delegate_func = func_map[f.structured_delegate] + delegate_func = func_map.get(f.structured_delegate) + assert delegate_func is not None, ( + f"{f.func.name} is marked as a structured_delegate pointing to " + f"{f.structured_delegate}, but {f.structured_delegate} is missing." + ) assert delegate_func.structured, ( f"{f.func.name} is marked as a structured_delegate pointing to " f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " diff --git a/torchgen/model.py b/torchgen/model.py index 2706f234c56b..bed8f262f592 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -626,6 +626,9 @@ def from_yaml( assert device_check_s is None or isinstance( device_check_s, str ), f"not a str: {device_check_s}" + assert ( + device_check_s is None or device_check_s in DeviceCheckType.__members__ + ), f"illegal device_check: {device_check_s}" device_check: DeviceCheckType if device_check_s is None: device_check = DeviceCheckType.ExactSame @@ -706,7 +709,12 @@ def from_yaml( for ks, v in raw_dispatch.items(): if ks == "__line__": continue # not worth tracking line numbers for dispatch entries - assert isinstance(ks, str), e + assert isinstance( + ks, str + ), f"illegal dispatch key '{ks}' in {raw_dispatch}" + assert isinstance( + v, str + ), f"illegal dispatch value '{v}' in {raw_dispatch}" for k in ks.split(","): dispatch_key = DispatchKey.parse(k.strip()) num_dispatch_keys += 1 @@ -2006,8 +2014,12 @@ def alias_info(self) -> Optional[Annotation]: def parse(arg: str) -> "Argument": name: str default: Optional[str] + assert " " in arg, f"illegal argument '{arg}'" type_and_annot, name_and_default = arg.rsplit(" ", 1) if "=" in name_and_default: + assert ( + name_and_default.count("=") == 1 + ), f"illegal argument with default value: '{name_and_default}'" name, default = name_and_default.split("=") else: name = name_and_default @@ -2792,6 +2804,9 @@ def parse(src: object) -> "Precompute": ) arg, with_list_raw = raw_replace_item.split(" -> ") + assert ( + " " not in arg + ), f"illegal kernel param name '{arg}' in precomputed parameters'" with_list = with_list_raw.split(",") with_list_args = [Argument.parse(name.strip()) for name in with_list] replace[arg] = with_list_args diff --git a/torchgen/static_runtime/config.py b/torchgen/static_runtime/config.py index 407165147e35..da6e2a21c2a3 100644 --- a/torchgen/static_runtime/config.py +++ b/torchgen/static_runtime/config.py @@ -383,6 +383,6 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N return if op_name in ("diagonal", "linalg_diagonal"): arg_map["offset"] = "0" - arg_map["dim0"] = "1" arg_map["dim1"] = "2" + arg_map["dim2"] = "1" return diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index e709450b48d3..b068af7728aa 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -222,6 +222,17 @@ def has_alias( "special_spherical_bessel_j0", "_foobar", "_nested_tensor_strides", + "_nested_tensor_storage_offsets", + "_nested_get_values", # no CPU backend + "_nested_get_values_copy", # no CPU backend + "_nested_view_from_jagged", # testing needs to be patched + "_nested_view_from_jagged_copy", # testing needs to be patched + "_nested_view_from_buffer", # testing needs to be patched + "_nested_view_from_buffer_copy", # testing needs to be patched + "_int_mm", # testing needs to be patched + "_to_sparse_csc", # testing needs to be patched + "_to_sparse_csr", # testing needs to be patched + "segment_reduce", # testing needs to be patched ) )