From 71529afd9ea2c3b90c092ddadf7cae530febdb07 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Thu, 25 Sep 2025 20:34:35 +0000 Subject: [PATCH 1/7] initial approach --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index eb08786f7982..62d47fa64072 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -692,7 +692,6 @@ class ConvertAtenUnflattenIntOp if (outputSizes[i] == Torch::kUnknownSize) numDynamicReassocDims++; } - SmallVector reassocSizes; if (!getListConstructElements(op.getSizes(), reassocSizes) && numDynamicReassocDims > 1) @@ -700,7 +699,7 @@ class ConvertAtenUnflattenIntOp op, "Must be able to either infer expansion dims, or retrieve them " "from list construct"); - auto expandTy = getTypeConverter()->convertType(outputTensorType); + RankedTensorType expandTy = cast(getTypeConverter()->convertType(outputTensorType)); Value expand; // When there are less than two dynamic reassociation dims, this will lower // to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. @@ -717,10 +716,61 @@ class ConvertAtenUnflattenIntOp for (int i = dimInt + numSizes; i < outputRank; ++i) reassociations[i - numSizes + 1].push_back(i); } - expand = rewriter - .create( - loc, expandTy, adaptor.getSelf(), reassociations) - .getResult(); + + // When we have -1 in our sizes, we need to infer the output shape. + // Instead, we calculate this inferred dimensions directly. + SmallVector sizesInts; + bool haveConstSizes = matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizesInts)); + int64_t minusOneIdx = -1; + int64_t knownProduct = 1; + if (haveConstSizes) { + for (int64_t j = 0, e = sizesInts.size(); j < e; ++j) { + if (sizesInts[j] == -1) { + if (minusOneIdx != -1) + minusOneIdx = -2; // more than one -1 -> invalid sizes list + else + minusOneIdx = j; + } else { + knownProduct *= sizesInts[j]; + } + } + } + + bool folded = false; + if (haveConstSizes && minusOneIdx >= 0) { + OpFoldResult numerator; + ArrayRef inShape = inputTensorType.getSizes(); + if (inShape[dimInt] != Torch::kUnknownSize) { + numerator = rewriter.getIndexAttr(inShape[dimInt]); + } else { + SmallVector inputShapeIdx = getTensorSizes(rewriter, loc, adaptor.getSelf()); + numerator = OpFoldResult(inputShapeIdx[dimInt]); + } + + AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); + auto map = AffineMap::get(0, 1, s0.floorDiv(knownProduct), rewriter.getContext()); + OpFoldResult inferred = affine::makeComposedFoldedAffineApply(rewriter, loc, map, ArrayRef{numerator}); + + if (auto attr = inferred.dyn_cast()) { + int64_t inferredAttr = cast(attr).getInt(); // index attr + SmallVector inferShape(expandTy.getShape().begin(), expandTy.getShape().end()); + int64_t pos = dimInt + minusOneIdx; + inferShape[pos] = inferredAttr; + + auto inferTy = RankedTensorType::get(inferShape, expandTy.getElementType()); + Value inferExpand = rewriter.create(loc, inferTy, adaptor.getSelf(), reassociations); + + if (inferTy != expandTy) { + expand = rewriter.create(loc, expandTy, inferExpand).getResult(); + } else { + expand = inferExpand; + } + folded = true; + } + } + if (!folded) { + expand = rewriter.create(loc, expandTy, adaptor.getSelf(), reassociations).getResult(); + } } else { reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), reassocSizes); @@ -745,6 +795,7 @@ class ConvertAtenUnflattenIntOp shapeValue) .getResult(); } + rewriter.replaceOp(op, expand); return success(); } From a9c0f41e916adfb3fcf1c0a96bc71094a18157be Mon Sep 17 00:00:00 2001 From: raayandhar Date: Fri, 26 Sep 2025 19:07:43 +0000 Subject: [PATCH 2/7] different approach but still not great --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 104 ++++++++++-------- 1 file changed, 60 insertions(+), 44 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 62d47fa64072..a348a60008c1 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -717,59 +717,75 @@ class ConvertAtenUnflattenIntOp reassociations[i - numSizes + 1].push_back(i); } - // When we have -1 in our sizes, we need to infer the output shape. - // Instead, we calculate this inferred dimensions directly. - SmallVector sizesInts; - bool haveConstSizes = matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizesInts)); + SmallVector reassocSizeValues; + // Is there a function that already does this somewhere? + auto sizeToOFR = [&](Value sizeVal) -> OpFoldResult { + int64_t constantSize; + if (matchPattern(sizeVal, m_TorchConstantInt(&constantSize))) { + return rewriter.getIndexAttr(constantSize); + } + SmallVector singleSizeVec = {sizeVal}; + Value converted = castIntToIndex(rewriter, loc, getTypeConvertedValues(rewriter, loc, getTypeConverter(), singleSizeVec)[0]); + return OpFoldResult(converted); + }; + int64_t minusOneIdx = -1; - int64_t knownProduct = 1; - if (haveConstSizes) { - for (int64_t j = 0, e = sizesInts.size(); j < e; ++j) { - if (sizesInts[j] == -1) { - if (minusOneIdx != -1) - minusOneIdx = -2; // more than one -1 -> invalid sizes list - else - minusOneIdx = j; - } else { - knownProduct *= sizesInts[j]; - } + OpFoldResult knownProduct = rewriter.getIndexAttr(1); + AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); + AffineExpr s1 = getAffineSymbolExpr(1, rewriter.getContext()); + auto mulMap = AffineMap::get(0, 2, s0 * s1, rewriter.getContext()); + + for (int64_t j = 0, e = reassocSizeValues.size(); j < e; ++j) { + int64_t constantSize; + // mlir::Value to int comparison... + if (matchPattern(reassocSizeValues[j], m_TorchConstantInt(&constantSize)) && constantSize == -1) { + minusOneIdx = j; + } else { + knownProduct = affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, {knownProduct, sizeToOFR(reassocSizeValues[j])}); } } - bool folded = false; - if (haveConstSizes && minusOneIdx >= 0) { - OpFoldResult numerator; - ArrayRef inShape = inputTensorType.getSizes(); - if (inShape[dimInt] != Torch::kUnknownSize) { - numerator = rewriter.getIndexAttr(inShape[dimInt]); + SmallVector outputShape; + SmallVector inputSizes = getTensorSizes(rewriter, loc, adaptor.getSelf()); + for (int64_t i = 0; i < inputRank; ++i) { + if (i == dimInt) { + OpFoldResult inputDimSize = (inputTensorType.getSizes()[dimInt] != Torch::kUnknownSize) ? + rewriter.getIndexAttr(inputTensorType.getSizes()[dimInt]) : OpFoldResult(inputSizes[dimInt]); + for (int64_t j = 0; j < numSizes; ++j) { + if (j == minusOneIdx) { + auto divMap = AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); + outputShape.push_back(affine::makeComposedFoldedAffineApply(rewriter, loc, divMap, {inputDimSize, knownProduct})); + } else { + outputShape.push_back(sizeToOFR(reassocSizeValues[j])); + } + } } else { - SmallVector inputShapeIdx = getTensorSizes(rewriter, loc, adaptor.getSelf()); - numerator = OpFoldResult(inputShapeIdx[dimInt]); + OpFoldResult inputDimSize = (inputTensorType.getSizes()[i] != Torch::kUnknownSize) ? + rewriter.getIndexAttr(inputTensorType.getSizes()[i]) : OpFoldResult(inputSizes[i]); + outputShape.push_back(inputDimSize); } + } - AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); - auto map = AffineMap::get(0, 1, s0.floorDiv(knownProduct), rewriter.getContext()); - OpFoldResult inferred = affine::makeComposedFoldedAffineApply(rewriter, loc, map, ArrayRef{numerator}); - - if (auto attr = inferred.dyn_cast()) { - int64_t inferredAttr = cast(attr).getInt(); // index attr - SmallVector inferShape(expandTy.getShape().begin(), expandTy.getShape().end()); - int64_t pos = dimInt + minusOneIdx; - inferShape[pos] = inferredAttr; - - auto inferTy = RankedTensorType::get(inferShape, expandTy.getElementType()); - Value inferExpand = rewriter.create(loc, inferTy, adaptor.getSelf(), reassociations); - - if (inferTy != expandTy) { - expand = rewriter.create(loc, expandTy, inferExpand).getResult(); - } else { - expand = inferExpand; - } - folded = true; + // Originally I was doing: + // expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, adaptor.getSelf(), reassociations, outputShape).getResult(); + // But with that I was running into: + // error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic + // %4491 = torch.aten.as_strided %4488, %4489, %4490, %int0_462 : !torch.vtensor<[2,4096,5120],f16>, !torch.list, !torch.list, !torch.int -> !torch.vtensor<[2,4096,2560],f16> + // /home/rdhar/expand-shape-bug/iree/iree-model-benchmark/sdxl/int8-model/base_ir/stable_diffusion_xl_base_1_0_scheduled_unet_bs1_64_1024x1024_i8.mlir:13071:13: note: see current operation: %17734 = "tensor.expand_shape"(%17730) <{reassociation = [[0, 1, 2]], static_output_shape = array}> : (tensor<2xi64>) -> tensor + // So there is this really ugly code to handle the types... but it kind of defeats all the code above. + SmallVector resultShape; + for (OpFoldResult ofr : outputShape) { + if (auto attr = ofr.dyn_cast()) { + resultShape.push_back(cast(attr).getInt()); + } else { + resultShape.push_back(ShapedType::kDynamic); } } - if (!folded) { - expand = rewriter.create(loc, expandTy, adaptor.getSelf(), reassociations).getResult(); + auto resultType = RankedTensorType::get(resultShape, expandTy.getElementType()); + expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, adaptor.getSelf(), reassociations, outputShape).getResult(); + + if (resultType != expandTy) { + expand = rewriter.create(loc, expandTy, expand).getResult(); } } else { reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), From 514c0c92d5e9481baa957dbd1e181d9ca348209a Mon Sep 17 00:00:00 2001 From: raayandhar Date: Fri, 26 Sep 2025 22:41:39 +0000 Subject: [PATCH 3/7] fix unit test assertion error --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 77 +++++++++++++------ 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a348a60008c1..ea7b6d5476c6 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -12,6 +12,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "PopulatePatterns.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -699,7 +700,8 @@ class ConvertAtenUnflattenIntOp op, "Must be able to either infer expansion dims, or retrieve them " "from list construct"); - RankedTensorType expandTy = cast(getTypeConverter()->convertType(outputTensorType)); + RankedTensorType expandTy = cast( + getTypeConverter()->convertType(outputTensorType)); Value expand; // When there are less than two dynamic reassociation dims, this will lower // to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. @@ -717,7 +719,6 @@ class ConvertAtenUnflattenIntOp reassociations[i - numSizes + 1].push_back(i); } - SmallVector reassocSizeValues; // Is there a function that already does this somewhere? auto sizeToOFR = [&](Value sizeVal) -> OpFoldResult { int64_t constantSize; @@ -725,7 +726,10 @@ class ConvertAtenUnflattenIntOp return rewriter.getIndexAttr(constantSize); } SmallVector singleSizeVec = {sizeVal}; - Value converted = castIntToIndex(rewriter, loc, getTypeConvertedValues(rewriter, loc, getTypeConverter(), singleSizeVec)[0]); + Value converted = castIntToIndex( + rewriter, loc, + getTypeConvertedValues(rewriter, loc, getTypeConverter(), + singleSizeVec)[0]); return OpFoldResult(converted); }; @@ -734,45 +738,63 @@ class ConvertAtenUnflattenIntOp AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); AffineExpr s1 = getAffineSymbolExpr(1, rewriter.getContext()); auto mulMap = AffineMap::get(0, 2, s0 * s1, rewriter.getContext()); - - for (int64_t j = 0, e = reassocSizeValues.size(); j < e; ++j) { + + for (int64_t j = 0, e = reassocSizes.size(); j < e; ++j) { int64_t constantSize; // mlir::Value to int comparison... - if (matchPattern(reassocSizeValues[j], m_TorchConstantInt(&constantSize)) && constantSize == -1) { + if (matchPattern(reassocSizes[j], m_TorchConstantInt(&constantSize)) && + constantSize == -1) { minusOneIdx = j; } else { - knownProduct = affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, {knownProduct, sizeToOFR(reassocSizeValues[j])}); + knownProduct = affine::makeComposedFoldedAffineApply( + rewriter, loc, mulMap, + {knownProduct, sizeToOFR(reassocSizes[j])}); } } SmallVector outputShape; - SmallVector inputSizes = getTensorSizes(rewriter, loc, adaptor.getSelf()); + SmallVector inputSizes = + getTensorSizes(rewriter, loc, adaptor.getSelf()); for (int64_t i = 0; i < inputRank; ++i) { if (i == dimInt) { - OpFoldResult inputDimSize = (inputTensorType.getSizes()[dimInt] != Torch::kUnknownSize) ? - rewriter.getIndexAttr(inputTensorType.getSizes()[dimInt]) : OpFoldResult(inputSizes[dimInt]); + OpFoldResult inputDimSize = + (inputTensorType.getSizes()[dimInt] != Torch::kUnknownSize) + ? rewriter.getIndexAttr(inputTensorType.getSizes()[dimInt]) + : OpFoldResult(inputSizes[dimInt]); for (int64_t j = 0; j < numSizes; ++j) { if (j == minusOneIdx) { - auto divMap = AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); - outputShape.push_back(affine::makeComposedFoldedAffineApply(rewriter, loc, divMap, {inputDimSize, knownProduct})); + auto divMap = + AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); + outputShape.push_back(affine::makeComposedFoldedAffineApply( + rewriter, loc, divMap, {inputDimSize, knownProduct})); } else { - outputShape.push_back(sizeToOFR(reassocSizeValues[j])); + outputShape.push_back(sizeToOFR(reassocSizes[j])); } } } else { - OpFoldResult inputDimSize = (inputTensorType.getSizes()[i] != Torch::kUnknownSize) ? - rewriter.getIndexAttr(inputTensorType.getSizes()[i]) : OpFoldResult(inputSizes[i]); + OpFoldResult inputDimSize = + (inputTensorType.getSizes()[i] != Torch::kUnknownSize) + ? rewriter.getIndexAttr(inputTensorType.getSizes()[i]) + : OpFoldResult(inputSizes[i]); outputShape.push_back(inputDimSize); } } // Originally I was doing: - // expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, adaptor.getSelf(), reassociations, outputShape).getResult(); - // But with that I was running into: - // error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic - // %4491 = torch.aten.as_strided %4488, %4489, %4490, %int0_462 : !torch.vtensor<[2,4096,5120],f16>, !torch.list, !torch.list, !torch.int -> !torch.vtensor<[2,4096,2560],f16> - // /home/rdhar/expand-shape-bug/iree/iree-model-benchmark/sdxl/int8-model/base_ir/stable_diffusion_xl_base_1_0_scheduled_unet_bs1_64_1024x1024_i8.mlir:13071:13: note: see current operation: %17734 = "tensor.expand_shape"(%17730) <{reassociation = [[0, 1, 2]], static_output_shape = array}> : (tensor<2xi64>) -> tensor - // So there is this really ugly code to handle the types... but it kind of defeats all the code above. + // expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, + // adaptor.getSelf(), reassociations, outputShape).getResult(); But with + // that I was running into: error: 'tensor.expand_shape' op expected + // dimension 0 of collapsed type to be dynamic since one or more of the + // corresponding dimensions in the expanded type is dynamic %4491 = + // torch.aten.as_strided %4488, %4489, %4490, %int0_462 : + // !torch.vtensor<[2,4096,5120],f16>, !torch.list, !torch.list, + // !torch.int -> !torch.vtensor<[2,4096,2560],f16> + // /home/rdhar/expand-shape-bug/iree/iree-model-benchmark/sdxl/int8-model/base_ir/stable_diffusion_xl_base_1_0_scheduled_unet_bs1_64_1024x1024_i8.mlir:13071:13: + // note: see current operation: %17734 = "tensor.expand_shape"(%17730) + // <{reassociation = [[0, 1, 2]], static_output_shape = array}> : (tensor<2xi64>) -> tensor So there is this really + // ugly code to handle the types... but it kind of defeats all the code + // above. SmallVector resultShape; for (OpFoldResult ofr : outputShape) { if (auto attr = ofr.dyn_cast()) { @@ -781,11 +803,16 @@ class ConvertAtenUnflattenIntOp resultShape.push_back(ShapedType::kDynamic); } } - auto resultType = RankedTensorType::get(resultShape, expandTy.getElementType()); - expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, adaptor.getSelf(), reassociations, outputShape).getResult(); - + auto resultType = + RankedTensorType::get(resultShape, expandTy.getElementType()); + expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, + adaptor.getSelf(), reassociations, + outputShape) + .getResult(); + if (resultType != expandTy) { - expand = rewriter.create(loc, expandTy, expand).getResult(); + expand = + rewriter.create(loc, expandTy, expand).getResult(); } } else { reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), From 98a90a378e1bd5c79ff88f4bff5e8c14ea580504 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Mon, 29 Sep 2025 22:53:04 +0000 Subject: [PATCH 4/7] address some comments and fix doc --- docs/development.md | 2 +- lib/Conversion/TorchToLinalg/DataMovement.cpp | 50 +++++++++---------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/docs/development.md b/docs/development.md index f1e72966f84d..bdb4854e4d66 100644 --- a/docs/development.md +++ b/docs/development.md @@ -187,7 +187,7 @@ sudo apt install clang ccache lld - **...run Python regression tests**, run: ```shell - cmake --build build --target check-torch-mlir-python + cmake --build build --target check-torch_mlir-python ``` TIP: add multiple target options to stack build phases diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index ea7b6d5476c6..3bb08bbf3cb2 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -672,7 +672,8 @@ class ConvertAtenUnflattenIntOp return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); } - int inputRank = inputTensorType.getSizes().size(); + auto inputTensorSizes = inputTensorType.getSizes(); + int inputRank = inputTensorSizes.size(); auto outputSizes = outputTensorType.getSizes(); int outputRank = outputSizes.size(); @@ -756,27 +757,28 @@ class ConvertAtenUnflattenIntOp SmallVector inputSizes = getTensorSizes(rewriter, loc, adaptor.getSelf()); for (int64_t i = 0; i < inputRank; ++i) { - if (i == dimInt) { + if (i != dimInt) { OpFoldResult inputDimSize = - (inputTensorType.getSizes()[dimInt] != Torch::kUnknownSize) - ? rewriter.getIndexAttr(inputTensorType.getSizes()[dimInt]) - : OpFoldResult(inputSizes[dimInt]); - for (int64_t j = 0; j < numSizes; ++j) { - if (j == minusOneIdx) { - auto divMap = - AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); - outputShape.push_back(affine::makeComposedFoldedAffineApply( - rewriter, loc, divMap, {inputDimSize, knownProduct})); - } else { - outputShape.push_back(sizeToOFR(reassocSizes[j])); - } - } - } else { - OpFoldResult inputDimSize = - (inputTensorType.getSizes()[i] != Torch::kUnknownSize) - ? rewriter.getIndexAttr(inputTensorType.getSizes()[i]) + (inputTensorSizes[i] != Torch::kUnknownSize) + ? rewriter.getIndexAttr(inputTensorSizes[i]) : OpFoldResult(inputSizes[i]); outputShape.push_back(inputDimSize); + continue; + } + + OpFoldResult inputDimSize = + (inputTensorSizes[dimInt] != Torch::kUnknownSize) + ? rewriter.getIndexAttr(inputTensorSizes[dimInt]) + : OpFoldResult(inputSizes[dimInt]); + for (int64_t j = 0; j < numSizes; ++j) { + if (j == minusOneIdx) { + auto divMap = + AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); + outputShape.push_back(affine::makeComposedFoldedAffineApply( + rewriter, loc, divMap, {inputDimSize, knownProduct})); + } else { + outputShape.push_back(sizeToOFR(reassocSizes[j])); + } } } @@ -795,14 +797,8 @@ class ConvertAtenUnflattenIntOp // 1>}> : (tensor<2xi64>) -> tensor So there is this really // ugly code to handle the types... but it kind of defeats all the code // above. - SmallVector resultShape; - for (OpFoldResult ofr : outputShape) { - if (auto attr = ofr.dyn_cast()) { - resultShape.push_back(cast(attr).getInt()); - } else { - resultShape.push_back(ShapedType::kDynamic); - } - } + SmallVector resultShape = + decomposeMixedValues(outputShape).first; auto resultType = RankedTensorType::get(resultShape, expandTy.getElementType()); expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, From 0c19f9961f3cb8d0ed7724dbf5eadb4a6180a187 Mon Sep 17 00:00:00 2001 From: raayandhar Date: Mon, 29 Sep 2025 23:40:10 +0000 Subject: [PATCH 5/7] add e2e and lit tests --- .../test_suite/reshape_like.py | 40 +++++++++++++ test/Conversion/TorchToLinalg/unflatten.mlir | 59 +++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 test/Conversion/TorchToLinalg/unflatten.mlir diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..1441eb1890f7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1281,6 +1281,46 @@ def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) +class UnflattenIntDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, 12], torch.float32, True), + ] + ) + def forward(self, inputs): + return torch.ops.aten.unflatten(inputs, 1, [3, 4]) + + +@register_test_case(module_factory=lambda: UnflattenIntDynamicModule()) +def UnflattenIntDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 12)) + + +class UnflattenIntDynamicWithInferredSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, 20], torch.float32, True), + ] + ) + def forward(self, inputs): + return torch.ops.aten.unflatten(inputs, 1, [4, -1]) + + +@register_test_case(module_factory=lambda: UnflattenIntDynamicWithInferredSizeModule()) +def UnflattenIntDynamicWithInferredSizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 20)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/unflatten.mlir b/test/Conversion/TorchToLinalg/unflatten.mlir new file mode 100644 index 000000000000..19ade6fd35a7 --- /dev/null +++ b/test/Conversion/TorchToLinalg/unflatten.mlir @@ -0,0 +1,59 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$static +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> + return %1 : !torch.vtensor<[2,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { + %int-2 = torch.constant.int -2 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> + return %1 : !torch.vtensor<[2,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[3,2,6],f32> + return %1 : !torch.vtensor<[3,2,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,2,3],f32> + return %1 : !torch.vtensor<[?,2,3],f32> +} From 96088b8fc0afd218683446c0f60bf0bde3e5c23d Mon Sep 17 00:00:00 2001 From: raayandhar Date: Mon, 29 Sep 2025 23:49:30 +0000 Subject: [PATCH 6/7] remove comments --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 3bb08bbf3cb2..d91926276f00 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -720,7 +720,6 @@ class ConvertAtenUnflattenIntOp reassociations[i - numSizes + 1].push_back(i); } - // Is there a function that already does this somewhere? auto sizeToOFR = [&](Value sizeVal) -> OpFoldResult { int64_t constantSize; if (matchPattern(sizeVal, m_TorchConstantInt(&constantSize))) { @@ -742,7 +741,6 @@ class ConvertAtenUnflattenIntOp for (int64_t j = 0, e = reassocSizes.size(); j < e; ++j) { int64_t constantSize; - // mlir::Value to int comparison... if (matchPattern(reassocSizes[j], m_TorchConstantInt(&constantSize)) && constantSize == -1) { minusOneIdx = j; @@ -782,21 +780,6 @@ class ConvertAtenUnflattenIntOp } } - // Originally I was doing: - // expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, - // adaptor.getSelf(), reassociations, outputShape).getResult(); But with - // that I was running into: error: 'tensor.expand_shape' op expected - // dimension 0 of collapsed type to be dynamic since one or more of the - // corresponding dimensions in the expanded type is dynamic %4491 = - // torch.aten.as_strided %4488, %4489, %4490, %int0_462 : - // !torch.vtensor<[2,4096,5120],f16>, !torch.list, !torch.list, - // !torch.int -> !torch.vtensor<[2,4096,2560],f16> - // /home/rdhar/expand-shape-bug/iree/iree-model-benchmark/sdxl/int8-model/base_ir/stable_diffusion_xl_base_1_0_scheduled_unet_bs1_64_1024x1024_i8.mlir:13071:13: - // note: see current operation: %17734 = "tensor.expand_shape"(%17730) - // <{reassociation = [[0, 1, 2]], static_output_shape = array}> : (tensor<2xi64>) -> tensor So there is this really - // ugly code to handle the types... but it kind of defeats all the code - // above. SmallVector resultShape = decomposeMixedValues(outputShape).first; auto resultType = From 14e1c199f56651df796b370be3c057797de8b18c Mon Sep 17 00:00:00 2001 From: raayandhar Date: Tue, 30 Sep 2025 18:11:58 +0000 Subject: [PATCH 7/7] add two dynamic dim test in unflatten.mlir --- test/Conversion/TorchToLinalg/unflatten.mlir | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/Conversion/TorchToLinalg/unflatten.mlir b/test/Conversion/TorchToLinalg/unflatten.mlir index 19ade6fd35a7..01049d4fac29 100644 --- a/test/Conversion/TorchToLinalg/unflatten.mlir +++ b/test/Conversion/TorchToLinalg/unflatten.mlir @@ -57,3 +57,18 @@ func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f3 %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,2,3],f32> return %1 : !torch.vtensor<[?,2,3],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.from_elements +// CHECK: tensor.reshape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int1 = torch.constant.int 1 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int + %0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,?],f32> + return %1 : !torch.vtensor<[?,?,?],f32> +}