diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index d2c6ba557b9bb..75e1abead973f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { opToErase.push_back(read.getOperation()); } -/// Returns a copy of `shape` without unit dims. -static SmallVector getReducedShape(ArrayRef shape) { - SmallVector reducedShape; - llvm::copy_if(shape, std::back_inserter(reducedShape), - [](int64_t dimSize) { return dimSize != 1; }); - return reducedShape; -} - /// Converts OpFoldResults to int64_t shape without unit dims. static SmallVector getReducedShape(ArrayRef mixedSizes) { SmallVector reducedShape; @@ -340,7 +332,7 @@ static FailureOr createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, vector::CreateMaskOp op) { auto type = op.getType(); - auto reducedType = trimNonScalableUnitDims(type); + VectorType reducedType = trimNonScalableUnitDims(type); if (reducedType.getRank() == type.getRank()) return failure(); @@ -391,7 +383,7 @@ class TransferReadDropUnitDimsPattern return failure(); // Check if the reduced vector shape matches the reduced source shape. // Otherwise, this case is not supported yet. - auto reducedVectorType = trimNonScalableUnitDims(vectorType); + VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); if (reducedRank != reducedVectorType.getRank()) return failure(); if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { @@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern Value source = transferWriteOp.getSource(); MemRefType sourceType = dyn_cast(source.getType()); // TODO: support tensor type. - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (!sourceType) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) @@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern return failure(); // Check if the reduced vector shape matches the reduced destination shape. // Otherwise, this case is not supported yet. - int vectorReducedRank = getReducedRank(vectorType.getShape()); - if (reducedRank != vectorReducedRank) + VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); + if (reducedRank != reducedVectorType.getRank()) return failure(); if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return getConstantIntValue(v) != static_cast(0); })) return failure(); + + Value maskOp = transferWriteOp.getMask(); + if (maskOp) { + auto createMaskOp = maskOp.getDefiningOp(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + transferWriteOp, + "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + FailureOr rankReducedCreateMask = + createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); + if (failed(rankReducedCreateMask)) + return failure(); + maskOp = *rankReducedCreateMask; + } Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); Value c0 = rewriter.create(loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - VectorType reducedVectorType = VectorType::get( - getReducedShape(vectorType.getShape()), vectorType.getElementType()); - + SmallVector inBounds(reducedVectorType.getRank(), true); auto shapeCast = rewriter.createOrFold( loc, reducedVectorType, vector); rewriter.replaceOpWithNewOp( - transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); + transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, + identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 735915d435653..d65708068862f 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2( // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref // CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref, vector<[1]x3x[16]xi8> +func.func @masked_transfer_write_and_vector_rank_reducing( + %arg : memref<1x1x3x1x16x1xf32>, + %vec : vector<1x3x1x16x1xf32>, + %mask_dim1 : index, + %mask_dim2 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1> + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask : + vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32> + return +} +// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32> +// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>, +// CHECK-SAME: %[[MASKDIM1:.+]]: index, +// CHECK-SAME: %[[MASKDIM2:.+]]: index +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32> +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32> + +func.func @masked_transfer_write_dynamic_rank_reducing( + %arg : memref>, + %vec : vector<[16]x1xi8>, + %mask_dim0 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1> + vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} : + vector<[16]x1xi8>, memref> + return +} +// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref, +// CHECK-SAME: %[[MASK_DIM0:.+]]: index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1> +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref to memref +// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref + /// Only masks operands of vector.create_mask are currently supported. func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1( %arg : memref>,