diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index aa63625b9896..e0528274b9b2 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -2,9 +2,9 @@ name: Build and Test on: pull_request: - branches: [ main ] + branches: [ feature/misc_fixes ] push: - branches: [ main ] + branches: [ feature/misc_fixes ] workflow_dispatch: # Ensure that only a single job or workflow using the same @@ -25,9 +25,9 @@ jobs: strategy: fail-fast: true matrix: - os-arch: [ubuntu-x86_64, macos-arm64, windows-x86_64] - llvm-build: [in-tree, out-of-tree] - torch-binary: [ON, OFF] + os-arch: [ubuntu-x86_64] # macos-arm64, windows-x86_64 + llvm-build: [in-tree] # out-of-tree + torch-binary: [ON] # OFF torch-version: [nightly, stable] exclude: # Exclude llvm in-tree and pytorch source @@ -51,11 +51,11 @@ jobs: include: # Specify OS versions - os-arch: ubuntu-x86_64 - os: a100 - - os-arch: macos-arm64 - os: macos-latest - - os-arch: windows-x86_64 - os: windows-latest + os: ubuntu-latest # a100 + #- os-arch: macos-arm64 + # os: macos-latest + #- os-arch: windows-x86_64 + # os: windows-latest runs-on: ${{ matrix.os }} steps: @@ -71,7 +71,6 @@ jobs: uses: actions/checkout@v3 with: submodules: 'true' - fetch-depth: 0 - name: Fetch PyTorch commit hash if: ${{ matrix.os-arch != 'windows-x86_64' }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 1af748879e43..a8f95ef91415 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -13,11 +13,16 @@ on: jobs: build_linux: name: Manylinux x86_64 Build - runs-on: a100 + runs-on: ubuntu-latest + permissions: + contents: write + actions: write + packages: write strategy: matrix: - package: [ torch-mlir, torch-mlir-core ] - py_version: [ cp38-cp38, cp311-cp311 ] + package: [ torch-mlir ] + py_version: [ cp38-cp38, cp310-cp310 ] # cp311-cp311 + torch-version: [stable] # nightly exclude: - package: torch-mlir-core py_version: cp38-cp38 @@ -36,7 +41,6 @@ jobs: uses: actions/checkout@v3 with: submodules: 'true' - fetch-depth: 0 - uses: ./.github/actions/setup-build with: @@ -46,7 +50,11 @@ jobs: cd $GITHUB_WORKSPACE TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh + TM_SKIP_TESTS=ON \ + TM_PYTHON_VERSIONS=${{ matrix.py_version }} \ + TM_PACKAGES=${{ matrix.package }} \ + TM_TORCH_VERSION="${{ matrix.torch-version }}" \ + ./build_tools/python_deploy/build_linux_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. @@ -55,7 +63,7 @@ jobs: id: upload-release-assets uses: dwenegar/upload-release-assets@v1 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl @@ -66,7 +74,7 @@ jobs: id: publish_release uses: eregon/publish-release@v1 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} - name: Create dist directory @@ -86,6 +94,7 @@ jobs: path: dist build_linux_arm64: + if: false name: Manylinux arm64 Build runs-on: linux-arm64 strategy: @@ -155,6 +164,7 @@ jobs: path: dist build_macos: + if: false name: MacOS Build runs-on: macos-latest strategy: @@ -215,6 +225,7 @@ jobs: path: dist build_windows: + if: false name: Windows Build runs-on: windows-latest strategy: @@ -286,11 +297,15 @@ jobs: publish_releases: runs-on: ubuntu-latest + permissions: + contents: write + actions: write + packages: write needs: - build_linux - - build_linux_arm64 - - build_macos - - build_windows + #- build_linux_arm64 + #- build_macos + #- build_windows # Publish even if one of the builds failed if: ${{ always() }} @@ -300,7 +315,7 @@ jobs: uses: benc-uk/workflow-dispatch@v1 with: workflow: Publish releases page - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + token: ${{ secrets.GITHUB_TOKEN }} # Wheels must be published from a linux environment. # diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index c6df475cca4d..5ee7047c5d8d 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -8,9 +8,11 @@ jobs: scrape_and_publish_releases: name: "Scrape and publish releases" runs-on: ubuntu-latest + permissions: + contents: write # Don't run this in everyone's forks. - if: github.repository == 'llvm/torch-mlir' + if: github.repository == 'xilinx/torch-mlir' steps: - name: Prepare workspace @@ -20,10 +22,8 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository uses: actions/checkout@v3 - with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: Run scrape releases script - run: python ./build_tools/scrape_releases.py llvm torch-mlir > /tmp/index.html + run: python ./build_tools/scrape_releases.py xilinx torch-mlir > /tmp/index.html shell: bash - run: git fetch --all - run: git switch github-pages diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index 46832ce9c667..bec2e21282f0 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -8,7 +8,7 @@ jobs: name: "Tag snapshot release" runs-on: ubuntu-latest # Don't run this in everyone's forks. - if: github.repository == 'llvm/torch-mlir' + #if: github.repository == 'llvm/torch-mlir' steps: - name: Prepare workspace run: | @@ -16,10 +16,11 @@ jobs: # existing lock files. sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checking out repository + - name: Checkout torch-mlir uses: actions/checkout@v3 with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + submodules: 'true' + fetch-depth: 0 - name: Compute version run: | diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index c18eff88d32f..0bf45adad584 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -2,7 +2,7 @@ name: Release snapshot package on: schedule: - - cron: '0 11 * * *' + - cron: '17 4 * * *' workflow_dispatch: @@ -11,7 +11,12 @@ jobs: name: "Tag snapshot release" runs-on: ubuntu-latest # Don't run this in everyone's forks. - if: github.repository == 'llvm/torch-mlir' + #if: github.repository == 'llvm/torch-mlir' + permissions: + contents: write + actions: write + env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} steps: - name: Prepare workspace @@ -22,8 +27,6 @@ jobs: - name: Checking out repository uses: actions/checkout@v3 - with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: Compute version run: | @@ -40,15 +43,15 @@ jobs: - name: Pushing changes uses: ad-m/github-push-action@v0.6.0 with: - github_token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - branch: main + github_token: ${{ secrets.GITHUB_TOKEN }} + branch: ${{ env.BRANCH_NAME }} tags: true - name: Create Release id: create_release uses: actions/create-release@v1 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: ${{ env.tag_name }} release_name: torch-mlir snapshot ${{ env.tag_name }} @@ -57,17 +60,15 @@ jobs: draft: true prerelease: false - - name: "Invoke workflow :: Build and Test" - uses: benc-uk/workflow-dispatch@v1 - with: - workflow: Build and Test - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - ref: "${{ env.tag_name }}" + # - name: "Invoke workflow :: Build and Test" + # uses: benc-uk/workflow-dispatch@v1 + # with: + # workflow: Build and Test + # ref: "${{ env.tag_name }}" - name: "Invoke workflow :: Release Build" uses: benc-uk/workflow-dispatch@v1 with: workflow: Release Build - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} ref: "${{ env.tag_name }}" inputs: '{"release_id": "${{ steps.create_release.outputs.id }}", "python_package_version": "${{ env.package_version }}"}' diff --git a/.gitmodules b/.gitmodules index 81c66a441907..5b0f4e7479eb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,7 @@ [submodule "externals/llvm-project"] path = externals/llvm-project - url = https://github.com/llvm/llvm-project.git + url = https://github.com/Xilinx/llvm-project.git + branch = misc_fixes [submodule "externals/mlir-hlo"] path = externals/mlir-hlo url = https://github.com/tensorflow/mlir-hlo.git diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index bd9c0cee61e5..2d5d38568cf6 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -116,9 +116,9 @@ function run_on_host() { docker run --rm \ -v "${repo_root}:/main_checkout/torch-mlir" \ -v "${TM_OUTPUT_DIR}:/wheelhouse" \ - -v "${HOME}:/home/${USER}" \ + -v "${PWD}:$PWD" \ --user ${USERID}:${GROUPID} \ - --workdir="/home/$USER" \ + --workdir="$PWD" \ --volume="/etc/group:/etc/group:ro" \ --volume="/etc/passwd:/etc/passwd:ro" \ --volume="/etc/shadow:/etc/shadow:ro" \ @@ -275,7 +275,6 @@ function test_in_tree() { cd /main_checkout/torch-mlir/ export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir" - case $torch_version in nightly) echo ":::: Test with nightly torch" @@ -300,7 +299,10 @@ function test_in_tree() { exit 1 ;; esac - + + echo ":::: Run make_fx + TOSA e2e integration tests" + python -m e2e_testing.main --config=make_fx_tosa -v + echo ":::: Run TorchDynamo e2e integration tests" python -m e2e_testing.main --config=torchdynamo -v @@ -422,7 +424,7 @@ function build_torch_mlir() { ;; stable) echo ":::: Using stable dependencies" - python3 -m pip install --no-cache-dir torch torchvision + python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 13e7ba7c892d..3893edee4765 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -29,6 +29,7 @@ from .xfail_sets import ( LINALG_XFAIL_SET, + MAKE_FX_TOSA_PASS_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, @@ -42,7 +43,7 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -94,6 +95,10 @@ def main(): config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET crashing_set = set() + elif args.config == "make_fx_tosa": + config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True) + xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET + crashing_set = set() elif args.config == "stablehlo": config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) xfail_set = all_test_unique_names - STABLEHLO_PASS_SET diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index c077657a6047..7ace41ffead3 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -11,8 +11,17 @@ # might be used to keep more elaborate sets of testing configurations). from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS +from torch_mlir._version import torch_version_for_comparison, version -LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS +LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { + "Conv1dNoPaddingModule_basic", + "Conv1dNoPaddingTransposeModule_basic", + "Conv1dNoPaddingGroupModule_basic", + "RepeatInterleaveStaticModule_basic", + "RepeatInterleaveFillModule_basic", + # tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic" +} TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors @@ -267,7 +276,23 @@ "ScatterValueFloatModule_basic", # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} "ScatterValueIntModule_basic", + # ERROR: Unsupported: dynamic shape operator: aten.repeat_interleave.Tensor + "RepeatInterleaveModule_basic", + "RepeatInterleaveFillModule_basic", + + # failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal + "Conv1dNoPaddingModule_basic", + "Conv1dNoPaddingTransposeModule_basic", + "Conv1dNoPaddingGroupModule_basic", + + # tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + + # failed to legalize operation 'torch.aten.clamp' that was explicitly marked illegal + "ElementwiseClampIntModule_basic", + # failed to legalize operation 'torch.constant.int' + "RepeatInterleaveStaticModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -372,9 +397,11 @@ "BatchNorm3DModule_basic", "BatchNorm1DStaticShapeModule_basic", "ResNet18StaticModule_basic", + "AtenToDtypeModule_basic", "BmmModule_basic", "BroadcastToModule_basic", "BroadcastToSameRankStaticModule_basic", + "BroadcastToDifferentRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", "BroadcastListConstructWithMinusOneModule_basic", "BucketizeTensorStaticFloatModule_basic", @@ -399,6 +426,7 @@ "ElementwiseClampModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampMaxModule_basic", + "ElementwiseClampIntModule_basic", "ElementwiseSignModule_basic", "ElementwisePowModule_basic", "ElementwisePowTensorStaticModule_basic", @@ -430,6 +458,7 @@ "ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqFloatScalarModule_basic", "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqBoolScalarModule_basic", "ElementwiseNeFloatScalarModule_basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", @@ -506,7 +535,9 @@ "IndexSelectWholeDimensionModule_basic", "IndexSelectWholeTensorModule_basic", "IndexSelectNegativeDimModule_basic", + "IndexSelectStaticModule_basic", "IndexTensorStaticModule_basic", + "IndexTensorModule3dInputStatic_basic", "IndexTensorMultiIndexStaticModule_basic", "LayerNormLastDimModule_basic", "LayerNormModule_basic", @@ -567,6 +598,7 @@ "TensorsConcatPromoteDTypeModule_basic", "TensorsConcatStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeStaticModule_basic", "TensorsStackModule_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", @@ -590,9 +622,12 @@ "NumToTensorFloatModule_basic", "AtenToDeviceModule_basic", "AvgPool2dStaticModule_basic", + "Conv1dNoPaddingModule_basic", + "Conv1dNoPaddingGroupModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Convolution2DStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "Convolution2DGroupsStatic_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneModule_basic", @@ -622,14 +657,18 @@ "SliceModule_basic", "SliceNegIdxModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfLowerBoundStartIndexStaticModule_basic", "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStartEqEndModule_basic", "SliceSizeTwoStepModule_basic", + "SliceSizeTwoStepDivisibleStaticModule_basic", "SliceWholeTensorModule_basic", "SliceScatterModule_basic", "SliceScatterNegativeDimModule_basic", "SliceScatterNegativeEndModule_basic", "SliceScatterStaticModule_basic", + "SliceEndSleStartStaticModule_basic", "SliceScatterStepVariationModule_basic", "SliceScatterZeroDimModule_basic", "SqueezeDimModule_static", @@ -651,6 +690,7 @@ "EmptyModule_falsePinMemory", "EmptyModule_int", "EmptyModule_float", + "NewEmptyModuleBool_basic", "NewEmptyModuleDefaultDtype_basic", "NewEmptyModuleFalsePinMemory_basic", "NewEmptyModuleFloat2D_basic", @@ -720,6 +760,7 @@ "ReduceMaxFloatModule_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", + "PrimsSumFloatModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumFloatModule_basic", @@ -745,6 +786,9 @@ "NumpyTRank2Module_basic", "NumpyTRankNStaticModule_basic", "NumpyTRankNDynamicModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "TensorsSplitTensorLastSmallerModule_basic", "TModuleRank2_basic", "TensorLiteralModule_basic", "TensorsConcatModule_basic", @@ -818,6 +862,15 @@ "ElementwiseMinimumIntModule_basic", "ElementwiseMaximumModule_basic", "ElementwiseMaximumIntModule_basic", + "ElementwiseSinModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseAcosTensorFloatModule_basic", + "ElementwiseAsinTensorFloatModule_basic", + "ElementwiseAtan2TensorFloatModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampIntModule_basic", "ViewDoubleMergeStaticModule_basic", "ViewCollapseOnesMiddleModule_basic", "ViewFiveTestStaticModule_basic", @@ -854,6 +907,12 @@ "ReturnTwoTensorF32I64_basic", "ElementwiseSignModule_basic", "ElementwisePowModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", + "AtenToDtypeModule_basic", "BmmModule_basic", "MmDagModule_basic", "Matmul4dStatic_basic", @@ -879,6 +938,10 @@ "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseGtFloatTensorModule_basic", "ElementwiseGtIntTensorModule_basic", + "ElementwiseLeFloatIntScalarModule_basic", + "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", "ElementwiseLtFloatScalarModule_basic", "ElementwiseLtIntScalarModule_basic", "ElementwiseLtDiffWidthScalarModule_basic", @@ -886,6 +949,7 @@ "ElementwiseLtIntTensorModule_basic", "ElementwiseEqFloatScalarModule_basic", "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqBoolScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqFloatTensorModule_basic", "ElementwiseEqIntTensorModule_basic", @@ -907,6 +971,8 @@ "ElementwiseReciprocalModule_basic", "ElementwiseIsnanModule_basic", "TypePromotionAlphaWiderModule_basic", + "Conv1dNoPaddingModule_basic", + "Conv1dNoPaddingGroupModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "BatchNorm1DModule_basic", "BatchNorm1DWith2DInputModule_basic", @@ -923,6 +989,7 @@ "ReduceAmaxKeepDim_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", + "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "ElementwiseLog2Module_basic", @@ -973,10 +1040,12 @@ "ViewNoChangeStaticModule_basic", "UnsafeViewExpandModule_basic", "ReshapeCollapseModule_basic", + "ElementwiseErfModule_basic", "ElementwiseGeluModule_basic", "GeluBackwardModule_basic", "ElementwiseNeIntScalarModule_basic", "Convolution2DStaticModule_basic", + "Convolution2DGroupsStatic_basic", "ElementwiseNegModule_basic", "TestMultipleTensorReturn_basic", "TypeAsSameModule_basic", @@ -989,11 +1058,13 @@ "BaddbmmWithBetaModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", + "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", "NumpyTRankNStaticModule_basic", "NumpyTRankNDynamicModule_basic", "EmbeddingModuleI32Static_basic", + "EmbeddingModule1DIndices_basic", "TModuleRank2_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", @@ -1014,15 +1085,30 @@ "TypePromotionSameCategoryZeroRankWider_basic", "TypePromotionZeroRankHigherCategoryModule_basic", "GatherStaticModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorModule3dInputStatic_basic", "ElementwiseWhereScalarModule_basic", "FullLikeModuleFloat3DStatic_basic", "FullModuleDefaultDtype_basic", "FullModuleFloat3D_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleInt2D_basic", "MaskedFillScalarDefaultModule_basic", + "MaskedFillScalarFloatValueModule_basic", + "MaskedFillScalarFloatValueStaticModule_basic", "NumToTensorFloatModule_basic", "LiftFreshCopyModule_basic", + "PrimsSumFloatModule_basic", + "PrimsSqueezeModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", @@ -1031,10 +1117,19 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "BroadcastToDifferentRankStaticModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", "BroadcastListConstructWithMinusOneModule_basic", "SliceStaticModule_basic", + "SliceSizeTwoStepDivisibleStaticModule_basic", + "SliceOutOfLowerBoundStartIndexStaticModule_basic", "ArangeStartStepIntModule_basic", "ArangeDtypeFloatModule_basic", "ArangeIntModule_basic", @@ -1042,6 +1137,13 @@ "ArangeStartIntModule_basic", "ArangeStartNegativeStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartStepFloatModule_basic", "NumToTensorIntModule_basic", "ToDtypeBoolLayoutNoneStaticModule_basic", "ToCopyBoolDTypeStaticModule_basic", @@ -1070,6 +1172,9 @@ "FullModuleFloat2D_basic", "ElementwiseAbsModule_basic", "RepeatModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "TensorsSplitTensorLastSmallerModule_basic", "ConstantPad2dStaticModule_basic", "ConstantPadNdModule_basic", "ConstantPadNdPartialStaticModule_basic", @@ -1096,15 +1201,97 @@ "UnbindIntGetItem_Module_basic", "TensorsConcatStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeStaticModule_basic", "AtenComplex64Module_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseSqrtModule_basic", + "EmptyModule_defaultDtype", + "EmptyModule_int", + "EmptyModule_float", + "EmptyModule_contiguous", + "EmptyModule_falsePinMemory", + "NewEmptyModuleBool_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "Fill_TensorFloat64WithFloat32Static_basic", "SplitTensorGetItem_Module_basic", "SplitTensorListUnpackModule_basic", "ChunkListUnpack_Module_basic", "ChunkListUnpackUneven_Module_basic", + "RepeatInterleaveStaticModule_basic", + "RepeatInterleaveFillModule_basic", } +MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { +### Tests additionally passing in make_fx_tosa + "CumsumStaticModule_basic", + "CumsumStaticNegativeDimModule_basic", + "NativeGroupNormBackwardModule_basic", + "SliceWholeTensorModule_basic", + "TensorFloatModule_basic", + "TensorIntModule_basic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexSelectStaticModule_basic", + "LinalgVectorNormModule_basic", + "LinalgVectorNormKeepDimModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NormalizeModule_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", +}) - { +### Test failing in make_fx_tosa but not in tosa + + # 'tosa.const' op failed to verify that all of {value, output} have same shape + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + + # 'tensor.empty' op incorrect number of dynamic sizes, has 1, expected 0 + "BatchNorm1DStaticShapeModule_basic", + + # Dynamic shape, has extra unsupported broadcast ops + "Matmul_3d", + + # failed to legalize operation 'torch.aten.max_pool2d_with_indices + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticModule_basic", + "ResNet18StaticModule_basic", + + # Unimplemented operator 'aten._index_put_impl_.hacked_twin' + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + + # failed to legalize operation 'torch.aten.index.Tensor' + "Im2ColModule_basic", +} + +if torch_version_for_comparison() < version.parse("2.1.0.dev"): + MAKE_FX_TOSA_PASS_SET -= { + # 'tensor.expand_shape' op expected rank expansion, but found source rank 1 >= result rank 1 + "ReshapeCollapseModule_basic", + } + LTC_CRASHING_SET = { # https://github.com/llvm/torch-mlir/issues/2186 + "Conv1dNoPaddingModule_basic", + "Conv1dNoPaddingTransposeModule_basic", + "Conv1dNoPaddingGroupModule_basic", "Add_Module_basic" } @@ -1150,6 +1337,8 @@ "IndexPut2DFloatNonAccumulateModule_basic", "IndexPut2DIntAccumulateModule_basic", "IndexPut2DIntNonAccumulateModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", "IndexPut3DFloatAccumulateModule_basic", "IndexPut3DFloatNonAccumulateModule_basic", "IndexPut3DIntAccumulateModule_basic", @@ -1177,6 +1366,7 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IndexTensorModule3dInput_basic", + "IndexTensorModule3dInputStatic_basic", "IndexTensorModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", @@ -1206,7 +1396,9 @@ "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", + "SliceEndSleStartStaticModule_basic", "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStartEqEndModule_basic", "SqrtIntModule_basic", "SubFloatModule_basic", @@ -1223,6 +1415,7 @@ "TensorsConcatModule_basic", "TensorsConcatStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeStaticModule_basic", "UniformModule_basic", "UniformNoCorrelationModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", @@ -1245,6 +1438,7 @@ "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", "PrimsConvertElementTypeModule_basic", + "PrimsSumFloatModule_basic", "ElementwisePreluModule_basic", "VarMeanBiasedModule_basic", "VarMeanUnbiasedModule_basic", @@ -1289,10 +1483,16 @@ "SplitTensorListUnpackModule_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "TensorsSplitTensorLastSmallerModule_basic", "ChunkListUnpack_Module_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", + "RepeatInterleaveModule_basic", + "RepeatInterleaveFillModule_basic", + "Im2ColModule_basic", } diff --git a/externals/llvm-project b/externals/llvm-project index 2b4807ba0442..ae98bc3601d6 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 2b4807ba044230ed6243f5c3a1329a9344de758d +Subproject commit ae98bc3601d6ad0ac41b2d46087cdcfca4bd539d diff --git a/externals/mlir-hlo b/externals/mlir-hlo index ac26bdba7a5e..a4ac6990f751 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit ac26bdba7a5edfe6060ba5be528b9d20c987297d +Subproject commit a4ac6990f7519a569a380452d7c1d3764aad7e59 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index c3ab1d474222..a91074d43178 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -38,6 +38,10 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, Value conv_val, ShapedType input_type, ShapedType weight_type, ShapedType output_type); +// Create a TOSA slice op from \p start with \p size +Value buildSlice(PatternRewriter &rewriter, Value &input, + llvm::ArrayRef start, llvm::ArrayRef size); + // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type); @@ -55,7 +59,7 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // To create INT48 TOSA constant, need to pass in llvm::APInt instead. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape); + ArrayRef vec, ArrayRef shape, std::optional dtype = {}); LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result); @@ -116,6 +120,13 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, rewriter.replaceOp(op, result->getResults()); } +TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape); + +TypedValue transposeBy(Location loc, + PatternRewriter &rewriter, Value val, + ArrayRef permutation); + // Get accumulator type for AvgPool2dOp. LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType); diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3d7bee6d1183..d26f07190129 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -837,6 +837,96 @@ def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [ }]; } +def Torch_AtenAsinOp : Torch_Op<"aten.asin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::asin : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::asin_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsin_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsin_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcosOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcos_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenNegOp : Torch_Op<"aten.neg", [ AllowsTypeRefinement, HasValueSemantics, @@ -7328,6 +7418,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ @@ -7550,6 +7641,30 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ }]; } +def Torch_AtenRepeatInterleaveTensorOp : Torch_Op<"aten.repeat_interleave.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$repeats, + AnyTorchOptionalIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRepeatInterleaveTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRepeatInterleaveTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ AllowsTypeRefinement, ReadOnly @@ -12016,6 +12131,31 @@ def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [ }]; } +def Torch_PrimsSumOp : Torch_Op<"prims.sum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::sum : (Tensor, int[]?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$inp, + AnyTorchOptionalListOfTorchIntType:$dims, + AnyTorchOptionalIntType:$output_dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsSumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void PrimsSumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ AllowsTypeRefinement, ReadOnly diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 37aaed9cd704..0ae1bf607a61 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -16,10 +16,24 @@ namespace mlir { namespace torch { namespace Torch { +class BaseTensorType; int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); bool getListConstructElements(Value v, SmallVectorImpl &elems); + +/// Returns a torch.list of the given vals as torch.constant.int. +Value toTorchList(Location loc, PatternRewriter &rewriter, + ArrayRef vals); + +/// Broadcast the given value of tensor type to the new shape. +TypedValue broadcastTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape); + +/// Reshapes the given value of tensor type to the new shape. +TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape); + /// Returns the index indicated by `v` for a list of given `length`. /// If the index is negative, it is adjusted to `length` + `v`. /// `None` is returned the index is not an integer in the range [0,`length). diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0884fbe7a910..5007786b5fef 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -48,6 +48,8 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, return b.create(loc, iupred, lhs, rhs); if (intType.isSigned()) return b.create(loc, ispred, lhs, rhs); + assert(intType.getWidth() == 1); + return b.create(loc, iupred, lhs, rhs); } llvm_unreachable("Unhandled element type for comparison"); } @@ -656,6 +658,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( divTensorMode.emitError("invalid rounding mode"); return nullptr; } + if (auto pow = dyn_cast(op)) { + if (!pow.getType() + .cast() + .getDtype() + .isa()) { + pow.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Type dtype = pow.getExponent().getType().cast().getDtype(); + Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype); + return b.create(loc, selfPromoted, payloadArgs[0]); + } if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() @@ -1158,7 +1172,7 @@ class ConvertElementwiseOp : public ConversionPattern { AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, - AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, + AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index fc1efa364bcf..f0d9e9beb2ad 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,6 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -19,11 +21,14 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/SmallVector.h" +#include using namespace mlir; using namespace mlir::torch; @@ -145,16 +150,26 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, if (dtype.isa()) { tosaTensor = tosa::getConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dshape) + rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); - if (w != 32 && w != 64) + if (w!= 1 && w != 32 && w != 64) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "Unsupported integer type: " << intType; }); - if (w == 32) { + if (w == 1) { + if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); + } + bool d = isFloat ? static_cast(doubleValue) + : static_cast(intValue); + tosaTensor = + tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + } else if (w == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " @@ -200,8 +215,9 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unsupported integer value for alpha"); - alphaTensor = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue); + alphaTensor = tosa::getConstTensor( + rewriter, op, {static_cast(alphaValue)}, {}, dtype) + .value(); return success(); } @@ -359,7 +375,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { auto rhsTensor = rhsTy ? rhs : rhsAsTensor; // There is no Lesser operator in TOSA. auto swapLhsRhs = (std::is_same() || - std::is_same()); + std::is_same() || + std::is_same() || + std::is_same()); // Promote lhs and rhs dtypes for bitwise operators. TensorType resultTy = OpConversionPattern::getTypeConverter() @@ -604,7 +622,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()).value(); auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -990,6 +1008,40 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { } }; +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Value exp = adaptor.getExponent(); + auto expTy = exp.getType().template dyn_cast(); + + if (!expTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); + + if (!expTy.getElementType().isa()) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); + + Value selfTensor; + Value selfScalar = op.getSelf(); + if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor, + expTy.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); + + auto outType = + getTypeConverter()->convertType(op.getType()).template cast(); + + auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, + selfTensor, exp); + rewriter.replaceOp(op, powOp.getResult()); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorScalarOp op, OpAdaptor adaptor, @@ -1024,6 +1076,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowTensorTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Value self = adaptor.getSelf(); + auto selfTy = self.getType().template cast(); + + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); + + if (!selfTy.getElementType().isa()) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); + + auto outType = + getTypeConverter()->convertType(op.getType()).template cast(); + + Value expTensor = adaptor.getExponent(); + if (expTensor.getType() != selfTy) { + expTensor = rewriter.createOrFold( + op->getLoc(), + RankedTensorType::get(outType.getShape(), selfTy.getElementType()), + expTensor); + } + + auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, + self, expTensor); + rewriter.replaceOp(op, powOp.getResult()); + return success(); +} + // Perform the basic n-dim matmul operation encompassing the handling of // broadcasting and dynamic shape propagation. // All PyTorch ops that leverage matrix multiplication will derive this and @@ -1829,6 +1914,58 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +/// tosa.conv2d does not support group convolution. +/// Therefore, we create multiple ops where the input, kernel +/// and bias are slices of the original inputs. +/// Afterwards we concat the results into a single tensor. +/// This is inspired by the legalization done in onnx-mlir. +Value createConvInGroups(PatternRewriter &rewriter, Operation *op, + Type &resultType, + const llvm::ArrayRef weightShape, + Value &input, Value &weights, Value &bias, + const int64_t groups, DenseI64ArrayAttr &pads, + DenseI64ArrayAttr &strides, + DenseI64ArrayAttr &dilations) { + // Set up constants outside of loop + const int64_t sizeOfSliceInput = weightShape[1]; + const int64_t sizeOfSliceKernel = weightShape[0] / groups; + auto inputShape = input.getType().cast().getShape(); + + llvm::SmallVector inputSize = { + inputShape[0], inputShape[1], inputShape[2], sizeOfSliceInput}; + llvm::SmallVector kernelSize = {sizeOfSliceKernel, weightShape[2], + weightShape[3], weightShape[1]}; + llvm::SmallVector sliceValues; + Type outputType = RankedTensorType::get( + llvm::SmallVector(4, ShapedType::kDynamic), + resultType.cast().getElementType()); + for (int64_t i = 0; i < groups; i++) { + // Slice input + Value sliceInput = tosa::buildSlice( + rewriter, input, {0, 0, 0, i * sizeOfSliceInput}, inputSize); + + // Slice kernel + Value sliceWeight = tosa::buildSlice( + rewriter, weights, {i * sizeOfSliceKernel, 0, 0, 0}, kernelSize); + + // Slice bias + Value sliceBias = tosa::buildSlice(rewriter, bias, {i * sizeOfSliceKernel}, + {sizeOfSliceKernel}); + + // Create conv + Value tempConv2D = tosa::CreateOpAndInfer( + rewriter, input.getLoc(), outputType, sliceInput, sliceWeight, + sliceBias, pads, strides, dilations); + // Add value to vector + sliceValues.push_back(tempConv2D); + } + + constexpr int64_t channelDim = 3; + // Create concat op + return tosa::CreateOpAndInfer( + rewriter, op->getLoc(), outputType, sliceValues, channelDim); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenConvolutionOp op, OpAdaptor adaptor, @@ -1875,7 +2012,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor(rewriter, op, zeroVec, - {static_cast(weightShape[0])}) + {static_cast(weightShape[0])}, + inputElemTy) .value(); } } else { @@ -1895,6 +2033,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); + + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "transpose must be a bool constant"); + + if (transposed) + return rewriter.notifyMatchFailure( + op, "Unimplemented: only non-transposed convolutions supported"); + + int64_t groups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) + return rewriter.notifyMatchFailure( + op, "non-const group convolution unsupported"); + // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. @@ -1954,18 +2107,26 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // quantized input is i32, which gets rescaled down to quantized output range. SmallVector outputShape = {transposedInputShape[0], outputHDim, outputWDim, transposedWeightShape[0]}; - auto convOpTy = - RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); - Value convOpResult = - rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(convOpTy), - transposedInput, transposedWeight, bias, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) - .getResult(); + DenseI64ArrayAttr paddingAttr = rewriter.getDenseI64ArrayAttr(padding); + DenseI64ArrayAttr strideAttr = rewriter.getDenseI64ArrayAttr(stride); + DenseI64ArrayAttr dilationAttr = rewriter.getDenseI64ArrayAttr(dilation); + Value convOpResult; + if (groups == 1) { + auto convOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); + convOpResult = + rewriter + .create(op->getLoc(), + getTypeConverter()->convertType(convOpTy), + transposedInput, transposedWeight, bias, + paddingAttr, strideAttr, dilationAttr) + .getResult(); + } else { + convOpResult = createConvInGroups( + rewriter, op, outputTy, weightShape, transposedInput, transposedWeight, + bias, groups, paddingAttr, strideAttr, dilationAttr); + } std::optional nhwcToNchwTransposeConst = tosa::getConstTensor(rewriter, op, @@ -2159,7 +2320,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); + tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, + meanType.getElementType()) + .value(); auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, @@ -2263,7 +2427,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto elemCntConst = tosa::getConstTensor(rewriter, op.getOperation(), - {static_cast(elemCnt)}, {1}) + {static_cast(elemCnt)}, {1}, elemTy) .value(); Value elemCntRcp = rewriter.create( op.getLoc(), elemCntConst.getType(), elemCntConst); @@ -2318,7 +2482,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); + tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, elemTy) + .value(); // Compute layer norm. auto layerNorm = @@ -2471,9 +2637,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); - auto ln2Op = - tosa::getConstTensor(rewriter, op, {0.69314718056}, ln2Shape) - .value(); + auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, + ln2Shape, selfType.getElementType()) + .value(); auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); @@ -2688,7 +2854,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } static Value approximateErfOp(ConversionPatternRewriter &rewriter, - Operation *op, Value x) { + Operation *op, Value x, Type dtype) { // Using: // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = @@ -2699,24 +2865,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto outType = x.getType().cast(); auto loc = op->getLoc(); auto absX = rewriter.create(loc, outType, x); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).value(); + auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}, dtype).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).value(); + auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).value(); + auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).value(); + auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2739,9 +2905,10 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, } static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, - Operation *op, Value x) { - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + Operation *op, Value x, Type dtype) { + auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); + auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one @@ -2750,12 +2917,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}).value(); + tosa::getConstTensor(rewriter, op, 0.70710678, {}, dtype).value(); + Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); - Value erf = approximateErfOp(rewriter, op, erfArg); + Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).value(); + Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); return normalCdf; @@ -2786,7 +2955,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf()); + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); + cdf = rewriter.createOrFold( + op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, /*shift=*/0); @@ -2827,16 +3000,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( const double kAlpha = cstAlpha0 * cstAlpha1; Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}).value(); + tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}).value(); + tosa::getConstTensor(rewriter, op, -0.5, {}, selfElemTy).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( loc, selfType, inputSquared, negOneHalf, /*shift=*/0); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf()); + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); Value dinputInput = rewriter.create( loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0); Value dinputInputAlpha = rewriter.create( @@ -2900,7 +3073,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); } - Value replace = tosa::getConstTensor(rewriter, op, 0, {}).value(); + Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( @@ -2939,9 +3112,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); - if (indicesType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "indices must be of rank 2"); - auto weightType = weight.getType().cast(); if (weightType.getRank() != 2) return op.emitError("weight must be of rank 2"); @@ -3155,6 +3325,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "output type must be ranked"); + } + if (outTy.hasStaticShape() && outTy.getNumElements() == 0) { + return rewriter.notifyMatchFailure(op, + "tosa.slice does not support zero size"); + } + // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -3165,46 +3345,66 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isValidDim(dim, selfType.getRank())) return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); + auto sizeOfDim = selfType.getDimSize(dim); + int64_t start; if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - if (start < 0) - return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); + // support for start < 0 + start = toPositiveDim(start, sizeOfDim); + start = std::clamp(start, (int64_t)0, sizeOfDim); int64_t end; - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) - return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { + if (isa(op.getEnd().getDefiningOp())) + end = sizeOfDim; + else + return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); + } + // support for end < 0 end = toPositiveDim(end, selfType.getShape()[dim]); // support for end out of upper bound end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end); - - // FIXME: add support for start < 0 and end < start - if (end < start) - return rewriter.notifyMatchFailure(op, - "Currently unsupported: end < start"); + // Handle start > end + end = std::clamp(end, (int64_t)0, sizeOfDim); int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); - if (step != 1) - return rewriter.notifyMatchFailure( - op, "step value other than 1 is currently unsupported"); + if (sizeOfDim % step != 0) { + return rewriter.notifyMatchFailure(op, "size must be divisible by step"); + } - SmallVector startSlice(selfType.getRank(), 0); - SmallVector sizeSlice = - llvm::to_vector(makeShapeTorchCompatible(selfType.getShape())); + // We handle step by splitting the dimension dim into two dimensions, + // where the second one has size 'step'. + // E.g. to take slice with step 3 out of dim=0 of [6, 10], we first + // reshape into [2, 3, 10]. + SmallVector newShape{selfType.getShape()}; + newShape[dim] /= step; + newShape.insert(newShape.begin() + dim+1, step); - startSlice[dim] = start; - sizeSlice[dim] = end - start; + auto reshaped = + tosa::reshapeTo(op->getLoc(), rewriter, adaptor.getSelf(), newShape); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(startSlice), - rewriter.getDenseI64ArrayAttr(sizeSlice)); + SmallVector startSlice(reshaped.getType().getRank(), 0); + + startSlice[dim] = start / step; + startSlice[dim+1] = start % step; + + SmallVector sliceShape{outTy.getShape()}; + sliceShape.insert(sliceShape.begin() + dim+1, 1); + auto slice = rewriter.create( + op.getLoc(), outTy.cloneWith(sliceShape, outTy.getElementType()), + reshaped, rewriter.getDenseI64ArrayAttr(startSlice), + rewriter.getDenseI64ArrayAttr(sliceShape)); + + auto out = tosa::reshapeTo(op->getLoc(), rewriter, slice, outTy.getShape()); + + rewriter.replaceOp(op, out); return success(); } @@ -3287,8 +3487,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( tosa::getZerosLikeTensor(rewriter, op, resultType).value(); // Use add broadcast - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), - zeroTensor); + auto newOp = rewriter.createOrFold( + op.getLoc(), resultType, adaptor.getSelf(), zeroTensor); + rewriter.replaceOp(op, newOp); return success(); } return rewriter.notifyMatchFailure( @@ -3383,6 +3584,482 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Turn a torch.aten._index_put_impl where some entries in the indices list are +// none into multiple _index_put_impl across all elements of that dimension. +// +// Example: +// a = torch.aten._index_put_impl(in, [idx0, None, idx1], values) +// where in is a 7x3x5 tensor, is equivalent to +// tmp = torch.aten._index_put_impl(in, [idx0, [0], idx1], values) +// tmp2 = torch.aten._index_put_impl(tmp, [idx0, [1], idx1], values) +// a = torch.aten._index_put_impl(tmp2, [idx0, [2], idx1], values) +class SimplifyAten_IndexPutImplOpNone + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Aten_IndexPutImplOp op, + PatternRewriter &rewriter) const override { + + auto outTy = dyn_cast(op.getType()); + if (!outTy || !outTy.areAllSizesKnown()) + return failure(); + + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return failure(); + + for (size_t i=0; i < indices.size(); ++i) { + if (isa(indices[i].getType())) { + Value newIndexPut = op.getSelf(); + auto si64Type = IntegerType::get(rewriter.getContext(), 64, IntegerType::Signed); + Type indexType = + ValueTensorType::get(rewriter.getContext(), {{}}, si64Type); + for( int64_t d=0; d < outTy.getSizes()[i]; ++d) { + SmallVector newIndices = indices; + + newIndices[i] = rewriter.create(op.getLoc(), indexType, + rewriter.create( + op->getLoc(), d)); + + Value newIndicesList = + rewriter.create(op->getLoc(), op.getIndices().getType(), newIndices); + + newIndexPut = rewriter.create(op.getLoc(), op.getType(), newIndexPut, newIndicesList, op.getValues(), + op.getAccumulate(), op.getUnsafe()); + } + rewriter.replaceOp(op, newIndexPut); + return success(); + } + } + return failure(); + } +}; + +// Turn a torch.aten._index_put_impl on a 2d [1, n] tensor into a +// torch.aten._index_put_impl on a 1d [n] tensor. +class SimplifyAten_IndexPutImplOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Aten_IndexPutImplOp op, + PatternRewriter &rewriter) const override { + + auto ty = op.getType().dyn_cast(); + if (!ty || !ty.areAllSizesKnown()) { + return rewriter.notifyMatchFailure(op, "Required ranked tensor type"); + } + + auto shape = ty.getSizes(); + if (shape.size() != 2 || shape[0] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non-2d output with leading dimension of size 1"); + } + int64_t numSelfElements = shape[1]; + + auto valuesTy = op.getValues().getType().dyn_cast(); + if (!valuesTy || !valuesTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure(op, "Required ranked tensor type for values"); + } + + auto valuesShape = valuesTy.getSizes(); + if (valuesShape.size() > 2) { + return rewriter.notifyMatchFailure( + op, "unimplemented: nd values with n>=2"); + } + if (valuesShape.size() == 2 && valuesShape[0] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: 2d values with leading dimension of size 1"); + } + auto numValues = valuesShape.empty() ? 1 : valuesShape.back(); + + SmallVector indicesList; + if (!getListConstructElements(op.getIndices(), indicesList)) { + return op.emitError( + "unimplemented: the indices list is not from list construct"); + } + // There is one indices tensor for each dimension of self. + // Here, we know that self is 1xN, so we are only interested for the indices + // of the 2nd dimension. + auto indices = indicesList[1]; + auto indicesTy = indices.getType().dyn_cast(); + if (!indicesTy || !indicesTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure( + op, "Required ranked tensor type for indices"); + } + if (indicesTy.getSizes().size() > 1) { + return rewriter.notifyMatchFailure( + op, "Required 0d or 1d tensor for indices"); + } + auto numIndices = + indicesTy.getSizes().empty() ? 1 : indicesTy.getSizes()[0]; + + if (indicesTy.getSizes().empty()) { + indices = reshapeTo(op.getLoc(), rewriter, indices, {1}); + } + + // Broadcast so that values and indices have the same size + if (numIndices == 1 && numValues > numIndices) { + indices = broadcastTo(op.getLoc(), rewriter, indices, {numValues}); + } + + Value newIndicesList = rewriter.create( + op->getLoc(), op.getIndices().getType(), SmallVector{indices}); + + auto reshapedSelf = + reshapeTo(op.getLoc(), rewriter, op.getSelf(), {numSelfElements}); + + auto values = reshapeTo(op.getLoc(), rewriter, op.getValues(), {numValues}); + + // Broadcast so that values and indices have the same size + if (numValues == 1 && numIndices > numValues) { + values = broadcastTo(op.getLoc(), rewriter, values, {numIndices}); + } + + auto put = rewriter.create( + op.getLoc(), reshapedSelf.getType(), reshapedSelf, newIndicesList, + values, op.getAccumulate(), op.getUnsafe()); + + rewriter.replaceOp(op, reshapeTo(op.getLoc(), rewriter, put, shape)); + + return success(); + } +}; + +// Handle Aten_IndexPutImplOp on 1d tensors +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + Aten_IndexPutImplOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // TOSA scatter: + // // Copy the values_in tensor to the values_out tensor. + // // Values not written by the scatter operation are unchanged in the output. + // for_each(0 <= n < N, 0 <= k < K, 0 <= c < C) { + // value_t value = tensor_read(values_in, [N,K,C], [n,k,c]); + // tensor_write(values_out, [N,K,C], [n, k, c], value); + // } + // // Now perform the SCATTER operation, modifying the positions from the + // indices tensor for_each(0 <= n < N, 0 <= w < W, 0 <= c < C) { + // index_t k = tensor_read(indices, [N,W], [n,w]); + // REQUIRE(0 <= k && k < K); + // value_t value = tensor_read(input, [N,W,C], [n,w,c]); + // tensor_write(values_out, [N,K,C], [n, k, c], value); + // output_modified[n,k,c] = true; + // } + + auto loc = op.getLoc(); + + // Not a tensor type. + auto self = dyn_cast>(adaptor.getSelf()); + if (!self) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + if (self.getType().getRank() != 1) { + return rewriter.notifyMatchFailure( + op, "Only 1d input tensor are currently supported"); + } + + auto values = dyn_cast>(adaptor.getValues()); + if (!values) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + // Deal with torch.prim.ListConstruct of non const value to get the index + SmallVector indicesTorchType; + if (!getListConstructElements(op.getIndices(), indicesTorchType)) + return op.emitError( + "unimplemented: the tensor list is not from list construct"); + + // Convert indicesTorchType to TOSA types + auto indexTensors = getTypeConvertedValues( + rewriter, op->getLoc(), getTypeConverter(), indicesTorchType); + + // the number of tensors in indexTensors is equal to the rank of outType + if (indexTensors.size() != 1) { + return rewriter.notifyMatchFailure(op, "Expected 1 indices "); + } + + auto indices0 = indexTensors[0]; + auto indicesTy = dyn_cast(indices0.getType()); + + if (!indicesTy || indicesTy.getShape() != values.getType().getShape()) + return rewriter.notifyMatchFailure( + op, "Expected indices to have same shape as values"); + + + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!outType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + + auto numInElements = self.getType().getShape()[0]; + auto numValues = values.getType().getShape()[0]; + + // TOSA scatter requires 3d in and 2d indices & values + SmallVector scatterInOutShape {1, numInElements, 1}; + SmallVector scatterIndicesShape {1, numValues}; + SmallVector scatterInputShape {1, numValues, 1}; + + auto in = mlir::tosa::reshapeTo(loc, rewriter, self, scatterInOutShape); + auto indices = + mlir::tosa::reshapeTo(loc, rewriter, indices0, scatterIndicesShape); + auto input = mlir::tosa::reshapeTo(loc, rewriter, values, scatterInputShape); + + // TOSA scatter requires 32 bit indices + // TODO: This might break on large (sparse?) tensors that require 64 bit indices + auto indices32Ty = RankedTensorType::get(indices.getType().getShape(), rewriter.getI32Type()); + auto indices32 = rewriter.create(loc, indices32Ty, indices); + + auto scatterTy = RankedTensorType::get(scatterInOutShape, self.getType().getElementType()); + auto scatter = rewriter.create(loc, scatterTy, in, indices32, input); + + auto reshaped = + mlir::tosa::reshapeTo(loc, rewriter, scatter, outType.getShape()); + + rewriter.replaceOp(op, reshaped); + return success(); +} + +// This defines a template to simplify legalization of certain ops. +template +class SimplifyAtenOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +template <> +LogicalResult SimplifyAtenOp::matchAndRewrite( + AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // TOSA doesn't supports 1D convolutions. + // We model them through a combination of AtenViewOp and 2D Convolution. + // A Conv1D is replaced by: + // %view = AtenViewOp (%input) : (3D type) -> (4D Type) + // %conv2d = AtenConvolution (%view) : (4D type) -> (4D type) + // %view2 = AtenViewOp (%conv2d) : (4D type) -> (3D type) + + auto inputTy = adaptor.getInput().getType().cast(); + auto weightTy = adaptor.getWeight().getType().cast(); + auto outputTy = getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + auto ty = op.getType().dyn_cast_or_null(); + if (!ty || !ty.hasSizes()) + return rewriter.notifyMatchFailure( + op, "unimplemented: input must have known sizes"); + + if (!inputTy || !weightTy || !outputTy) + return rewriter.notifyMatchFailure( + op, "Input, weight and output to Convolution must be ranked tensors"); + + if (!weightTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: TOSA only supports static weight"); + + if (inputTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Unimplemented: only simplify 1D convolution"); + + auto loc = op->getLoc(); + + auto getListConstructElementsPlusValue = + [&](Value listConstruct, int64_t addedValue) -> std::optional { + SmallVector values; + if (!getListConstructElements(listConstruct, values)) { + return std::nullopt; + } + + Type ty = listConstruct.getType(); + values.push_back( + rewriter.create(op->getLoc(), addedValue)); + return rewriter.create(op->getLoc(), ty, values); + }; + + auto stride = getListConstructElementsPlusValue(op.getStride(), 1); + if (!stride.has_value()) + return rewriter.notifyMatchFailure(op, "non-const stride list unsupported"); + + auto dilation = getListConstructElementsPlusValue(op.getDilation(), 1); + if (!dilation.has_value()) + return rewriter.notifyMatchFailure(op, + "non-const dilation list unsupported"); + + auto paddingValue = getListConstructElementsPlusValue(op.getPadding(), 0); + if (!paddingValue.has_value()) + return rewriter.notifyMatchFailure(op, + "non-const padding list unsupported"); + + auto outputPaddingValue = + getListConstructElementsPlusValue(op.getOutputPadding(), 0); + if (!outputPaddingValue.has_value()) { + return rewriter.notifyMatchFailure( + op, "non-const output padding list unsupported"); + } + + auto addDimOneToSizes = [&](BaseTensorType ty) { + SmallVector newSizes(ty.getSizes()); + newSizes.push_back(1); + return newSizes; + }; + + auto input = op.getInput(); + auto weight = op.getWeight(); + + auto newSizes = addDimOneToSizes(cast(input.getType())); + Value view1dTo2d = reshapeTo(loc, rewriter, input, newSizes); + + auto newWeightSizes = addDimOneToSizes(cast(weight.getType())); + weight = reshapeTo(loc, rewriter, weight, newWeightSizes); + + auto convSizes = addDimOneToSizes(cast(ty)); + auto convTy = ty.getWithSizesAndDtype(convSizes, ty.getOptionalDtype()); + auto conv2dOp = rewriter.create( + loc, convTy, view1dTo2d, weight, op.getBias(), *stride, + *paddingValue, *dilation, op.getTransposed(), *outputPaddingValue, + op.getGroups()); + + Value view2dTo1d = reshapeTo(loc, rewriter, conv2dOp, ty.getSizes()); + rewriter.replaceOp(op, view2dTo1d); + return success(); +} + +// The goal of this pattern is to handle the case where the indices for all +// dimensions except one are None. +class ConvertAtenIndexTensorOpNone + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // To do so, we rewrite index.Tensor like that : + // - To match tosa format of NxKxC, with K the dimension to extract from: + // - Transpose the dim to extract into position 'K' + // - flatten the other dimensions + // - Reshape to insert a 1x dimension as the N - The format should be + // 1xKxC with C the flattened dimensions + // - Insert a tosa.gather + // - Bring back to the original format: + // - Reshape + // - Transpose + auto loc = op->getLoc(); + auto outTy = dyn_cast( + getTypeConverter()->convertType(op.getType())); + if (!outTy || !outTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op.getLoc(), + "unimplemented: Only static shapes are currently supported"); + + SmallVector torchIndices; + if (!getListConstructElements(op.getIndices(), torchIndices)) + return rewriter.notifyMatchFailure( + op.getLoc(), + "unimplemented: the tensor list is not from list construct"); + + auto indicesList = + getTypeConvertedValues(rewriter, loc, typeConverter, torchIndices); + + // Check that all indices are none but one. + int64_t indexDim = -1; + for (size_t i = 0; i < indicesList.size(); ++i) { + if (!indicesList[i]) + continue; + if (indexDim != -1) { + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: only one dimension must be set in " + "indices for this pattern to work"); + } + indexDim = i; + } + if (indexDim == -1) { + return rewriter.notifyMatchFailure(op.getLoc(), + "unimplemented: all indices are none"); + } + + auto indices = + dyn_cast>(indicesList[indexDim]); + if (!indices) { + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: index must be ranked tensor"); + } + + if (indices.getType().getRank() != 1) { + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: index must be 1d tensor"); + } + + auto input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: input must have static shapes"); + auto inputElemTy = inputTy.getElementType(); + + // Transpose indexDim into dimension 0 + SmallVector transposePerm; + for (int64_t i = 0; i < inputTy.getRank(); ++i) + transposePerm.push_back(i); + transposePerm[0] = indexDim; + transposePerm[indexDim] = 0; + + auto transposedInput = tosa::transposeBy(loc, rewriter, input, transposePerm); + + // Flatten matrix [k, ...] -> [1, k, c] + auto transposedShape = transposedInput.getType().getShape(); + int64_t k = transposedShape[0]; + int64_t c = std::accumulate(transposedShape.begin() + 1, transposedShape.end(), 1, + [&](int64_t a, int64_t b) { + return a * b; + }); + + SmallVector reshapedFormat = {1, k, c}; + // Reshapes the input to 1xKx(flattened_dims) + auto reshapedInput = + tosa::reshapeTo(loc, rewriter, transposedInput, reshapedFormat); + + auto w = indices.getType().getDimSize(0); + auto reshapedIndices = tosa::reshapeTo(loc, rewriter, indices, {1, w}); + + // And cast indices to i32 + TensorType promotedType = + reshapedIndices.getType().cloneWith(reshapedIndices.getType().getShape(), rewriter.getI32Type()); + auto castedIndices = rewriter.create(op->getLoc(), promotedType, reshapedIndices); + + SmallVector gatherShape = {1, w, c}; + auto gatherOp = rewriter.create( + op->getLoc(), RankedTensorType::get(gatherShape, inputElemTy), + reshapedInput, castedIndices); + + // Unflatten [1, w, c] -> [w, ...] + SmallVector unflattenedShape{transposedShape}; + unflattenedShape[0] = w; + auto unflattened = + tosa::reshapeTo(loc, rewriter, gatherOp, unflattenedShape); + + SmallVector inversePermutation(transposePerm.size(), 0); + for (size_t i = 0; i < transposePerm.size(); ++i) + inversePermutation[transposePerm[i]] = i; + + + // Transpose 'w' back in the original position of 'k' + auto unTranspose = + tosa::transposeBy(loc, rewriter, unflattened, inversePermutation); + + rewriter.replaceOp(op, unTranspose); + return success(); + } +}; + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorOp op, OpAdaptor adaptor, @@ -3625,23 +4302,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - int64_t int_min, int_max; - if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) + int64_t intMin = 0; + int64_t intMax = 0; + double fpMin = 0.0; + double fpMax = 0.0; + + auto min = op.getMin(); + auto isIntMin = matchPattern(min, m_TorchConstantInt(&intMin)); + auto isFloatMin = matchPattern(min, m_TorchConstantFloat(&fpMin)); + auto isNoneTypeMin = min.getType().isa(); + + auto max = op.getMax(); + auto isIntMax = matchPattern(max, m_TorchConstantInt(&intMax)); + auto isFloatMax = matchPattern(max, m_TorchConstantFloat(&fpMax)); + auto isNoneTypeMax = max.getType().isa(); + + if (!(isIntMin || isFloatMin || isNoneTypeMin)) return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant int"); + op, "unimplemented: value `int_min` should be a torch constant " + "int/float or Torch::NoneType"); - if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) + if (!(isIntMax || isFloatMax || isNoneTypeMax)) return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant int"); + op, "unimplemented: value `int_max` should be a torch constant " + "int/float or Torch::NoneType"); + + // Adjust min and max to their numeric_limits if type == Torch::NoneType. + if (isNoneTypeMin) { + intMin = std::numeric_limits::min(); + fpMin = std::numeric_limits::lowest(); + } + if (isNoneTypeMax) { + intMax = std::numeric_limits::max(); + fpMax = std::numeric_limits::max(); + } - IntegerAttr min_int = rewriter.getI64IntegerAttr(int_min); - IntegerAttr max_int = rewriter.getI64IntegerAttr(int_max); - FloatAttr min_fp = rewriter.getF32FloatAttr(float(int_min)); - FloatAttr max_fp = rewriter.getF32FloatAttr(float(int_max)); + // If we are using integer for min and max values, + // import them from their fp counterparts. + if (isIntMin) + fpMin = static_cast(intMin); + + if (isIntMax) + fpMax = static_cast(intMax); auto outType = getTypeConverter()->convertType(op.getType()); + + // It is safe to static_cast to float since tosa doesn't support fp64. + FloatAttr minFp = rewriter.getF32FloatAttr(static_cast(fpMin)); + FloatAttr maxFp = rewriter.getF32FloatAttr(static_cast(fpMax)); + IntegerAttr minInt = rewriter.getI64IntegerAttr(intMin); + IntegerAttr maxInt = rewriter.getI64IntegerAttr(intMax); + rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), - min_int, max_int, min_fp, max_fp); + minInt, maxInt, minFp, maxFp); return success(); } @@ -3669,28 +4382,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + auto matchIntOrDouble = + [&](Value val) -> std::tuple { + // Match int or fp values. The one used depends on the resultType. + // Therefore `valueInt` and `valueDouble` will have similar values (but may + // be truncated due to casting). + int64_t valueInt = 0; + double valueDouble = 0.0; + if (matchPattern(val, m_TorchConstantInt(&valueInt))) + return {success(), valueInt, static_cast(valueInt)}; + if (matchPattern(val, m_TorchConstantFloat(&valueDouble))) + return {success(), static_cast(valueDouble), valueDouble}; + return {failure(), valueInt, valueDouble}; + }; + + auto [matchStart, startInt, startDouble] = matchIntOrDouble(op.getStart()); + if (failed(matchStart)) return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, + "unimplemented: value `start` should be a torch constant int or float"); - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + auto [matchEnd, endInt, endDouble] = matchIntOrDouble(op.getEnd()); + if (failed(matchEnd)) return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, + "unimplemented: value `end` should be a torch constant int or float"); - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + auto [matchStep, stepInt, stepDouble] = matchIntOrDouble(op.getStep()); + if (failed(matchStep)) return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); + op, + "unimplemented: value `step` should be a torch constant int or float"); // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + auto elementType = resultType.getElementType(); + Value result; + if (isa(elementType)) { + int64_t resultShape = ceil(static_cast(endInt - startInt) / + static_cast(stepInt)); + SmallVector values(resultShape, startInt); + for (unsigned i = 1; i < resultShape; i++) + values[i] += i * stepInt; + result = tosa::getConstTensor(rewriter, op, values, resultShape) + .value(); + } else { + int64_t resultShape = ceil((endDouble - startDouble) / stepDouble); + SmallVector values(resultShape, startDouble); + for (unsigned i = 1; i < resultShape; i++) + values[i] += static_cast(i) * stepDouble; + result = tosa::getConstTensor(rewriter, op, values, resultShape) + .value(); + } rewriter.replaceOpWithNewOp(op, resultType, result); return success(); @@ -4347,8 +5091,9 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Supplied value must be a Scalar constant"); - rewriter.replaceOpWithNewOp(op, outType, constOp); - + auto newOp = + rewriter.createOrFold(op.getLoc(), outType, constOp); + rewriter.replaceOp(op, newOp); return success(); } }; @@ -4545,12 +5290,389 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto builtinTensors = getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); + for(auto &in: builtinTensors) + in = tosa::promoteType(rewriter, in, outType); + auto result = tosa::CreateOpAndInfer( rewriter, loc, outType, builtinTensors, rewriter.getI64IntegerAttr(dim)); rewriter.replaceOp(op, result.getResult()); return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Converts AtenSqrtOp into pow(x, 0.5) + auto self = adaptor.getSelf(); + auto selfTy = self.getType().dyn_cast(); + if (!selfTy) + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); + + auto resultType = typeConverter->convertType(op.getType()) + .template cast(); + auto elementType = resultType.getElementType(); + + if (selfTy.getElementType().isa()) { + self = rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType), + self); + } + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, elementType).value(); + + rewriter.replaceOpWithNewOp(op, resultType, self, oneHalf); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto loc = op.getLoc(); + MLIRContext* ctx = op->getContext(); + mlir::TypeConverter* typeConverter = this->getTypeConverter(); + + bool pinMemory; + if (!op.getPinMemory().getType().template isa() && + (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { + return rewriter.notifyMatchFailure( + op, "Unsupported pin_memory, should be either None or false"); + } + + if (!op.getDevice().getType().template isa()) { + std::string device; + if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) + return rewriter.notifyMatchFailure( + op, "unimplemented: device must be a constant str"); + if (device != "cpu") + return rewriter.notifyMatchFailure( + op, "unimplemented: device is expected to be none or cpu"); + } + + if (!op.getLayout().getType().template isa()) { + int64_t tensorLayout; + if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) + return rewriter.notifyMatchFailure( + op, "unimplemented: layout must be a constant"); + if (tensorLayout != torch_upstream::Layout::Strided) + return rewriter.notifyMatchFailure( + op, "unimplemented: layout is expected to be strided"); + } + // Only `none`, `contiguous` and `preserve` memory_format are supported. + if (!op.getMemoryFormat().getType().template isa()) { + int64_t memoryFormat; + if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) + return rewriter.notifyMatchFailure( + op, "unimplemented: the memory format should be specified in " + "an integer constant"); + if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && + memoryFormat != torch_upstream::MemoryFormat::Preserve) + return rewriter.notifyMatchFailure( + op, "unimplemented: only none, contiguous and preserve " + "memory_format is supported"); + } + + SmallVector size; + if (!getListConstructElements(op.getSize(), size)) + return rewriter.notifyMatchFailure( + op, "unimplemented: size must be a ListConstruct"); + SmallVector resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, + size); + auto resultType = + typeConverter->convertType(op.getType()).template cast(); + + DenseElementsAttr emptyVal; + if (op.getDtype().getType().template isa()) { + emptyVal = DenseFPElementsAttr::get(resultType, {0.0F}); + } else { + int64_t dtypeInt; + if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: dtype must be a constant integer or none"); + FailureOr maybeResultElementType = getTypeForScalarType( + ctx, (torch_upstream::ScalarType)dtypeInt, + IntegerType::Signless); + if (failed(maybeResultElementType)) { + return rewriter.notifyMatchFailure( + op, "unable to convert `dtypeInt` to builtin type"); + } + if(maybeResultElementType->isSignedInteger(64) || maybeResultElementType->isIndex()) + emptyVal = DenseIntElementsAttr::get(resultType, {0L}); + if(maybeResultElementType->isSignedInteger(32)) + emptyVal = DenseIntElementsAttr::get(resultType, {0}); + else if (maybeResultElementType->isSignlessInteger(64)) + emptyVal = DenseIntElementsAttr::get(resultType, {0UL}); + else if (maybeResultElementType->isSignlessInteger(32)) + emptyVal = DenseIntElementsAttr::get(resultType, {0U}); + else if (maybeResultElementType->isSignedInteger(1) || + maybeResultElementType->isSignlessInteger(1)) + emptyVal = DenseIntElementsAttr::get(resultType, {false}); + else if (maybeResultElementType->isF64()) + emptyVal = DenseFPElementsAttr::get(resultType, {0.0}); + else if (maybeResultElementType->isF32()) + emptyVal = DenseFPElementsAttr::get(resultType, {0.0F}); + else + return rewriter.notifyMatchFailure(op, "unsupported: dtype used for empty.memory_format is unsupported"); + } + + rewriter.replaceOpWithNewOp(op, resultType, emptyVal); + return success(); + } + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRepeatInterleaveTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto outputTy = getTypeConverter() + ->convertType(op.getType()) + .dyn_cast(); + if (!outputTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type outputs permitted"); + + auto shape = outputTy.getShape(); + if (shape.size() != 1) + return rewriter.notifyMatchFailure(op, "Only rank 1 tensors are permitted"); + + int64_t outputSize; + if (!matchPattern(op.getOutputSize(), m_TorchConstantInt(&outputSize))) { + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "output_size in TOSA operation"); + } + + auto repeats = dyn_cast(adaptor.getRepeats().getDefiningOp()); + if (!repeats) + return rewriter.notifyMatchFailure( + op, "Currently only constants are supported for " + "repeats in TOSA operation"); + + auto attr = repeats.getValue(); + if (!attr.isSplat()) + return rewriter.notifyMatchFailure(op, "Only single values are supported."); + + auto elementTy = outputTy.getElementType(); + if (!elementTy.isa()) + return rewriter.notifyMatchFailure(op, + "Only integer values are supported."); + + int64_t numberOfRepeats = attr.getSplatValue().getSExtValue(); + + // Create an array of repeated values + auto createConstArrayOfRepeatedValues = [&](int64_t numOfRepeats) { + SmallVector values; + for (int64_t val = 0; val < outputSize / numberOfRepeats; ++val) { + SmallVector newValues(numberOfRepeats, val); + values.insert(values.end(), newValues.begin(), newValues.end()); + } + return values; + }; + + auto newOp = tosa::getConstTensor( + rewriter, op, createConstArrayOfRepeatedValues(numberOfRepeats), shape, + elementTy); + rewriter.replaceOp(op, *newOp); + return success(); +} + +template +class ConvertAtenOpToTosaCustomOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + + ConvertAtenOpToTosaCustomOp(TypeConverter &typeConverter, + MLIRContext *context, std::string opName, + std::string implementedWithOpAttr = "UNDEF") + : OpConversionPattern(typeConverter, context), + opName(std::move(opName)), + implementedWithOpAttr(std::move(implementedWithOpAttr)) {} + + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Set tosa.custom_op attributes. + // Only identifier needs to be known. Other attributes are not used. + auto *ctx = op->getContext(); + auto identifier = StringAttr::get(ctx, opName); + auto implementAttr = StringAttr::get(ctx, implementedWithOpAttr); + auto config = StringAttr::get(ctx, "UNDEF"); + + rewriter.replaceOpWithNewOp( + op, + TypeRange{OpConversionPattern::getTypeConverter()->convertType( + op.getType())}, + identifier, config, implementAttr, adaptor.getOperands()); + return success(); + } + +private: + std::string opName; + std::string implementedWithOpAttr; +}; + +class SimplifyAtenIndexTensorWithSliceIndex + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + + LogicalResult matchAndRewrite(AtenIndexTensorOp op, + PatternRewriter &rewriter) const override { + auto outTy = dyn_cast(op.getType()); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return failure(); + + TypedValue input = + dyn_cast>(op.getSelf()); + if (!input) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + + if (llvm::count_if(indices, [](Value v) { + return !isa(v.getType()); + }) == 1) { + return rewriter.notifyMatchFailure(op, "nothing to do"); + } + + auto loc = op->getLoc(); + + for (size_t i = 0; i < indices.size(); ++i) { + if (isa(indices[i].getType())) + continue; + + auto indicesTy = dyn_cast(indices[i].getType()); + if (!indicesTy || !indicesTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure( + op, "requires indices with static shape"); + } + int64_t numIndices = std::accumulate( + indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1, + [&](int64_t a, int64_t b) { return a * b; }); + if (numIndices != 1) + continue; + + auto inputTy = input.getType(); + SmallVector slicedShape{inputTy.getSizes()}; + slicedShape[i] = 1; + auto slicedType = + inputTy.getWithSizesAndDtype(slicedShape, inputTy.getDtype()); + + auto none = rewriter.create(op->getLoc()); + SmallVector sliceIndices{inputTy.getSizes().size(), none}; + sliceIndices[i] = reshapeTo(loc, rewriter, indices[i], {1}); + + Value sliceIndicesV = rewriter.create( + loc, op.getIndices().getType(), sliceIndices); + auto slicedInput = rewriter.create( + loc, slicedType, input, sliceIndicesV); + + SmallVector reshapedShape = slicedShape; + reshapedShape.erase(reshapedShape.begin() + i); + + auto reshaped = reshapeTo(loc, rewriter, slicedInput, reshapedShape); + + SmallVector newIndicesList{indices}; + newIndicesList.erase(newIndicesList.begin() + i); + + Value newIndicesListV = rewriter.create( + loc, op.getIndices().getType(), newIndicesList); + + rewriter.replaceOpWithNewOp(op, op.getType(), reshaped, + newIndicesListV); + return success(); + } + return failure(); + } +}; + +class SimplifyAtenIndexTensorWithNdIndex + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtenIndexTensorOp op, + PatternRewriter &rewriter) const override { + auto outTy = dyn_cast(op.getType()); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return failure(); + + TypedValue input = + dyn_cast>(op.getSelf()); + if (!input) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + auto loc = op->getLoc(); + + if (llvm::count_if(indices, [](Value v) { + return !isa(v.getType()); + }) != 1) { + return rewriter.notifyMatchFailure(op, "can only handle single None"); + } + + for (size_t i = 0; i < indices.size(); ++i) { + if (isa(indices[i].getType())) + continue; + + auto indicesTy = dyn_cast(indices[i].getType()); + if (!indicesTy || !indicesTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure( + op, "requires indices with static shape"); + } + if (indicesTy.getSizes().size() == 1) { + continue; + } + + // flatten indices + int64_t numIndices = std::accumulate( + indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1, + [&](int64_t a, int64_t b) { return a * b; }); + + auto newIndices = + reshapeTo(op.getLoc(), rewriter, indices[i], {numIndices}); + + SmallVector newIndicesList{indices}; + newIndicesList[i] = newIndices; + + Value newIndicesListV = rewriter.create( + loc, op.getIndices().getType(), newIndicesList); + + SmallVector indexOpShape{outTy.getSizes()}; + indexOpShape.erase(indexOpShape.begin() + i, + indexOpShape.begin() + i + indicesTy.getSizes().size()); + indexOpShape.insert(indexOpShape.begin() + i, numIndices); + + auto indexOpType = + outTy.getWithSizesAndDtype(indexOpShape, outTy.getOptionalDtype()); + auto indexed = rewriter.create( + loc, indexOpType, input, newIndicesListV); + + auto reshaped = + reshapeTo(loc, rewriter, indexed, outTy.getSizes()); + rewriter.replaceOp(op, reshaped); + return success(); + } + return failure(); + } +}; } // namespace // ----------------------------------------------------------------------------- @@ -4577,8 +5699,32 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); + // Mark constant ops as legal, so the error message about + // "failed to legalize" + // mentions the real problematic op and not the constants used by it. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalDialect(); + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(typeConverter, context); + +#define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_SIMPLIFY_OP_PATTERN(AtenConvolutionOp) +#undef INSERT_SIMPLIFY_OP_PATTERN + #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -4596,6 +5742,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenErfOp, tosa::ErfOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ @@ -4619,6 +5766,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add>(typeConverter, context); INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) @@ -4727,6 +5875,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_SCALAR_PATTERN(AtenFillScalarOp); #undef INSERT_FILL_SCALAR_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ @@ -4745,7 +5894,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenLeakyReluOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenPowScalarOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenRsubScalarOp); INSERT_ATENOP_PATTERN(AtenConvolutionOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); @@ -4768,6 +5919,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp); INSERT_ATENOP_PATTERN(AtenIndexTensorOp); INSERT_ATENOP_PATTERN(AtenAbsOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); @@ -4780,6 +5932,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); + INSERT_ATENOP_PATTERN(AtenRepeatInterleaveTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ @@ -4788,6 +5943,18 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); #undef INSERT_CLONE_ATENOP_PATTERN +#define INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenOp, opName, implementedWith) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + opName, implementedWith); + INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenAtan2Op, "math.atan2", + "linalg.generic"); + INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenSinOp, "math.sin", + "linalg.generic"); + INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN(AtenCosOp, "math.cos", + "linalg.generic"); +#undef INSERT_ATEN_TO_TOSA_CUSTOMOP_PATTERN + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 2bb6045d950d..afc041263174 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -429,6 +429,13 @@ std::optional convertReduceOpCommon( auto input_rank = input_shape.size(); Value val = input_value; + if (output_type.getElementType() != input_type.getElementType()) { + reduce_element_type = output_type.getElementType(); + val = rewriter.createOrFold(op->getLoc(), RankedTensorType::get( + input_shape, + reduce_element_type), val); + } + if (axes_elems.getNumElements() == 0) { // No axes means return the original tensor. auto identity_op = CreateOpAndInfer( diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1d026a62d414..b71378fa5ad4 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -133,6 +133,18 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, } } +Value buildSlice(PatternRewriter &rewriter, Value &input, + llvm::ArrayRef start, llvm::ArrayRef size) { + assert(start.size() == size.size() && + "Start and Size must have the same size"); + return tosa::CreateOpAndInfer( + rewriter, input.getLoc(), + RankedTensorType::get( + llvm::SmallVector(size.size(), ShapedType::kDynamic), + input.getType().cast().getElementType()), + input, start, size); +} + // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type) { return (output_element_type.getStorageTypeIntegralWidth() == 8); @@ -174,23 +186,32 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // Default template creates a constant tensor in T. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape) { + ArrayRef vec, ArrayRef shape, std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; } - if (vec.size() != num_total_elements) { + if (vec.size() != num_total_elements && vec.size() != 1) { op->emitOpError("getConstTensor(): number of elements mismatch."); return std::nullopt; } + auto width = sizeof(T) * 8; + if constexpr(std::is_same_v) + width = 1; + auto const_type = - RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); + RankedTensorType::get(shape, rewriter.getIntegerType(width)); auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } return const_op.getResult(); } @@ -198,13 +219,13 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape) { + ArrayRef shape, std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; } - if (vec.size() != num_total_elements) { + if (vec.size() != num_total_elements && vec.size() != 1) { op->emitOpError("getConstTensor(): number of elements mismatch."); return std::nullopt; } @@ -215,6 +236,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } return const_op.getResult(); } @@ -222,13 +248,13 @@ std::optional getConstTensor(PatternRewriter &rewriter, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape) { + ArrayRef shape, std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; } - if (vec.size() != num_total_elements) { + if (vec.size() != num_total_elements && vec.size() != 1) { op->emitOpError("getConstTensor(): number of elements mismatch."); return std::nullopt; } @@ -238,33 +264,53 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } return const_op.getResult(); } -static LogicalResult checkValidityOfCast(Type src, Type dest) { - if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || - (src.isInteger(64) && dest.isInteger(8)) || - (src.isInteger(64) && dest.isInteger(1)) || - (src.isInteger(64) && dest.isF32()) || - (src.isInteger(32) && dest.isInteger(64)) || - (src.isInteger(32) && dest.isInteger(1)) || - (src.isInteger(32) && dest.isF32()) || - (src.isInteger(32) && dest.isBF16()) || - (src.isInteger(16) && dest.isBF16()) || - (src.isInteger(8) && dest.isInteger(1)) || - (src.isInteger(8) && dest.isBF16()) || - (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) || - (src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) || - (src.isF32() && dest.isInteger(64)) || - (src.isF32() && dest.isInteger(1)) || - (src.isBF16() && dest.isInteger(8)) || - (src.isBF16() && dest.isInteger(16)) || - (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) { - return success(); +// Template specialization for double +template <> +std::optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape, std::optional dtype) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return std::nullopt; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } - return failure(); + return const_op.getResult(); +} + +static LogicalResult checkValidityOfCast(Type src, Type dest) { + if (src == dest) + return success(); + + auto isValid = [](Type ty) { + return ty.isInteger(1) || ty.isInteger(8) || ty.isInteger(16) || + ty.isInteger(32) || ty.isInteger(64) || ty.isBF16() || ty.isF16() || ty.isF32() || + ty.isF64(); + }; + + return success(isValid(src) && isValid(dest)); } // Template specialization for float @@ -294,14 +340,31 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, SmallVector values(num_total_elements, 0); constOp = tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(8)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(16)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isBF16()) { + SmallVector values(num_total_elements, 0.0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape, srcElemTy) + .value(); } else if (srcElemTy.isF32()) { SmallVector values(num_total_elements, 0.0); constOp = tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(8)) { - SmallVector values(num_total_elements, 0); + } else if (srcElemTy.isF64()) { + SmallVector values(num_total_elements, 0.0); constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else { + op->dump(); + op->emitError("Unsupported conversion to i1"); + return failure(); } Value equalToZero = rewriter.create(op->getLoc(), destType, src, constOp.value()); @@ -325,16 +388,56 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { return input; } +TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape) { + + auto tensorTy = dyn_cast(val.getType()); + assert(tensorTy); + + auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType()); + return rewriter.create( + loc, newTy, val, rewriter.getDenseI64ArrayAttr(newShape)); +} + +TypedValue transposeBy(Location loc, PatternRewriter &rewriter, + Value val, + ArrayRef permutation) { + auto tensorTy = dyn_cast(val.getType()); + assert(tensorTy); + + auto permType = RankedTensorType::get({(int64_t)permutation.size()}, + rewriter.getI32Type()); + auto permAttr = DenseElementsAttr::get(permType, permutation); + auto permOp = rewriter.create(loc, permType, permAttr); + + SmallVector newShape{tensorTy.getShape()}; + for (size_t i = 0; i < newShape.size(); i++) + newShape[i] = tensorTy.getShape()[permutation[i]]; + + auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType()); + + auto v = rewriter.createOrFold(loc, newTy, val, permOp); + return cast>(v); +} + // Template instantiation +template std::optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape, + std::optional dtype); + template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, - ArrayRef shape); + ArrayRef shape, + std::optional dtype); template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, - ArrayRef shape); + ArrayRef shape, + std::optional dtype); LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType) { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 96d412f03b3c..9b293e0d1eee 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2282,8 +2282,17 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); + int64_t start, end, step; + if (matchPattern(getStart(), m_TorchConstantInt(&start)) && + matchPattern(getEnd(), m_TorchConstantInt(&end)) && + matchPattern(getStep(), m_TorchConstantInt(&step)) + && step == 1 + && start == 0 + && end == std::numeric_limits::max()) + return getOperand(0); + + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || @@ -2433,6 +2442,68 @@ void PrimDeviceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +void AtenBroadcastToOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenBroadcastToOp op, PatternRewriter &rewriter) { + auto selfTy = dyn_cast(op.getSelf().getType()); + + if (!selfTy || !selfTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure(op, + "only applies when selfTy is known"); + } + + SmallVector resultShape; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape))) { + return rewriter.notifyMatchFailure( + op, "size must consist of Scalar constants"); + } + + SmallVector selfShape{selfTy.getSizes()}; + if (resultShape.size() == selfShape.size()) { + return rewriter.notifyMatchFailure(op, "nothing to do"); + } + + if (resultShape.size() <= selfShape.size()) { + return rewriter.notifyMatchFailure( + op, "unexpected result rank smaller than self rank"); + } + + size_t extraDims = resultShape.size() - selfShape.size(); + for (unsigned i = 0; i < extraDims; i++) { + if (resultShape[i] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: broadcasts that increases rank must add " + "dimensions with size 1."); + } + } + + // Create 1, ..., 1, inputShape[0], inputShape[1], inputShape[2] + SmallVector reshapeShape = resultShape; + for (unsigned i = 0; i < selfShape.size(); i++) + reshapeShape[extraDims + i] = selfShape[i]; + + SmallVector sizes; + for (unsigned i = 0; i < reshapeShape.size(); i++) { + sizes.push_back(rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(reshapeShape[i]))); + } + + auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); + + Value dims = + rewriter.create(op->getLoc(), listType, sizes); + + auto input = rewriter.create( + op->getLoc(), + selfTy.getWithSizesAndDtype(reshapeShape, selfTy.getDtype()), + op.getSelf(), dims); + + rewriter.replaceOpWithNewOp(op, op.getType(), input, + op.getSize()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index dbc2bc617067..c90e9d152dfd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6174,6 +6174,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6401,6 +6409,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pow.Scalar\"(%arg0: !torch.float, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6560,6 +6572,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.sum\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %0 = torch.derefine %arg2 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6669,6 +6687,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.Tensor\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" %4 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %4 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list\n" +" return %3 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.roll\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7865,6 +7898,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.acos\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -7981,6 +8024,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %1#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" @@ -8381,6 +8437,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -9368,6 +9428,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.union, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b0dce438e074..abf866847aa2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3105,6 +3105,11 @@ class DecomposeAtenCopyOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } + auto srcTy = op.getSrc().getType().cast(); + if (!srcTy.hasSizes() || !srcTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected src type to have a known rank"); + } Type resultDtype = resultType.getDtype(); Value srcToDtype = convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype); @@ -4487,6 +4492,71 @@ class DecomposeAtenScatterValueOp }; } // namespace +namespace { +// Decompose `aten.asin/acos` op into a combination of `mul/sqrt/atan` ops. +template +class DecomposeAtenArcSinCosOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ArcASinCosOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto outType = op.getType().template dyn_cast(); + if (!outType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + // According to CORDIC algorithm: + // asin(x) = atan2 (x, sqrt ((1 + x) * (1 - x))) + // acos(x) = atan2 (sqrt ((1 + x) * (1 - x)), x) + Value self = op.getSelf(); + Value one; + if (outType.hasDtype() && isa(outType.getDtype())) { + one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + } else { + one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + } + Value onePlusSelf = rewriter.create( + loc, outType, self, one, /*alpha*/ one); + Value minusSelf = rewriter.create(loc, outType, self); + Value oneMinusSelf = rewriter.create( + loc, outType, minusSelf, one, /*alpha*/ one); + + Value mult = rewriter.create(loc, outType, onePlusSelf, + oneMinusSelf); + Value sqrt = rewriter.create(loc, outType, mult); + + Value atan2; + if constexpr (std::is_same()) + atan2 = rewriter.create(loc, outType, self, sqrt); + else + atan2 = rewriter.create(loc, outType, sqrt, self); + + rewriter.replaceOp(op, atan2); + return success(); + } +}; +} // namespace + +namespace { +// Decompose prims.sum into aten.sum +class DecomposePrimsSumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsSumOp op, + PatternRewriter &rewriter) const override { + Value cstFalse = rewriter.create(op.getLoc(), false); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInp(), op.getDims(), /*keepdim=*/cstFalse, + op.getOutputDtype()); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.sign` op into comparisons and aten.where. class DecomposeAtenSignOp : public OpRewritePattern { @@ -4699,8 +4769,14 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal>( + patterns); + addPatternIfTargetOpIsIllegal>( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 7ec4594eb6ca..4890c6a8cad9 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -491,4 +491,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, auto opName = opOp->getAttr("name").cast().getValue(); return backendLegalOpsSet.contains(opName); }); + + // TODO: We need this for TOSA; other backends might be fine with this op + // having a dynamic sized output tensor. + target.addDynamicallyLegalOp( + [](AtenRepeatInterleaveTensorOp op) { + return op.getOutputSize().getDefiningOp(); + }); } diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index fe1b7a6a0847..c3e88e1a925d 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -55,6 +55,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern { newEnd = rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); } + newEnd = rewriter.create(op.getLoc(), newEnd, dimSize); newStart = rewriter.create(op.getLoc(), newStart, dimSize); newEnd = rewriter.create(op.getLoc(), newEnd, dimSize); @@ -88,6 +89,9 @@ class RecomposeSliceCopy_ : public OpRewritePattern { op, sliceOpInput.getType(), sliceOpInput, indices, op.getSrc(), /*accumulate=*/falseVal, /*unsafe=*/falseVal); + if (sliceOp->use_empty()) + rewriter.eraseOp(sliceOp); + return success(); } }; @@ -124,7 +128,7 @@ class RecomposeSelectFill_ : public OpRewritePattern { // Create indicesVector for IndexPut_Op by TorchNone and indexTensor BaseTensorType tensorType = op->getResultTypes()[0].cast(); - SmallVector indicesVector(dim - 1, noneVal); + SmallVector indicesVector(dim, noneVal); indicesVector.push_back(indexTensor); Value indices = rewriter.create( @@ -203,6 +207,51 @@ class RecomposeUnbindGetItem : public OpRewritePattern { } }; +class RecomposeSplitTensorPrimListUnpackOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + + auto torchList = op.getOperand(); + if (isListPotentiallyMutated(torchList)) + return failure(); + + auto split = torchList.getDefiningOp(); + if (!split) + return failure(); + int64_t size = 0; + if (!matchPattern(split.getSplitSize(), m_TorchConstantInt(&size))) + return failure(); + + Value constOne = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(1)); + std::vector results; + int64_t start = 0; + + for (size_t i = 0; i < op->getNumResults(); ++i) { + results.push_back(rewriter.create( + op->getLoc(), + op.getResult(i).getType(), + split.getSelf(), + /*dim=*/split.getDim(), + /*start=*/ + rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(start)), + /*end=*/ + rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(start + size)), + /*step=*/constOne)); + start += size; + } + rewriter.replaceOp(op, results); + if (split->use_empty()) + rewriter.eraseOp(split); + + return success(); + } +}; + class RecomposeSplitTensorGetItemOp : public OpRewritePattern { public: @@ -344,6 +393,51 @@ class RecomposeChunkListUnpack : public OpRewritePattern { return success(); } }; +class RecomposeRepeatInterleave : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRepeatInterleaveTensorOp op, + PatternRewriter &rewriter) const override { + if (!op.getOutputSize().getDefiningOp()) + return failure(); + + auto repeatsTy = dyn_cast(op.getRepeats().getType()); + if (!repeatsTy || !repeatsTy.areAllSizesKnown() || repeatsTy.getSizes().size() != 1) { + return rewriter.notifyMatchFailure( + op, + "Expected 1d tensor with static shape"); + } + auto numElements = repeatsTy.getSizes()[0]; + + auto broadcast = op.getRepeats().getDefiningOp(); + if (!broadcast){ + return rewriter.notifyMatchFailure( + op, + "Expected broadcast op defining repeat_interleave input"); + } + + auto fill = broadcast.getSelf().getDefiningOp(); + if (!fill){ + return rewriter.notifyMatchFailure( + op, + "Expected fill op defining broadcast/repeat_interleave input"); + } + + int64_t fillValue; + if (!matchPattern(fill.getValue(), + m_TorchConstantInt(&fillValue))) { + return rewriter.notifyMatchFailure( + op, + "Expected fill value of fill.Scalar to be an integer constant"); + } + + auto outputSize = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(fillValue * numElements)); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getRepeats(), outputSize); + return success(); + } +}; + } // namespace namespace { @@ -361,7 +455,9 @@ class RecomposeComplexOpsPass patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index b524166654bc..f4aafe773923 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -42,6 +42,43 @@ bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { return true; } +Value Torch::toTorchList(Location loc, PatternRewriter &rewriter, + ArrayRef vals) { + SmallVector intConsts; + for (int64_t v : vals) { + intConsts.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(v))); + } + + auto listType = + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())); + return rewriter.create(loc, listType, intConsts); +} + +TypedValue Torch::broadcastTo(Location loc, + PatternRewriter &rewriter, + Value val, + ArrayRef newShape) { + + auto ty = dyn_cast(val.getType()); + assert(ty); + auto newTy = ty.getWithSizesAndDtype(newShape, ty.getOptionalDtype()); + return cast>(rewriter.create( + loc, newTy, val, toTorchList(loc, rewriter, newShape)).getResult()); +} + +TypedValue Torch::reshapeTo(Location loc, + PatternRewriter &rewriter, + Value val, + ArrayRef newShape) { + + auto ty = dyn_cast(val.getType()); + assert(ty); + auto newTy = ty.getWithSizesAndDtype(newShape, ty.getOptionalDtype()); + return cast>(rewriter.create(loc, newTy, val, + toTorchList(loc, rewriter, newShape)).getResult()); +} + torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (type.isa()) return torch_upstream::ScalarType::Float; diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 157b5f91289e..20d8b336e8ac 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -51,6 +51,8 @@ if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) ADD_TO_PARENT TorchMLIRPythonSources SOURCES __init__.py + repro.py + fx_minifier.py _dynamo_fx_importer.py compiler_utils.py dynamo.py diff --git a/python/test/compile_api/do_test.py b/python/test/compile_api/do_test.py new file mode 100644 index 000000000000..7e5e4e245604 --- /dev/null +++ b/python/test/compile_api/do_test.py @@ -0,0 +1,39 @@ +# RUN: %PYTHON %s + +from dataclasses import dataclass +from typing import Optional +import torch_mlir +import torch + +class Model(torch.nn.Module): + def forward(self, x): + return 2 * x + +class ModelWithTuple(torch.nn.Module): + def forward(self, x): + return (2 * x,) + +class ModelWithNestedTuple(torch.nn.Module): + def forward(self, x): + return (2 * x, [x + x]) + +@dataclass +class ModelOutput(): + loss: Optional[torch.FloatTensor] = None + x: torch.FloatTensor = None + y: torch.FloatTensor = None + +class ModelWithDataclassOutput(torch.nn.Module): + def forward(self, x): + return ModelOutput(x=2 * x, y=x+x) + + +torch_mlir.do(Model(), torch.ones(5), output_type="torch") +torch_mlir.do(ModelWithTuple(), torch.ones(5), output_type="torch") +torch_mlir.do(ModelWithNestedTuple(), torch.ones(5), output_type="torch") +torch_mlir.do(ModelWithDataclassOutput(), torch.ones(5), output_type="torch") + + +torch_mlir.do(Model(), torch.ones(5), output_type="tosa") +torch_mlir.do(Model(), torch.ones(5), output_type="tosa", dtype=torch.bfloat16) +torch_mlir.do(Model(), torch.ones(5), output_type="tosa", dtype=torch.bfloat16, output_prefix="out") diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 836d3fdfc1ce..221bf97b7416 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -3,8 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +import dataclasses from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable from enum import Enum +import importlib.metadata import sys from io import StringIO @@ -13,11 +15,20 @@ from torch._functorch.compile_utils import strip_overloads import torch import torch.fx +from torch_mlir.dynamo import _get_decomposition_table +from torch.fx.experimental.proxy_tensor import make_fx from .compiler_utils import run_pipeline_with_repro_report from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library +from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import ( + TOSA_TO_LINALG_FUNC_PIPELINE, + LinalgOnTensorsTosaBackend, + ) +from ._mlir_libs._mlir.ir import Module +from .repro import reproduce +from .compiler_utils import prepare_model class OutputType(Enum): """The kind of output that `torch_mlir.compile` can produce. @@ -225,8 +236,11 @@ def _get_for_tracing( # they know what they are doing and that their trace is # correct for any specific concrete size. shape = [s if s != -1 else 7 for s in arg.shape] - example_args_for_trace.append( - torch.ones(*shape, dtype=arg.dtype)) + if len(shape) == 0: + example_args_for_trace.append(torch.tensor(1)) + else: + example_args_for_trace.append( + torch.ones(*shape, dtype=arg.dtype)) else: assert isinstance(arg, torch.Tensor) example_args_for_trace.append(arg) @@ -313,7 +327,8 @@ def compile(model: torch.nn.Module, ignore_traced_shapes=False, backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], - verbose: bool = False): + verbose: bool = False, + use_make_fx: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -367,6 +382,13 @@ def compile(model: torch.nn.Module, else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) + if use_make_fx: + args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"] + model = make_fx( + model, + decomposition_table=_get_decomposition_table())(*args) + + # For FX-based models, automatically strip overloads. if isinstance(model, torch.fx.GraphModule): strip_overloads(model) @@ -442,3 +464,116 @@ def compile(model: torch.nn.Module, ) return _lower_mlir_module(verbose, output_type, mb.module) + +def run_via_iree(module, *model_args): + try: + import iree_torch + except: + print("ERROR: Failed to import iree_torch") + print("pip install iree-compiler iree-runtime") + print("git clone https://github.com/iree-org/iree-torch && pip install iree-torch --no-deps") + sys.exit(1) + + backend = LinalgOnTensorsTosaBackend() + run_pipeline_with_repro_report( + module, + f"builtin.module(func.func({TOSA_TO_LINALG_FUNC_PIPELINE}))", + "Lowering TOSA backend contract to Linalg-on-Tensors backend contract") + + print("Loading inference function into IREE") + iree_vmfb = iree_torch.compile_to_vmfb( + module, "llvm-cpu") + invoker = iree_torch.load_vmfb(iree_vmfb, "llvm-cpu") + + print("Running inference on IREE") + return invoker.forward(*model_args) + +def run_and_compare(module, model_args, golden): + output = run_via_iree(module, *model_args) + if not isinstance(output, tuple): + golden = (golden, ) + output = (output, ) + + assert len(output) == len(golden) + for output_el, golden_el in zip(output, golden): + rel_err = torch.max((output_el - golden_el)/torch.abs(golden_el)) + print("Relative error: ", rel_err) + assert torch.allclose(output_el, golden_el, rtol=1e-2), "Accuracy issue" + return output + +def compile_and_run(model, model_args, output_type, golden = None): + compile_output_type = output_type + if compile_output_type == "check-tosa": + compile_output_type = "tosa" + + if compile_output_type == "run-tosa": + compile_output_type = "tosa" + + module = compile(model,model_args,output_type=compile_output_type, use_make_fx=True) + + if output_type == "run-tosa": + if golden is None: + golden = model(*model_args) + return run_and_compare(module, model_args, golden) + elif output_type == "check-tosa": + # TOSA lacks a bunch of verifiers. + # Our best way to find issues in the TOSA IR is to try to lower to Linalg + backend = LinalgOnTensorsTosaBackend() + backend.compile(module) + + return module + +@torch.no_grad() +def do(model: torch.nn.Module, + *model_args, + output_type: Union[str, "OutputType"] = OutputType.TORCH, + dtype = None, + output_prefix: Optional[str] = None, + verbose: bool = True, + **model_kwargs, + ): + """ + Converts the given model to torch/tosa. + WARNING: This modifies the model in-place! + """ + + if verbose: + try: + version = importlib.metadata.version('torch-mlir') + except importlib.metadata.PackageNotFoundError: + version = "dev" + print(f"Using torch-mlir {version}") + + model, golden = prepare_model(model, *model_args, dtype=dtype, **model_kwargs) + + compile_output_type = output_type + if compile_output_type in ("check-tosa", "run-tosa"): + compile_output_type = "tosa" + + module = compile(model,model_args,output_type=compile_output_type, use_make_fx=True) + if output_type == "run-tosa": + output = run_via_iree(module, *model_args) + if not isinstance(output, tuple): + golden = (golden, ) + output = (output, ) + + assert len(output) == len(golden) + for output_el, golden_el in zip(output, golden): + rel_err = torch.max((output_el - golden_el)/torch.abs(golden_el)) + print("Relative error: ", rel_err) + return output + + if output_prefix is not None: + prefix = f"{output_prefix}.{output_type}" + if dtype is not None: + assert dtype == torch.bfloat16 + prefix += ".bf16" + + if verbose: + print(f"Writing output files with prefix {prefix}") + with open(f"{prefix}.full.mlir", "w+") as f: + f.write(module.operation.get_asm()) + with open(f"{prefix}.mlir", "w+") as f: + f.write(module.operation.get_asm(large_elements_limit=10)) + + return module diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index 296c1caca99e..9ae050581965 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -3,14 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +import dataclasses from io import StringIO import os import sys import tempfile +from torch_mlir.dynamo import _get_decomposition_table from torch_mlir.passmanager import PassManager from torch_mlir.ir import StringAttr - +from torch.fx.experimental.proxy_tensor import make_fx +import torch def get_module_name_for_debug_dump(module): """Gets a name suitable for a debug dump. @@ -75,3 +78,73 @@ def run_pipeline_with_repro_report(module, raise TorchMlirCompilerError(trimmed_message) from None finally: sys.stderr = original_stderr + +def prepare_model(model, *model_args, dtype = None, **model_kwargs): + """ + Converts the given model to an FX graph. + WARNING: This modifies the model in-place! + """ + + assert len(model_kwargs) == 0, "model_kwargs are not supported yet" + + model.eval() + + if dtype is not None: + model.to(dtype) + + # Needed for models like bigbird-roberta-base that adjust their config during + # runtime saying, e.g. + # Attention type 'block_sparse' is not possible ... + # Changing attention type to 'original_full'..." + # Running the model once updates the config. If we trace while it updates + # the config, torch-mlir fails with + # error: unknown: unsupported by backend contract: module initializers + # See https://github.com/llvm/torch-mlir/issues/2165 + golden = model(*model_args, **model_kwargs) + + def flatten(S): + """ + Flattens a tree of list/tuples into a flat list. + Removes list entries that are None. + """ + if len(S) == 0: + return S + if isinstance(S[0], list) or isinstance(S[0], tuple): + return list(flatten(S[0])) + list(flatten(S[1:])) + if S[0] is None: + return list(flatten(S[1:])) + + return list(S[:1]) + list(flatten(S[1:])) + + class Wrapper(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, *args, **kwargs): + ret = self.model(*args, **kwargs) + + # Torch MLIR does not support return types that are dataclasses + # or lists or nested tuples. + # It also does not support tuples where some elements are None. + # Potential pytorch solution: + # ret, treespec = torch.utils._pytree.tree_flatten(ret) + # but unfortunately, pytree doesn't support dataclasses + # and it doesn't traverse base classes to see that transformer + # outputs derive from OrderedDicts. + # TODO: Remember the transformations done here, so we can revert + # them outside of the model to restore the original output type. + # See approach in make_simple_dynamo_backend. + + if dataclasses.is_dataclass(ret): + ret = tuple([ret.__dict__[field.name] for field in dataclasses.fields(ret)]) + + if isinstance(ret, list) or isinstance(ret, tuple): + ret = flatten(ret) + if len(ret) == 1: + return ret[0] + else: + return tuple(ret) + return ret + + return Wrapper(model), golden diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 0a7da2b00baf..bff145a0be62 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -92,6 +92,12 @@ def aten〇sin〡shape(self: List[int]) -> List[int]: def aten〇cos〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇asin〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇acos〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇hardtanh〡shape(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]: return upstream_shape_functions.unary(self) @@ -263,6 +269,9 @@ def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇pow〇Scalar〡shape(self: float, exponent: List[int]) -> List[int]: + return upstream_shape_functions.unary(exponent) + def aten〇pow〇Tensor_Scalar〡shape(self: List[int], exponent: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -372,6 +381,9 @@ def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) +def prims〇sum〡shape(inp: List[int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(inp, dims, False, output_dtype) + def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) @@ -437,6 +449,10 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]: for i in range(tensor_dim): out.append(self[i] * repeats[i + leading_rank]) return out + +def aten〇repeat_interleave〇Tensor〡shape(repeats: List[int], output_size: Optional[int] = None) -> List[int]: + assert output_size is not None + return [output_size] def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: return upstream_shape_functions.unary(self) @@ -1269,6 +1285,16 @@ def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇acos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1335,6 +1361,15 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def prims〇sum〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> int: + # When invoking prims.sum() with the output_dtype argument, pytorch + # complains that the argument is not known. + # See https://github.com/pytorch/pytorch/issues/102610 + assert output_dtype is None + inp_rank, inp_dtype = inp_rank_dtype + return inp_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -1708,6 +1743,10 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int]) self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇repeat_interleave〇Tensor〡dtype(repeats_rank_dtype: Tuple[int, int], output_size: Optional[int] = None) -> int: + repeats_rank, repeats_dtype = repeats_rank_dtype + return repeats_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1])) def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2561,6 +2600,16 @@ def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) +@check_dtype_function([ + Invocation(2.0, TensorOfShape(3, 4, dtype=torch.float64)), + Invocation(2.0, TensorOfShape(3, 4, dtype=torch.bfloat16)), + Invocation(2, TensorOfShape(4, dtype=torch.int32))]) +def aten〇pow〇Scalar〡dtype(self: Union[int, float], exponent_rank_dtype: Tuple[int, int]) -> int: + exp_rank, exp_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [exp_rank, None] + dtypes = [exp_dtype, get_dtype_of_scalar(self)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2bc06de12951..ece9cce0ba9a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -258,6 +258,8 @@ def emit_with_mutating_variants(key, **kwargs): "aten::cos : (Tensor) -> (Tensor)", "aten::atan : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", + "aten::asin : (Tensor) -> (Tensor)", + "aten::acos : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", @@ -495,7 +497,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") - emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") + emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_canonicalizer=True) emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") @@ -504,6 +506,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)") emit("aten::repeat : (Tensor, int[]) -> (Tensor)") + emit("aten::repeat_interleave.Tensor : (Tensor, int?) -> (Tensor)") emit("aten::reshape : (Tensor, int[]) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") @@ -711,6 +714,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") + emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) # ========================================================================== diff --git a/python/torch_mlir/dynamo.py b/python/torch_mlir/dynamo.py index d3d7978bbfee..9ae51e3b7ca4 100644 --- a/python/torch_mlir/dynamo.py +++ b/python/torch_mlir/dynamo.py @@ -64,6 +64,11 @@ def _get_decomposition_table(): aten.sigmoid_backward, aten._native_batch_norm_legit, aten.squeeze, + aten.cumsum, + aten.im2col, + aten.index_select, + aten.linalg_vector_norm, + aten.index_select, ] # TODO: enable test once 2.1.0 is stable if torch_version_for_comparison() >= version.parse("2.1.0.dev"): diff --git a/python/torch_mlir/fx_minifier.py b/python/torch_mlir/fx_minifier.py new file mode 100644 index 000000000000..f6cec8d9a527 --- /dev/null +++ b/python/torch_mlir/fx_minifier.py @@ -0,0 +1,321 @@ +# Patched version of the same file in pytorch +# Remove once https://github.com/pytorch/pytorch/issues/102169 is fixed +# upstream. +import torch.fx as fx +import copy +import torch +import math +import sys +from typing import Callable, List +from functools import wraps, partial +from dataclasses import dataclass +from torch._functorch.compile_utils import get_placeholders, get_outputs + +class ConcreteProp(torch.fx.Interpreter): + def run_node(self, n): + result = super().run_node(n) + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return obj + else: + return obj + + from torch.fx.node import map_aggregate + concrete_value = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta['concrete_value'] = concrete_value + return result + + def propagate(self, *args): + return super().run(*args) + + +# inplace modifies node/inps +def _convert_node_to_placeholder(node, inps): + if node.op == 'output' or node.op == "placeholder": + return + node.op = 'placeholder' + node.args = () + node.kwargs = {} + node.target = node.name + concrete_val = node.meta.get('concrete_value', None) + if isinstance(concrete_val, torch.Tensor): + inps.append(concrete_val) + else: + inps.append(torch.zeros(())) + for tuple_user in list(node.users): + _convert_node_to_placeholder(tuple_user, inps) + +def dump_state(fx_g, inps): + print(f""" +# Working Repro with {len(fx_g.graph.nodes)} nodes +inps = {[(i.shape, i.dtype, i.device.type) for i in inps]} +inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps] +{fx_g.code} +""") + +@dataclass +class ReproState: + graph: fx.Graph + inps: List[torch.Tensor] + +def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable = dump_state): + """ + Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. + + Does 2 main strategies: + 1. Truncates suffix: Removes some suffix from the graph and sets a new output. + 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, + tries replacing quarter of the graph, etc. + + >>> # xdoctest: +SKIP(failing) + >>> failing_function = fx.symbolic_trace(f) + >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) + + note: module_fails returns True if it fails. + """ + failing_graph = fail_f.graph + cur_size = len(failing_graph.nodes) + + num_queries = 0 + + def deepcopy_fx_graph(fx_graph): + return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph + + + def graph_fails(graph, inps): + nonlocal num_queries + graph = copy.deepcopy(graph) + num_queries += 1 + mod = fx.GraphModule(fail_f, graph) + mod.graph.lint() + return module_fails(mod, inps) + + ConcreteProp(fail_f).propagate(*inps) + if not graph_fails(failing_graph, inps): + raise RuntimeError("Input graph did not fail the tester") + print(f"Started off with {cur_size} nodes", file=sys.stderr) + + def _register_strategy(strategy: Callable, name: str): + @wraps(strategy) + def new_func(old_state: ReproState, granularity=1): + print(file=sys.stderr) + print( + f"Strategy: {name} (G: {granularity}) " + f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)", + file=sys.stderr + ) + new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity) + if new_state is not None: + new_nodes = len(new_state.graph.nodes) + old_nodes = len(old_state.graph.nodes) + new_inps = len(new_state.inps) + old_inps = len(old_state.inps) + new_outs = len(get_outputs(new_state.graph)) + old_outs = len(get_outputs(old_state.graph)) + progress_made = False + if new_nodes < old_nodes: + progress_made = True + print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", file=sys.stderr) + if new_inps > old_inps: + progress_made = True + print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs", file=sys.stderr) + if new_outs < old_outs: + progress_made = True + print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs", file=sys.stderr) + + if not progress_made: + raise RuntimeError("Success raised but no progress made?") + + if not graph_fails(new_state.graph, new_state.inps): + print("WARNING: Something went wrong, not applying this minification", file=sys.stderr) + return None + return new_state + else: + print(f"FAIL: {name}", file=sys.stderr) + return None + + return new_func + + def register_strategy(name: str): + return partial(_register_strategy, name=name) + + @register_strategy("Truncate suffix") + def remove_suffix(cur_graph, cur_inps, granularity): + tested = set() + new_graph = fx.Graph() + env = {} + for idx, node in enumerate(cur_graph.nodes): + new_node = new_graph.node_copy(node, lambda x: env[x]) + if node.op not in ['placeholder', 'output']: + # If idx is divisible by (granularity * 2), it would have been checked already. + if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested: + output_node = new_graph.output(new_node) + if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(new_graph, cur_inps): + return ReproState(new_graph, cur_inps) + else: + tested.add(idx) + new_graph.erase_node(output_node) + env[node] = new_node + return None + + @register_strategy("Remove outputs") + def remove_outputs(cur_graph, cur_inps, granularity): + granularity = max(1, granularity // 2) + for idx, node in enumerate(cur_graph.nodes): + node.idx = idx + if node.op == 'output': + output = node + break + + if isinstance(output.args[0], fx.Node): + # Only one output, nothing to reduce + return None + + output_args = sorted(output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9)) + if len(output_args) == 1: + return None + + for idx in range(0, len(output_args), granularity): + output.args = (output_args[:idx] + output_args[idx + granularity:],) + if len(output.args[0]) == 1: + output.args = (output.args[0][0],) + if graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + def remove_unused_inputs_unchecked(cur_state: ReproState): + cur_graph = cur_state.graph + cur_inps = cur_state.inps + ph_nodes = get_placeholders(cur_graph) + if len(ph_nodes) != len(cur_inps): + return None + assert len(ph_nodes) == len(cur_inps) + + new_inps = [] + for idx in range(len(ph_nodes)): + if len(ph_nodes[idx].users) == 0: + cur_graph.erase_node(ph_nodes[idx]) + else: + new_inps.append(cur_inps[idx]) + if len(new_inps) < len(cur_inps): + return ReproState(cur_graph, new_inps) + return None + + def remove_unused_inputs_checked(cur_state: ReproState): + new_state = remove_unused_inputs_unchecked(cur_state) + if new_state is not None and graph_fails(new_state.graph, new_state.inps): + return new_state + return None + + def _remove_unused_wrapper(cur_graph, cur_inps, granularity): + return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps)) + + remove_unused_inputs = register_strategy("Remove unused inputs")(_remove_unused_wrapper) + + @register_strategy("Eliminate dead code") + def eliminate_dead_code(cur_graph, cur_inps, granularity): + if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + + def _consolidate_placeholders(cur_graph): + new_graph = fx.Graph() + env = {} + for node in cur_graph.nodes: + if node.op == 'placeholder': + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + + for node in cur_graph.nodes: + if node.op != 'placeholder': + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + return new_graph + + @register_strategy("Delta Debugging") + def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity): + num_nodes = len(cur_graph.nodes) + for start_range in range(0, num_nodes, granularity): + is_removing = False + new_graph = deepcopy_fx_graph(cur_graph) + new_inps = cur_inps[:] + end_range = min(num_nodes, start_range + granularity) + for idx in range(start_range, end_range): + new_node = list(new_graph.nodes)[idx] + if new_node.op not in ['placeholder', 'output']: + is_removing = True + _convert_node_to_placeholder(new_node, new_inps) + if not is_removing: + continue + new_graph = _consolidate_placeholders(new_graph) + new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps)) + if new_state is None: + new_state = ReproState(new_graph, new_inps) + if graph_fails(new_state.graph, new_state.inps): + return ReproState(new_state.graph, new_state.inps) + + return None + + failing_state = ReproState(failing_graph, inps) + + def try_granularity(failing_state, granularity, use_non_granular): + print(f"Trying granularity {granularity}", file=sys.stderr) + + strategies = [] + num_nodes = len(failing_state.graph.nodes) + num_outputs = len(get_outputs(failing_state.graph)) + if num_outputs > num_nodes // 2: + strategies += [remove_outputs] + + if use_non_granular: + strategies += [eliminate_dead_code, remove_unused_inputs] + + strategies += [remove_suffix, delta_debugging] + + for strategy in strategies: + new_state = strategy(failing_state, granularity) + if new_state is not None: + return new_state + return None + + while True: + dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps) + granularity = int(2**(math.floor(math.log2(len(failing_state.graph.nodes))))) + new_state = try_granularity(failing_state, granularity, use_non_granular=True) + if new_state is not None: + failing_state = new_state + continue + + granularity //= 2 + has_progress = False + while granularity >= 1: + new_state = try_granularity(failing_state, granularity, use_non_granular=False) + if new_state is not None: + failing_state = new_state + has_progress = True + break + granularity //= 2 + if has_progress: + continue + + new_state = remove_outputs(failing_state, 1) + if new_state is not None: + failing_state = new_state + continue + + break + + if not graph_fails(failing_state.graph, failing_state.inps): + raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing") + + print(f"Made {num_queries} queries", file=sys.stderr) + failing_fx = fx.GraphModule(fail_f, failing_state.graph) + dump_state(failing_fx, failing_state.inps) + return failing_fx, failing_state.inps diff --git a/python/torch_mlir/repro.py b/python/torch_mlir/repro.py new file mode 100644 index 000000000000..398b0e695291 --- /dev/null +++ b/python/torch_mlir/repro.py @@ -0,0 +1,219 @@ +""" +Example: + +class Model(torch.nn.Module): + def forward(self, x): + x = x / 2.0 + x = x + 2 + x = x * 3 + return x, x *5 + +model = Model() +inputs = (torch.ones(5, 4), ) +out = model(*inputs) + +reproduce(model, inputs, output_type="tosa", expected_error="failed to legalize") +""" + + +import contextlib +import io +import re +from typing import List, Optional +import torch +import torch_mlir + +from torch_mlir.dynamo import _get_decomposition_table +from torch.fx.experimental.proxy_tensor import make_fx +import torch.fx as fx + +from .compiler_utils import prepare_model +from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import ( + LinalgOnTensorsTosaBackend, + ) + +# TODO: Switch to +# from functorch.compile import minifier +# once the bug mentioned at the top of fx_minifier.py is fixed. +from .fx_minifier import minifier + + +class bcolors: + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + +_REs = { + r"RuntimeError:": r"RuntimeError: ", # change so its kept + r"NameError:": r"NameError: ", + r"ImportError:": r"ImportError: ", + r"error: unknown:": r"error:", + r"assert torch.allclose": r"Did not match accuracy", + r'error: ["<>a-zA-Z0-9._/-]+:[0-9]+:[0-9]+: (.*)': r"error: \1", + r".*unsupported by backend contract: tensor with unknown rank": "unsupported by backend contract: tensor with unknown rank", + r"torch.initialize.global_slots.*": r"torch.initialize.global_slots", + r'note: ["<>a-zA-Z0-9._/-]+:[0-9]+:[0-9]+: (.*)': r"note: \1", + r"note: unknown:": r"note:", + r"note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py": "", + r"%(arg)?[0-9]+": "%SSA", + r"\[[0-9]+(,[0-9]+)*\]": r"[dims]", +} + + +def _reduce_error_msg(msg): + lines = [] + for line in msg.splitlines(): + orgline = line + for regex, replacement in _REs.items(): + line = re.sub(regex, replacement, line) + if line != "" and line != orgline: + lines.append(line) + if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""): + return msg + + return ", ".join(lines).strip() + + +def _obtain_errror(fx_g: fx.GraphModule, inputs, output_type: str): + """ + Runs the given module through torch_mlir and returns the error + message produced. + """ + # The minifer introduces functions that return a tuple with a single + # tensor, which is not supported by torch_mlir. + # Wrap the module to unpack those outputs. + # torch.jit.script doesn't support *args and **kwargs as used in + # the wrapper, so we also need to apply make_fx to the wrapped + # model. + # Both of those are implemented by prepare_model(). + # wrapped_g = prepare_model(model, *inputs) + _fix_single_output_tuple(fx_g) + with contextlib.redirect_stderr(io.StringIO()) as stderr: + try: + torch_mlir.compile_and_run(fx_g, inputs, output_type) + return "" + except Exception as e: + return str(e) + stderr.getvalue() + + +def _fix_single_output_tuple(fx_g: fx.GraphModule): + """ + torch_mlir.compile does not support modules that return a tuple of + a single tensor. + Change the module to return the tensor directly. + """ + for idx, node in enumerate(fx_g.graph.nodes): + node.idx = idx + if node.op == "output": + if isinstance(node.args[0], fx.Node): + # Only one output, nothing to reduce + return None + if len(node.args[0]) == 1: + node.args = (node.args[0][0], node.args[1:]) + fx_g.recompile() + + +def _dump_reproducer( + fx_g: fx.GraphModule, inps: List[torch.Tensor], output_type: str, dtype +): + _fix_single_output_tuple(fx_g) + + print("---- SNIP ----") + print("import torch") + print("from torch import tensor, device") # Used inside fx_g.code + print("import torch_mlir") + print("") + + print("class Model(torch.nn.Module):") + print(" ".join(fx_g.code.splitlines(True))) + + print() + print("model = Model()") + args = "" + for inp in inps: + if torch.all(inp == 0): + args += f"torch.zeros({inp.shape}, dtype={inp.dtype}), " + elif torch.all(inp == 1): + args += f"torch.ones({inp.shape}, dtype={inp.dtype}), " + else: + torch.set_printoptions(threshold=100000) + args += f"torch.tensor({str(inp)}, dtype={inp.dtype}), " + if dtype is not None: + print(f"model.to({dtype})") + print(f"inps = ({args})") + print("golden = model(*inps)") + print("# if you want to see the raw IR, you can print(torch_mlir.compile(model, inps, output_type='raw')") + print(f"torch_mlir.compile_and_run(model, inps, output_type='{output_type}', golden=golden)") + print("") + print("---- SNIP ----") + +def _reduce_inputs(inps, are_inputs_good): + for i in range(len(inps)): + new_inps = inps.copy() + new_inps[i] = torch.zeros(inps[i].shape, dtype=inps[i].dtype) + if are_inputs_good(new_inps): + inps = new_inps + return inps + +@torch.no_grad() +def reproduce( + model: torch.nn.Module, + inputs, + output_type="torch", + dtype=None, + expected_error: Optional[str] = None, + verbose=False, +): + """ + Reduces the given model while ensuring that the error message seen by passing + the model through torch_mlir.compile() doesn't change. + + When dtype is provided, calls model.to(dtype) as first step. + + This function tries to automatically determine the essential parts of the + error message. You can also pass it explicitly via the expected_error + parameter. + """ + + model, _ = prepare_model(model, *inputs, dtype=dtype) + fx_g = make_fx( + model, + decomposition_table=_get_decomposition_table())(*inputs) + + error = _obtain_errror(fx_g, inputs, output_type=output_type) + if error == "": + print("ERROR: torch_mlir.compile passes, nothing to reproduce") + return + + print(f"Found error:\n{error}\nEND") + + if expected_error is None: + expected_error = _reduce_error_msg(error) + + print( + f"Looking for error message '{bcolors.WARNING}{expected_error}{bcolors.ENDC}'" + ) + + def module_fails(fx_g, inputs): + error = _obtain_errror(fx_g, inputs, output_type=output_type) + reduced_error = _reduce_error_msg(error) + fails = expected_error in reduced_error + if verbose: + print( + f"Testing graph\n{fx_g.code}\nERROR: {error}\nREDUCED_ERROR: {reduced_error}\nModule fails?: {fails}" + ) + return fails + + + def show_reproducer(fx_g: fx.GraphModule, inps: List[torch.Tensor]): + inps = _reduce_inputs(inps, lambda inputs: module_fails(fx_g, inputs)) + _dump_reproducer(fx_g, inps, output_type, dtype) + + minifier(fx_g, inputs, module_fails, dump_state=show_reproducer) diff --git a/python/torch_mlir_e2e_test/configs/tosa_backend.py b/python/torch_mlir_e2e_test/configs/tosa_backend.py index 8b41cfeda535..89b90567b1d4 100644 --- a/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -23,14 +23,15 @@ class TosaBackendTestConfig(TestConfig): This class handles all the common lowering that torch-mlir does before reaching the linalg-on-tensors abstraction level. """ - def __init__(self, backend: TosaBackend): + def __init__(self, backend: TosaBackend, use_make_fx: bool = False): super().__init__() self.backend = backend + self.use_make_fx = use_make_fx def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torch_mlir.compile( - program, example_args, output_type="tosa") + program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) return self.backend.compile(module) diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index f4c4e5176cd0..23c727405a60 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -167,6 +167,8 @@ def invoke(*args): "expand-strided-metadata", "finalize-memref-to-llvm", "lower-affine", + "convert-bufferization-to-memref", + "finalize-memref-to-llvm", "func.func(convert-arith-to-llvm)", "convert-func-to-llvm", "convert-cf-to-llvm", diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 7add08a3ecee..89946c4858ee 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -14,6 +14,9 @@ "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", "ReduceMaxAlongDimUnsignedInt_basic", + "RepeatInterleaveModule_basic", + "Im2ColModule_basic", + "ElementwiseClampIntModule_basic", } # TODO: Delete once torch 2.1.0 is released diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index c0009713913a..564625ff31b9 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -713,6 +713,32 @@ def TensorsConcatNegativeDimStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatPromoteDTypeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 2, 4], torch.bool, True), + ([2, 1, 4], torch.int32, True), + ([2, 3, 4], torch.int64, True), + ]) + def forward(self, x, y, z): + return torch.cat([x, y, z], dim=-2) + + +@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeStaticModule()) +def TensorsConcatPromoteDTypeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(), + tu.randint(2, 1, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long()) + + +# ============================================================================== + + class TensorsStackModule(torch.nn.Module): def __init__(self): @@ -1304,6 +1330,28 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastToDifferentRankStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 8], torch.float32, True), + ]) + def forward(self, x): + return torch.broadcast_to(x, [1, 2, 8]) + + +@register_test_case(module_factory=lambda: BroadcastToDifferentRankStaticModule()) +def BroadcastToDifferentRankStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 8)) + + +# ============================================================================== + + class BroadcastToSameRankStaticModule(torch.nn.Module): def __init__(self): @@ -1412,6 +1460,70 @@ def RepeatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) # ============================================================================== +class RepeatInterleaveModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4], torch.int, True), + ]) + def forward(self, x): + z = torch.ops.aten.repeat_interleave(x, output_size=10) + y = torch.ops.aten.repeat_interleave(x) + return z, y + + +@register_test_case(module_factory=lambda: RepeatInterleaveModule()) +def RepeatInterleaveModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([3, 1, 2, 4], dtype=torch.int)) + +# ============================================================================== +class RepeatInterleaveFillModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1], torch.int, True), + ]) + def forward(self, x): + x = torch.ops.aten.fill_(x, 2) + x = torch.ops.aten.expand(x, [16]) + return torch.ops.aten.repeat_interleave(x) + + +@register_test_case(module_factory=lambda: RepeatInterleaveFillModule()) +def RepeatInterleaveFillModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([1], dtype=torch.int)) + + +# ============================================================================== + +class RepeatInterleaveStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + x = torch.ones((10), dtype=torch.int).fill_(3) + z = torch.ops.aten.repeat_interleave(x, output_size=30) + return z + + +@register_test_case(module_factory=lambda: RepeatInterleaveStaticModule()) +def RepeatInterleaveStaticModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== class ExpandModule(torch.nn.Module): @@ -2122,6 +2234,27 @@ def forward(self, x, index): def IndexTensorModule3dInput_basic(module, tu: TestUtils): module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) +# ============================================================================== + + +class IndexTensorModule3dInputStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 4, 3], torch.float32, True), + ([2, 3], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, (index,)) + + +@register_test_case(module_factory=lambda: IndexTensorModule3dInputStatic()) +def IndexTensorModule3dInputStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) # ============================================================================== @@ -3422,6 +3555,25 @@ def forward(self, val): def AtenToDeviceModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + +# ============================================================================== +class AtenToDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2], torch.bool, True), + ]) + + def forward(self, val): + return torch.ops.aten.to(val, dtype=torch.int32, non_blocking=False) + +@register_test_case(module_factory=lambda: AtenToDtypeModule()) +def AtenToDtypeModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([True, False], dtype=torch.bool)) + # ============================================================================== @@ -4100,3 +4252,23 @@ def forward(self, x): @register_test_case(module_factory=lambda: Add_Module()) def Add_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3)) + +# ============================================================================== + +class Im2Col_Module(torch.nn.Module): + + def __init__(self): + super().__init__() + self.tensor = torch.ones(2, 3) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.im2col(x, [9, 1], [1, 1], [4, 0], [1, 1]); + +@register_test_case(module_factory=lambda: Im2Col_Module()) +def Im2ColModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3,4,5,2)) diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index a4aa1e99bd10..b50a2a1f02cd 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1157,6 +1157,25 @@ def ZeroInt64Module_basic(module, tu: TestUtils): # ============================================================================== +class NewEmptyModuleBool(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) + + +@register_test_case(module_factory=lambda: NewEmptyModuleBool()) +def NewEmptyModuleBool_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3, high=2).to(dtype=torch.bool)) + + class NewEmptyModuleDefaultDtype(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index 006301b9fc79..64116d059cc2 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -10,6 +10,71 @@ # ============================================================================== +class Conv1dNoPaddingModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 768, 768], torch.float32, True), + ([768, 768, 1], torch.float32, True), + ([768], torch.float32, True), + ]) + def forward(self, x, weights, bias): + return torch.ops.aten.convolution(x, weights, bias, [1], [0], [1], False, [0], 1) + + +@register_test_case(module_factory=lambda: Conv1dNoPaddingModule()) +def Conv1dNoPaddingModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 768, 768), tu.rand(768, 768, 1), torch.ones(768)) + +# ============================================================================== + +class Conv1dNoPaddingTransposeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 768, 768], torch.float32, True), + ([768, 768, 1], torch.float32, True), + ([768], torch.float32, True), + ]) + def forward(self, x, weights, bias): + return torch.ops.aten.convolution(x, weights, bias, [1], [0], [1], True, [0], 1) + + +@register_test_case(module_factory=lambda: Conv1dNoPaddingTransposeModule()) +def Conv1dNoPaddingTransposeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 768, 768), tu.rand(768, 768, 1), torch.ones(768)) + +# ============================================================================== + +class Conv1dNoPaddingGroupModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1,3072,12], torch.float32, True), + ([768, 768, 1], torch.float32, True), + ([768], torch.float32, True), + ]) + def forward(self, x, weights, bias): + return torch.ops.aten.convolution(x, weights, bias, [1], [0], [1], False, [0], 4) + + +@register_test_case(module_factory=lambda: Conv1dNoPaddingGroupModule()) +def Conv1dNoPaddingGroupModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1,3072,12), tu.rand(768, 768, 1), torch.ones(768)) + +# ============================================================================== class Conv2dNoPaddingModule(torch.nn.Module): @@ -475,6 +540,34 @@ def forward(self, inputVec, weight): def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) +# ============================================================================== + +class Convolution2DGroupsStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 32, 4, 4], torch.float32, True), + ([32, 8, 3, 3], torch.float32, True), + ([32], torch.float32, True), + ]) + def forward(self, x, weight, bias): + return torch.ops.aten.convolution(x, + weight, + bias=bias, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=4) + +@register_test_case(module_factory=lambda: Convolution2DGroupsStatic()) +def Convolution2DGroupsStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3), torch.ones(32)) + class ConvolutionModule2DGroups(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index ba3b0b34103a..2a68d4ba5883 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -685,6 +685,31 @@ def ElementwiseClampModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + int_min = torch.clamp(x, min=-3) + int_max = torch.clamp(x, max=3) + both = torch.clamp(x, min=-5, max=5) + return int_min, int_max, both + + +@register_test_case(module_factory=lambda: ElementwiseClampIntModule()) +def ElementwiseClampIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10)) + + +# ============================================================================== + + class ElementwiseClampMinModule(torch.nn.Module): def __init__(self): @@ -1116,6 +1141,96 @@ def ElementwiseLogModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAsinTensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinTensorFloatModule()) +def ElementwiseAsinTensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4)) + + +# ============================================================================== + + +class ElementwiseAsinTensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinTensorIntModule()) +def ElementwiseAsinTensorIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, low=1, high=10).type(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAcosTensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 4], torch.float32, True), + ]) + def forward(self, a): + return torch.acos(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcosTensorFloatModule()) +def ElementwiseAcosTensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4)) + + +# ============================================================================== + + +class ElementwiseAcosTensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ]) + def forward(self, a): + return torch.acos(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcosTensorIntModule()) +def ElementwiseAcosTensorIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, low=1, high=10).type(torch.int32)) + + +# ============================================================================== + + class ElementwiseLogIntModule(torch.nn.Module): def __init__(self): @@ -1313,6 +1428,23 @@ def ElementwiseSignModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePowScalarModule(torch.nn.Module): + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return torch.ops.aten.pow(0.5, x) + +@register_test_case(module_factory=lambda: ElementwisePowScalarModule()) +def ElementwisePowScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwisePowModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 16fa9a196ece..67b168e7bf5e 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -477,6 +477,26 @@ def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseEqBoolScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, x): + return torch.eq(x, 1) + + +@register_test_case(module_factory=lambda: ElementwiseEqBoolScalarModule()) +def ElementwiseEqBoolScalarModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=1, dtype=torch.bool)) + + +# ============================================================================== + class ElementwiseEqDiffWidthScalarModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/index_select.py b/python/torch_mlir_e2e_test/test_suite/index_select.py index 0fdda62a13a0..e76c85503a4a 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -11,6 +11,26 @@ # ============================================================================== +class IndexSelectStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.tensor = torch.ones(2, 3) + + @export + @annotate_args([ + None, + ([3, 3], torch.float32, True), + ([1], torch.int, True), + ]) + def forward(self, x, y): + return torch.ops.aten.index_select(x, 0, y) + + +@register_test_case(module_factory=lambda: IndexSelectStaticModule()) +def IndexSelectStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3), torch.tensor([1], dtype=torch.int)) + class IndexSelectSingleIdxModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index dd2112110f6f..1f459affd5ec 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -68,6 +68,25 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class PrimsSumFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.prims.sum(a, (0, 1)) + + +@register_test_case(module_factory=lambda: PrimsSumFloatModule()) +def PrimsSumFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/scatter.py b/python/torch_mlir_e2e_test/test_suite/scatter.py index bc1df18322ac..5e3ea6e8c44f 100644 --- a/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -61,6 +61,59 @@ def forward(self, input, index, value): def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) +class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4], torch.int64, True), + ([3], torch.int64, True), + ([1, 3], torch.int64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (None, index), + value, + accumulate=False, + unsafe=False) + + +@register_test_case( + module_factory=lambda: IndexPutImpl2DNoneIndexStaticModule()) +def IndexPutImpl2DNoneIndexStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 4, high=3), tu.randint(3, high=3), tu.randint(1, 3, high=1)) + + +# ============================================================================== + +class IndexPutImpl2DNoneIndexBroadcastStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4], torch.int64, True), + ([3], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (None, index), + value, + accumulate=False, + unsafe=False) + + +@register_test_case( + module_factory=lambda: IndexPutImpl2DNoneIndexBroadcastStaticModule()) +def IndexPutImpl2DNoneIndexBroadcastStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 4, high=3), tu.randint(3, high=3), torch.tensor(0)) + +# ============================================================================== + class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index a188aa3c52fe..25f3bca7a306 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -73,6 +73,28 @@ def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceOutOfUpperBoundIndexStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 4, 7], torch.float32, True), + ]) + def forward(self, x): + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[:8, :5, 8:] + cat_tensor = torch.ones((6,4,1), dtype=torch.float32) + return torch.cat((result,cat_tensor), dim=2) + + +@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexStaticModule()) +def SliceOutOfUpperBoundIndexStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module): def __init__(self): super().__init__() @@ -111,6 +133,25 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceOutOfLowerBoundStartIndexStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 4, 7], torch.float32, True), + ]) + def forward(self, x): + return x[-8:3:1, :, :] + + +@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexStaticModule()) +def SliceOutOfLowerBoundStartIndexStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + class SliceEndSleStartModule(torch.nn.Module): def __init__(self): @@ -135,6 +176,30 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceEndSleStartStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 4, 7], torch.float32, True), + ]) + def forward(self, x): + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[:, 4:3, :] + cat_tensor = torch.ones((6,1,7), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=1) + + +@register_test_case(module_factory=lambda: SliceEndSleStartStaticModule()) +def SliceEndSleStartStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + + +# ============================================================================== + + class SliceStartEqEndModule(torch.nn.Module): def __init__(self): super().__init__() @@ -157,6 +222,25 @@ def SliceStartEqEndModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceSizeTwoStepDivisibleStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 6, 16], torch.float32, True), + ]) + def forward(self, x): + return x[0:5:2, 0:3:2, 0:4:2] + + +@register_test_case(module_factory=lambda: SliceSizeTwoStepDivisibleStaticModule()) +def SliceSizeTwoStepDivisibleStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10,6,16)) + +# ============================================================================== + class SliceSizeTwoStepModule(torch.nn.Module): def __init__(self): super().__init__() @@ -543,6 +627,28 @@ def forward(self, x, y): def SliceCopyNegative_Module_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4)) +# ============================================================================== + +class SliceCopyMax_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y): + # A slice without specified end uses the max. value of int64_t + xslice = torch.ops.aten.slice(x, 0, 0, 9223372036854775807, 1) + xslice.copy_(y) + return x + + +@register_test_case(module_factory=lambda: SliceCopyMax_Module()) +def SliceCopyMax_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4, 4), tu.rand(4, 4, 4)) # ============================================================================== @@ -655,6 +761,76 @@ def forward(self, x): def UnbindIntGetItem_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) +# ============================================================================== + + +class TensorsSplitTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([6, 10, 12], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 2, dim=0) + return s1 + + +@register_test_case(module_factory=lambda: TensorsSplitTensorModule()) +def TensorsSplitTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 10, 12)) + +# ============================================================================== + + +class TensorsSplitTensorLastSmallerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([8, 10, 12], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 3, dim=0) + return s2 + + +@register_test_case(module_factory=lambda: TensorsSplitTensorLastSmallerModule()) +def TensorsSplitTensorLastSmallerModule_basic(module, tu: TestUtils): + # Splitting the first dimension with 8 elements into chunks of 3 + # will leave the last result to have 2 elements in that dimension. + module.forward(tu.rand(8, 10, 12)) + +# ============================================================================== + + +class TensorsSplitTensorNegativeDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 12, 6], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 2, -1) + return s1 + + +@register_test_case(module_factory=lambda: TensorsSplitTensorNegativeDimModule()) +def TensorsSplitTensorNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 12, 6)) + +# ============================================================================== + # ============================================================================== diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index 6999989a6743..9317a3020624 100644 --- a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -25,6 +25,7 @@ # ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg, # that depend on TOSA as well as TOSA-to-Standard. "tosa-to-arith", + "tosa-to-scf", # Named ops must be legalized prior to general tosa-to-linalg "tosa-to-linalg-named", # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them diff --git a/requirements.txt b/requirements.txt index f346b53da470..6c86e58ae9c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ --r pytorch-requirements.txt -r build-requirements.txt +-r pytorch-requirements.txt +-r torchvision-requirements.txt -r test-requirements.txt diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index 8a6ec8d7266a..206084873c81 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -7,8 +7,8 @@ // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[INT10:.*]] = torch.constant.int 10 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT10]] // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -48,8 +48,8 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> + %int10 = torch.constant.int 10 + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int10, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32> } diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 275435599eb3..0d0e95502e3a 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -113,13 +113,28 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 // ----- +// CHECK-LABEL: torch.aten.div.Scalar$int_input_fp_output +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1.280000e+02> : tensor}> : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_4]]) <{shift = 0 : i32}> : (tensor, tensor<1x1xf32>) -> tensor +func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],f32> { + %int128 = torch.constant.int 128 + %0 = torch.aten.div.Scalar %arg0, %int128 : !torch.vtensor<[?, ?],si64>, !torch.int -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + // CHECK-LABEL: torch.aten.pow.Tensor$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor<1x1xf32>) -> tensor func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> { - %fp0 = torch.constant.float 3.123400e+00 + %fp0 = torch.constant.float 3.000000e+00 %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 8910caa864d9..b1e9886d369e 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1937,6 +1937,27 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor< return %0 : !torch.vtensor<[4],f32> } +// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice +// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[?],f32> +// CHECK: return %[[ARG0]] : !torch.vtensor<[?],f32> +func.func @torch.aten.slice.tensor$fold_full_slice(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> { + %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int0 = torch.constant.int 0 + %0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32> + return %0 : !torch.vtensor<[?],f32> +} + +// CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step +// CHECK: torch.aten.slice.Tensor +func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> { + %int2 = torch.constant.int 2 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int0 = torch.constant.int 0 + %0 = torch.aten.slice.Tensor %arg0, %dim, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?], f32> + return %0 : !torch.vtensor<[?],f32> +} + // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %int-1 = torch.constant.int -1 // CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index abaa2860cb85..5fa1a5df5d08 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -79,3 +79,42 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.acos$int_type( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si32> { +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_2]], %[[VAL_2]] : !torch.vtensor<[2,2],si32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],si32> +// CHECK: %[[VAL_4:.*]] = torch.aten.neg %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> !torch.vtensor<[2,2],si32> +// CHECK: %[[VAL_5:.*]] = torch.aten.add.Scalar %[[VAL_4]], %[[VAL_2]], %[[VAL_2]] : !torch.vtensor<[2,2],si32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],si32> +// CHECK: %[[VAL_6:.*]] = torch.aten.mul.Tensor %[[VAL_3]], %[[VAL_5]] : !torch.vtensor<[2,2],si32>, !torch.vtensor<[2,2],si32> -> !torch.vtensor<[2,2],si32> +// CHECK: %[[VAL_7:.*]] = torch.aten.sqrt %[[VAL_6]] : !torch.vtensor<[2,2],si32> -> !torch.vtensor<[2,2],si32> +// CHECK: %[[VAL_8:.*]] = torch.aten.atan2 %[[VAL_7]], %[[VAL_0]] : !torch.vtensor<[2,2],si32>, !torch.vtensor<[2,2],si32> -> !torch.vtensor<[2,2],si32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[2,2],si32> +// CHECK: } + +func.func @torch.aten.acos$int_type(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torch.vtensor<[2, 2],si32>) -> !torch.vtensor<[2, 2],si32> { + %0 = torch.aten.acos %arg0 : !torch.vtensor<[2, 2],si32> -> !torch.vtensor<[2, 2],si32> + return %0 : !torch.vtensor<[2, 2],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.acos$float_type( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> { +// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_2]], %[[VAL_2]] : !torch.vtensor<[2,2],f32>, !torch.float, !torch.float -> !torch.vtensor<[2,2],f32> +// CHECK: %[[VAL_4:.*]] = torch.aten.neg %[[VAL_0]] : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> +// CHECK: %[[VAL_5:.*]] = torch.aten.add.Scalar %[[VAL_4]], %[[VAL_2]], %[[VAL_2]] : !torch.vtensor<[2,2],f32>, !torch.float, !torch.float -> !torch.vtensor<[2,2],f32> +// CHECK: %[[VAL_6:.*]] = torch.aten.mul.Tensor %[[VAL_3]], %[[VAL_5]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> +// CHECK: %[[VAL_7:.*]] = torch.aten.sqrt %[[VAL_6]] : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> +// CHECK: %[[VAL_8:.*]] = torch.aten.atan2 %[[VAL_7]], %[[VAL_0]] : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[2,2],f32> +// CHECK: } +func.func @torch.aten.acos$float_type(%arg0: !torch.vtensor<[2, 2],f32>, %arg1: !torch.vtensor<[2, 2],f32>) -> !torch.vtensor<[2, 2],f32> { + %0 = torch.aten.acos %arg0 : !torch.vtensor<[2, 2],f32> -> !torch.vtensor<[2, 2],f32> + return %0 : !torch.vtensor<[2, 2],f32> +}