diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d34f6dfcaff4..cc36ceeb953b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5185,6 +5185,30 @@ class DecomposeAtenUnflattenIntOp }; } // namespace +namespace { +template +class DecomposeAtenUpsampleNearestVecOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(UpsampleVecOp op, + PatternRewriter &rewriter) const override { + Value scales = op.getScaleFactors(); + static_assert(std::is_same_v || + std::is_same_v); + Value cstMode = rewriter.create( + op.getLoc(), rewriter.getStringAttr("nearest")); + Value cstNone = rewriter.create(op.getLoc()); + Value cstAntialias = + rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput(), op.getOutputSize(), + op.getScaleFactors(), cstMode, cstNone, cstNone, cstAntialias); + return success(); + } +}; +} // namespace + // Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { @@ -12983,6 +13007,12 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenUpsampleNearestVecOp>( + patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenUpsampleNearestVecOp>( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index d746862193aa..cfc8bb96118b 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -593,6 +593,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e4a2e319d7fe..7e048f2b0143 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -497,7 +497,6 @@ "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", "IsInfiniteModule_basic", - "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", @@ -915,8 +914,12 @@ "TraceUnsignedIntModule_empty", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UpSampleNearest1dVecNoneScales_basic", + "UpSampleNearest1dVecNoneShape_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dVecNoneScales_basic", + "UpSampleNearest2dVecNoneShape_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", # Error: `aten.as_strided` op is not supported @@ -3956,8 +3959,13 @@ "TransposedConv2dNegativePadding_basic", "TransposedConv3dNegativePadding_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "InterpolateDynamicModule_sizes_nearest", + "UpSampleNearest1dVecNoneScales_basic", + "UpSampleNearest1dVecNoneShape_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dVecNoneScales_basic", + "UpSampleNearest2dVecNoneShape_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 2ec87b9fee43..b9dc855b7c0a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1088,6 +1088,94 @@ def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 4)) +class UpSampleNearest2dVecNoneShape(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec( + input, output_size=None, scale_factors=[3.66, 4.2] + ) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneShape()) +def UpSampleNearest2dVecNoneShape_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) + + +class UpSampleNearest2dVecNoneScales(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec( + input, + output_size=[18, 48], + scale_factors=None, + ) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneScales()) +def UpSampleNearest2dVecNoneScales_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) + + +class UpSampleNearest1dVecNoneShape(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest1d.vec( + input, output_size=None, scale_factors=[3.0] + ) + + +@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneShape()) +def UpSampleNearest1dVecNoneShape_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6).to(torch.float64)) + + +class UpSampleNearest1dVecNoneScales(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest1d.vec(input, [18], None) + + +@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneScales()) +def UpSampleNearest1dVecNoneScales_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6).to(torch.float64)) + + class Conv1dModule(torch.nn.Module): def __init__(self): super().__init__()