@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
17201720 }
17211721};
17221722
1723+ // / A pattern to drop unit dims from vector.transpose.
1724+ // /
1725+ // / Example:
1726+ // /
1727+ // / BEFORE:
1728+ // / ```mlir
1729+ // / %transpose = vector.transpose %vector, [3, 0, 1, 2]
1730+ // / : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1731+ // / ```
1732+ // /
1733+ // / AFTER:
1734+ // / ```mlir
1735+ // / %dropDims = vector.shape_cast %vector
1736+ // / : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1737+ // / %transpose = vector.transpose %0, [1, 0]
1738+ // / : vector<4x[4]xf32> to vector<[4]x4xf32>
1739+ // / %restoreDims = vector.shape_cast %transpose
1740+ // / : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1741+ // / ```
1742+ struct DropUnitDimsFromTransposeOp final
1743+ : OpRewritePattern<vector::TransposeOp> {
1744+ using OpRewritePattern::OpRewritePattern;
1745+
1746+ LogicalResult matchAndRewrite (vector::TransposeOp op,
1747+ PatternRewriter &rewriter) const override {
1748+ VectorType sourceType = op.getSourceVectorType ();
1749+ VectorType sourceTypeWithoutUnitDims =
1750+ dropNonScalableUnitDimFromType (sourceType);
1751+
1752+ if (sourceType == sourceTypeWithoutUnitDims)
1753+ return failure ();
1754+
1755+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1756+ auto sourceDims = llvm::to_vector (vector::getDims (sourceType));
1757+ SmallVector<int64_t > droppedDimsBefore (sourceType.getRank ());
1758+ int64_t droppedDims = 0 ;
1759+ for (auto [i, dim] : llvm::enumerate (sourceDims)) {
1760+ droppedDimsBefore[i] = droppedDims;
1761+ if (dim == std::make_tuple (1 , false ))
1762+ ++droppedDims;
1763+ }
1764+
1765+ // Drop unit dims from transpose permutation.
1766+ ArrayRef<int64_t > perm = op.getPermutation ();
1767+ SmallVector<int64_t > newPerm;
1768+ for (int64_t idx : perm) {
1769+ if (sourceDims[idx] == std::make_tuple (1 , false ))
1770+ continue ;
1771+ newPerm.push_back (idx - droppedDimsBefore[idx]);
1772+ }
1773+
1774+ Location loc = op.getLoc ();
1775+ // Drop the unit dims via shape_cast.
1776+ auto dropDimsShapeCast = rewriter.create <vector::ShapeCastOp>(
1777+ loc, sourceTypeWithoutUnitDims, op.getVector ());
1778+ // Create the new transpose.
1779+ auto tranposeWithoutUnitDims =
1780+ rewriter.create <vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
1781+ // Restore the unit dims via shape cast.
1782+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
1783+ op, op.getResultVectorType (), tranposeWithoutUnitDims);
1784+
1785+ return failure ();
1786+ }
1787+ };
1788+
17231789// / Pattern to eliminate redundant zero-constants added to reduction operands.
17241790// / It's enough for there to be one initial zero value, so we can eliminate the
17251791// / extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
19241990
19251991void mlir::vector::populateDropUnitDimWithShapeCastPatterns (
19261992 RewritePatternSet &patterns, PatternBenefit benefit) {
1927- patterns.add <DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1928- patterns.getContext (), benefit);
1993+ patterns.add <DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
1994+ ShapeCastOpFolder>( patterns.getContext (), benefit);
19291995}
19301996
19311997void mlir::vector::populateBubbleVectorBitCastOpPatterns (
0 commit comments