@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
8888 // Inspect constant mask index. If the index exceeds the
8989 // dimension size, all bits are set. If the index is zero
9090 // or less, no bits are set.
91- ArrayAttr masks = m.getMaskDimSizes ();
91+ ArrayRef< int64_t > masks = m.getMaskDimSizes ();
9292 auto shape = m.getType ().getShape ();
9393 bool allTrue = true ;
9494 bool allFalse = true ;
9595 for (auto [maskIdx, dimSize] : llvm::zip_equal (masks, shape)) {
96- int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt ();
97- if (i < dimSize)
96+ if (maskIdx < dimSize)
9897 allTrue = false ;
99- if (i > 0 )
98+ if (maskIdx > 0 )
10099 allFalse = false ;
101100 }
102101 if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
35933592 if (extractStridedSliceOp.hasNonUnitStrides ())
35943593 return failure ();
35953594 // Gather constant mask dimension sizes.
3596- SmallVector<int64_t , 4 > maskDimSizes;
3597- populateFromInt64AttrArray (constantMaskOp.getMaskDimSizes (), maskDimSizes);
3595+ ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
35983596 // Gather strided slice offsets and sizes.
35993597 SmallVector<int64_t , 4 > sliceOffsets;
36003598 populateFromInt64AttrArray (extractStridedSliceOp.getOffsets (),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
36253623 // region.
36263624 rewriter.replaceOpWithNewOp <ConstantMaskOp>(
36273625 extractStridedSliceOp, extractStridedSliceOp.getResult ().getType (),
3628- vector::getVectorSubscriptAttr (rewriter, sliceMaskDimSizes) );
3626+ sliceMaskDimSizes);
36293627 return success ();
36303628 }
36313629};
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
54105408 }
54115409
54125410 if (constantMaskOp) {
5413- auto maskDimSizes = constantMaskOp.getMaskDimSizes (). getValue () ;
5411+ auto maskDimSizes = constantMaskOp.getMaskDimSizes ();
54145412 auto numMaskOperands = maskDimSizes.size ();
54155413
54165414 // Check every mask dim size to see whether it can be dropped
54175415 for (size_t i = numMaskOperands - 1 ; i >= numMaskOperands - numDimsToDrop;
54185416 --i) {
5419- if (cast<IntegerAttr>( maskDimSizes[i]). getValue () != 1 )
5417+ if (maskDimSizes[i] != 1 )
54205418 return failure ();
54215419 }
54225420
54235421 auto newMaskOperands = maskDimSizes.drop_back (numDimsToDrop);
5424- ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr (newMaskOperands);
5425-
54265422 rewriter.replaceOpWithNewOp <vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5427- newMaskOperandsAttr );
5423+ newMaskOperands );
54285424 return success ();
54295425 }
54305426
@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
58045800
58055801 // ConstantMaskOp case.
58065802 auto maskDimSizes = constantMaskOp.getMaskDimSizes ();
5807- SmallVector<Attribute> newMaskDimSizes (maskDimSizes.getValue ());
5808- applyPermutationToVector (newMaskDimSizes, permutation);
5803+ auto newMaskDimSizes = applyPermutation (maskDimSizes, permutation);
58095804
58105805 rewriter.replaceOpWithNewOp <vector::ConstantMaskOp>(
5811- transpOp, transpOp.getResultVectorType (),
5812- ArrayAttr::get (transpOp.getContext (), newMaskDimSizes));
5806+ transpOp, transpOp.getResultVectorType (), newMaskDimSizes);
58135807 return success ();
58145808 }
58155809};
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
58325826 if (resultType.getRank () == 0 ) {
58335827 if (getMaskDimSizes ().size () != 1 )
58345828 return emitError (" array attr must have length 1 for 0-D vectors" );
5835- auto dim = llvm::cast<IntegerAttr>( getMaskDimSizes ()[0 ]). getInt () ;
5829+ auto dim = getMaskDimSizes ()[0 ];
58365830 if (dim != 0 && dim != 1 )
58375831 return emitError (" mask dim size must be either 0 or 1 for 0-D vectors" );
58385832 return success ();
@@ -5846,17 +5840,15 @@ LogicalResult ConstantMaskOp::verify() {
58465840 // result dimension size.
58475841 auto resultShape = resultType.getShape ();
58485842 auto resultScalableDims = resultType.getScalableDims ();
5849- SmallVector<int64_t , 4 > maskDimSizes;
5850- for (const auto [index, intAttr] : llvm::enumerate (getMaskDimSizes ())) {
5851- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt ();
5843+ ArrayRef<int64_t > maskDimSizes = getMaskDimSizes ();
5844+ for (const auto [index, maskDimSize] : llvm::enumerate (maskDimSizes)) {
58525845 if (maskDimSize < 0 || maskDimSize > resultShape[index])
58535846 return emitOpError (
58545847 " array attr of size out of bounds of vector result dimension size" );
58555848 if (resultScalableDims[index] && maskDimSize != 0 &&
58565849 maskDimSize != resultShape[index])
58575850 return emitOpError (
58585851 " only supports 'none set' or 'all set' scalable dimensions" );
5859- maskDimSizes.push_back (maskDimSize);
58605852 }
58615853 // Verify that if one mask dim size is zero, they all should be zero (because
58625854 // the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
58735865 // Check the corner case of 0-D vectors first.
58745866 if (resultType.getRank () == 0 ) {
58755867 assert (getMaskDimSizes ().size () == 1 && " invalid sizes for zero rank mask" );
5876- return llvm::cast<IntegerAttr>( getMaskDimSizes ()[0 ]). getInt () == 1 ;
5868+ return getMaskDimSizes ()[0 ] == 1 ;
58775869 }
5878- for (const auto [resultSize, intAttr ] :
5870+ for (const auto [resultSize, maskDimSize ] :
58795871 llvm::zip_equal (resultType.getShape (), getMaskDimSizes ())) {
5880- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt ();
58815872 if (maskDimSize < resultSize)
58825873 return false ;
58835874 }
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
60075998 }
60085999
60096000 // Replace 'createMaskOp' with ConstantMaskOp.
6010- rewriter.replaceOpWithNewOp <ConstantMaskOp>(
6011- createMaskOp, retTy,
6012- vector::getVectorSubscriptAttr (rewriter, maskDimSizes));
6001+ rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, retTy,
6002+ maskDimSizes);
60136003 return success ();
60146004 }
60156005};
0 commit comments