Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@
"IndexSelectNegativeDimModule_basic",
"IndexSelectStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorModule3dInputStatic_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
Expand Down Expand Up @@ -986,6 +987,7 @@
"ReduceAmaxKeepDim_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ElementwiseLog2Module_basic",
Expand Down Expand Up @@ -1054,6 +1056,7 @@
"BaddbmmWithBetaModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
Expand Down Expand Up @@ -1090,6 +1093,7 @@
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorModule3dInputStatic_basic",
"ElementwiseWhereScalarModule_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullModuleDefaultDtype_basic",
Expand Down Expand Up @@ -1359,6 +1363,7 @@
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule3dInputStatic_basic",
"IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
Expand Down
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 1827 files
159 changes: 159 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3993,6 +3993,11 @@ class ConvertAtenIndexTensorOpNone
op.getLoc(), "unimplemented: index must be ranked tensor");
}

if (indices.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op.getLoc(), "unimplemented: index must be 1d tensor");
}

auto input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy || !inputTy.hasStaticShape())
Expand Down Expand Up @@ -5514,7 +5519,159 @@ class ConvertAtenOpToTosaCustomOp : public OpConversionPattern<AtenOpT> {
std::string implementedWithOpAttr;
};

class SimplifyAtenIndexTensorWithSliceIndex
: public OpRewritePattern<AtenIndexTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;


LogicalResult matchAndRewrite(AtenIndexTensorOp op,
PatternRewriter &rewriter) const override {
auto outTy = dyn_cast<BaseTensorType>(op.getType());
if (!outTy) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}

SmallVector<Value> indices;
if (!getListConstructElements(op.getIndices(), indices))
return failure();

TypedValue<BaseTensorType> input =
dyn_cast<TypedValue<BaseTensorType>>(op.getSelf());
if (!input) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}

if (llvm::count_if(indices, [](Value v) {
return !isa<Torch::NoneType>(v.getType());
}) == 1) {
return rewriter.notifyMatchFailure(op, "nothing to do");
}

auto loc = op->getLoc();

for (size_t i = 0; i < indices.size(); ++i) {
if (isa<Torch::NoneType>(indices[i].getType()))
continue;

auto indicesTy = dyn_cast<BaseTensorType>(indices[i].getType());
if (!indicesTy || !indicesTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "requires indices with static shape");
}
int64_t numIndices = std::accumulate(
indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1,
[&](int64_t a, int64_t b) { return a * b; });
if (numIndices != 1)
continue;

auto inputTy = input.getType();
SmallVector<int64_t> slicedShape{inputTy.getSizes()};
slicedShape[i] = 1;
auto slicedType =
inputTy.getWithSizesAndDtype(slicedShape, inputTy.getDtype());

auto none = rewriter.create<Torch::ConstantNoneOp>(op->getLoc());
SmallVector<Value> sliceIndices{inputTy.getSizes().size(), none};
sliceIndices[i] = reshapeTo(loc, rewriter, indices[i], {1});

Value sliceIndicesV = rewriter.create<PrimListConstructOp>(
loc, op.getIndices().getType(), sliceIndices);
auto slicedInput = rewriter.create<AtenIndexTensorOp>(
loc, slicedType, input, sliceIndicesV);

SmallVector<int64_t> reshapedShape = slicedShape;
reshapedShape.erase(reshapedShape.begin() + i);

auto reshaped = reshapeTo(loc, rewriter, slicedInput, reshapedShape);

SmallVector<Value> newIndicesList{indices};
newIndicesList.erase(newIndicesList.begin() + i);

Value newIndicesListV = rewriter.create<PrimListConstructOp>(
loc, op.getIndices().getType(), newIndicesList);

rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(), reshaped,
newIndicesListV);
return success();
}
return failure();
}
};
class SimplifyAtenIndexTensorWithNdIndex
: public OpRewritePattern<AtenIndexTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(AtenIndexTensorOp op,
PatternRewriter &rewriter) const override {
auto outTy = dyn_cast<BaseTensorType>(op.getType());
if (!outTy) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}

SmallVector<Value> indices;
if (!getListConstructElements(op.getIndices(), indices))
return failure();

TypedValue<BaseTensorType> input =
dyn_cast<TypedValue<BaseTensorType>>(op.getSelf());
if (!input) {
return rewriter.notifyMatchFailure(op, "requires tensor type");
}
auto loc = op->getLoc();

if (llvm::count_if(indices, [](Value v) {
return !isa<Torch::NoneType>(v.getType());
}) != 1) {
return rewriter.notifyMatchFailure(op, "can only handle single None");
}

for (size_t i = 0; i < indices.size(); ++i) {
if (isa<Torch::NoneType>(indices[i].getType()))
continue;

auto indicesTy = dyn_cast<BaseTensorType>(indices[i].getType());
if (!indicesTy || !indicesTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "requires indices with static shape");
}
if (indicesTy.getSizes().size() == 1) {
continue;
}

// flatten indices
int64_t numIndices = std::accumulate(
indicesTy.getSizes().begin(), indicesTy.getSizes().end(), 1,
[&](int64_t a, int64_t b) { return a * b; });

auto newIndices =
reshapeTo(op.getLoc(), rewriter, indices[i], {numIndices});

SmallVector<Value> newIndicesList{indices};
newIndicesList[i] = newIndices;

Value newIndicesListV = rewriter.create<PrimListConstructOp>(
loc, op.getIndices().getType(), newIndicesList);

SmallVector<int64_t> indexOpShape{outTy.getSizes()};
indexOpShape.erase(indexOpShape.begin() + i,
indexOpShape.begin() + i + indicesTy.getSizes().size());
indexOpShape.insert(indexOpShape.begin() + i, numIndices);

auto indexOpType =
outTy.getWithSizesAndDtype(indexOpShape, outTy.getOptionalDtype());
auto indexed = rewriter.create<AtenIndexTensorOp>(
loc, indexOpType, input, newIndicesListV);

auto reshaped =
reshapeTo(loc, rewriter, indexed, outTy.getSizes());
rewriter.replaceOp(op, reshaped);
return success();
}
return failure();
}
};
} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -5557,6 +5714,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {

patterns.add<SimplifyAten_IndexPutImplOp>(context);
patterns.add<SimplifyAten_IndexPutImplOpNone>(context);
patterns.add<SimplifyAtenIndexTensorWithSliceIndex>(context);
patterns.add<SimplifyAtenIndexTensorWithNdIndex>(context);
patterns.add<ConvertAtenIndexTensorOpNone>(typeConverter, context);

#define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \
Expand Down
23 changes: 22 additions & 1 deletion python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,27 @@ def forward(self, x, index):
def IndexTensorModule3dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))

# ==============================================================================


class IndexTensorModule3dInputStatic(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 4, 3], torch.float32, True),
([2, 3], torch.int64, True),
])
def forward(self, x, index):
return torch.ops.aten.index(x, (index,))


@register_test_case(module_factory=lambda: IndexTensorModule3dInputStatic())
def IndexTensorModule3dInputStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))

# ==============================================================================

Expand Down Expand Up @@ -4228,4 +4249,4 @@ def forward(self, x):

@register_test_case(module_factory=lambda: Im2Col_Module())
def Im2ColModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,4,5,2))
module.forward(tu.rand(3,4,5,2))