diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 7a752867f596..831212f52803 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -535,6 +535,7 @@ "IndexSelectNegativeDimModule_basic", "IndexSelectStaticModule_basic", "IndexTensorStaticModule_basic", + "IndexTensorModule3dInputStatic_basic", "IndexTensorMultiIndexStaticModule_basic", "LayerNormLastDimModule_basic", "LayerNormModule_basic", @@ -986,6 +987,7 @@ "ReduceAmaxKeepDim_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", + "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "ElementwiseLog2Module_basic", @@ -1054,6 +1056,7 @@ "BaddbmmWithBetaModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", + "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", "NumpyTRankNStaticModule_basic", @@ -1090,6 +1093,7 @@ "IndexPutImpl1DIntNonAccumulateModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorModule3dInputStatic_basic", "ElementwiseWhereScalarModule_basic", "FullLikeModuleFloat3DStatic_basic", "FullModuleDefaultDtype_basic", @@ -1359,6 +1363,7 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IndexTensorModule3dInput_basic", + "IndexTensorModule3dInputStatic_basic", "IndexTensorModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", diff --git a/externals/mlir-hlo b/externals/mlir-hlo index ac26bdba7a5e..a4ac6990f751 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit ac26bdba7a5edfe6060ba5be528b9d20c987297d +Subproject commit a4ac6990f7519a569a380452d7c1d3764aad7e59 diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c214f3832bb5..cc4f49f7b732 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3993,6 +3993,11 @@ class ConvertAtenIndexTensorOpNone op.getLoc(), "unimplemented: index must be ranked tensor"); } + if (indices.getType().getRank() != 1) { + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: index must be 1d tensor"); + } + auto input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); if (!inputTy || !inputTy.hasStaticShape()) @@ -5514,7 +5519,159 @@ class ConvertAtenOpToTosaCustomOp : public OpConversionPattern { std::string implementedWithOpAttr; }; +class SimplifyAtenIndexTensorWithSliceIndex + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + + LogicalResult matchAndRewrite(AtenIndexTensorOp op, + PatternRewriter &rewriter) const override { + auto outTy = dyn_cast(op.getType()); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return failure(); + + TypedValue input = + dyn_cast>(op.getSelf()); + if (!input) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + + if (llvm::count_if(indices, [](Value v) { + return !isa(v.getType()); + }) == 1) { + return rewriter.notifyMatchFailure(op, "nothing to do"); + } + + auto loc = op->getLoc(); + + for (size_t i = 0; i < indices.size(); ++i) { + if (isa(indices[i].getType())) + continue; + + auto indicesTy = dyn_cast(indices[i].getType()); + if (!indicesTy || !indicesTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure( + op, "requires indices with static shape"); + } + int64_t numIndices = std::accumulate( + indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1, + [&](int64_t a, int64_t b) { return a * b; }); + if (numIndices != 1) + continue; + + auto inputTy = input.getType(); + SmallVector slicedShape{inputTy.getSizes()}; + slicedShape[i] = 1; + auto slicedType = + inputTy.getWithSizesAndDtype(slicedShape, inputTy.getDtype()); + + auto none = rewriter.create(op->getLoc()); + SmallVector sliceIndices{inputTy.getSizes().size(), none}; + sliceIndices[i] = reshapeTo(loc, rewriter, indices[i], {1}); + Value sliceIndicesV = rewriter.create( + loc, op.getIndices().getType(), sliceIndices); + auto slicedInput = rewriter.create( + loc, slicedType, input, sliceIndicesV); + + SmallVector reshapedShape = slicedShape; + reshapedShape.erase(reshapedShape.begin() + i); + + auto reshaped = reshapeTo(loc, rewriter, slicedInput, reshapedShape); + + SmallVector newIndicesList{indices}; + newIndicesList.erase(newIndicesList.begin() + i); + + Value newIndicesListV = rewriter.create( + loc, op.getIndices().getType(), newIndicesList); + + rewriter.replaceOpWithNewOp(op, op.getType(), reshaped, + newIndicesListV); + return success(); + } + return failure(); + } +}; +class SimplifyAtenIndexTensorWithNdIndex + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtenIndexTensorOp op, + PatternRewriter &rewriter) const override { + auto outTy = dyn_cast(op.getType()); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return failure(); + + TypedValue input = + dyn_cast>(op.getSelf()); + if (!input) { + return rewriter.notifyMatchFailure(op, "requires tensor type"); + } + auto loc = op->getLoc(); + + if (llvm::count_if(indices, [](Value v) { + return !isa(v.getType()); + }) != 1) { + return rewriter.notifyMatchFailure(op, "can only handle single None"); + } + + for (size_t i = 0; i < indices.size(); ++i) { + if (isa(indices[i].getType())) + continue; + + auto indicesTy = dyn_cast(indices[i].getType()); + if (!indicesTy || !indicesTy.areAllSizesKnown()) { + return rewriter.notifyMatchFailure( + op, "requires indices with static shape"); + } + if (indicesTy.getSizes().size() == 1) { + continue; + } + + // flatten indices + int64_t numIndices = std::accumulate( + indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1, + [&](int64_t a, int64_t b) { return a * b; }); + + auto newIndices = + reshapeTo(op.getLoc(), rewriter, indices[i], {numIndices}); + + SmallVector newIndicesList{indices}; + newIndicesList[i] = newIndices; + + Value newIndicesListV = rewriter.create( + loc, op.getIndices().getType(), newIndicesList); + + SmallVector indexOpShape{outTy.getSizes()}; + indexOpShape.erase(indexOpShape.begin() + i, + indexOpShape.begin() + i + indicesTy.getSizes().size()); + indexOpShape.insert(indexOpShape.begin() + i, numIndices); + + auto indexOpType = + outTy.getWithSizesAndDtype(indexOpShape, outTy.getOptionalDtype()); + auto indexed = rewriter.create( + loc, indexOpType, input, newIndicesListV); + + auto reshaped = + reshapeTo(loc, rewriter, indexed, outTy.getSizes()); + rewriter.replaceOp(op, reshaped); + return success(); + } + return failure(); + } +}; } // namespace // ----------------------------------------------------------------------------- @@ -5557,6 +5714,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add(context); patterns.add(context); + patterns.add(context); + patterns.add(context); patterns.add(typeConverter, context); #define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \ diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 56a8c6746352..b6d22ce83395 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2212,6 +2212,27 @@ def forward(self, x, index): def IndexTensorModule3dInput_basic(module, tu: TestUtils): module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) +# ============================================================================== + + +class IndexTensorModule3dInputStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 4, 3], torch.float32, True), + ([2, 3], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, (index,)) + + +@register_test_case(module_factory=lambda: IndexTensorModule3dInputStatic()) +def IndexTensorModule3dInputStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) # ============================================================================== @@ -4228,4 +4249,4 @@ def forward(self, x): @register_test_case(module_factory=lambda: Im2Col_Module()) def Im2ColModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3,4,5,2)) \ No newline at end of file + module.forward(tu.rand(3,4,5,2))