@@ -5776,6 +5776,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
57765776//  ConstantMaskOp
57775777// ===----------------------------------------------------------------------===//
57785778
5779+ void  ConstantMaskOp::build (OpBuilder &builder, OperationState &result,
5780+                            VectorType type, ConstantMaskKind kind) {
5781+   assert (kind == ConstantMaskKind::AllTrue ||
5782+          kind == ConstantMaskKind::AllFalse);
5783+   build (builder, result, type,
5784+         kind == ConstantMaskKind::AllTrue
5785+             ? type.getShape ()
5786+             : SmallVector<int64_t >(type.getRank (), 0 ));
5787+ }
5788+ 
57795789LogicalResult ConstantMaskOp::verify () {
57805790  auto  resultType = llvm::cast<VectorType>(getResult ().getType ());
57815791  //  Check the corner case of 0-D vectors first.
@@ -5858,6 +5868,21 @@ LogicalResult CreateMaskOp::verify() {
58585868  return  success ();
58595869}
58605870
5871+ std::optional<int64_t > vector::getConstantVscaleMultiplier (Value value) {
5872+   if  (value.getDefiningOp <vector::VectorScaleOp>())
5873+     return  1 ;
5874+   auto  mul = value.getDefiningOp <arith::MulIOp>();
5875+   if  (!mul)
5876+     return  {};
5877+   auto  lhs = mul.getLhs ();
5878+   auto  rhs = mul.getRhs ();
5879+   if  (lhs.getDefiningOp <vector::VectorScaleOp>())
5880+     return  getConstantIntValue (rhs);
5881+   if  (rhs.getDefiningOp <vector::VectorScaleOp>())
5882+     return  getConstantIntValue (lhs);
5883+   return  {};
5884+ }
5885+ 
58615886namespace  {
58625887
58635888// / Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
@@ -5889,73 +5914,51 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
58895914
58905915  LogicalResult matchAndRewrite (CreateMaskOp createMaskOp,
58915916                                PatternRewriter &rewriter) const  override  {
5892-     VectorType retTy = createMaskOp.getResult ().getType ();
5893-     bool  isScalable = retTy.isScalable ();
5894- 
5895-     //  Check every mask operand
5896-     for  (auto  [opIdx, operand] : llvm::enumerate (createMaskOp.getOperands ())) {
5897-       if  (auto  cst = getConstantIntValue (operand)) {
5898-         //  Most basic case - this operand is a constant value. Note that for
5899-         //  scalable dimensions, CreateMaskOp can be folded only if the
5900-         //  corresponding operand is negative or zero.
5901-         if  (retTy.getScalableDims ()[opIdx] && *cst > 0 )
5902-           return  failure ();
5903- 
5904-         continue ;
5905-       }
5906- 
5907-       //  Non-constant operands are not allowed for non-scalable vectors.
5908-       if  (!isScalable)
5909-         return  failure ();
5910- 
5911-       //  For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5912-       //  true" mask, so can also be treated as constant.
5913-       auto  mul = operand.getDefiningOp <arith::MulIOp>();
5914-       if  (!mul)
5915-         return  failure ();
5916-       auto  mulLHS = mul.getRhs ();
5917-       auto  mulRHS = mul.getLhs ();
5918-       bool  isOneOpVscale =
5919-           (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp ()) ||
5920-            isa<vector::VectorScaleOp>(mulRHS.getDefiningOp ()));
5921- 
5922-       auto  isConstantValMatchingDim =
5923-           [=, dim = retTy.getShape ()[opIdx]](Value operand) {
5924-             auto  constantVal = getConstantIntValue (operand);
5925-             return  (constantVal.has_value () && constantVal.value () == dim);
5926-           };
5927- 
5928-       bool  isOneOpConstantMatchingDim =
5929-           isConstantValMatchingDim (mulLHS) || isConstantValMatchingDim (mulRHS);
5930- 
5931-       if  (!isOneOpVscale || !isOneOpConstantMatchingDim)
5932-         return  failure ();
5917+     VectorType maskType = createMaskOp.getVectorType ();
5918+     ArrayRef<int64_t > maskTypeDimSizes = maskType.getShape ();
5919+     ArrayRef<bool > maskTypeDimScalableFlags = maskType.getScalableDims ();
5920+ 
5921+     //  Special case: Rank zero shape.
5922+     constexpr  std::array<int64_t , 1 > rankZeroShape{1 };
5923+     constexpr  std::array<bool , 1 > rankZeroScalableDims{false };
5924+     if  (maskType.getRank () == 0 ) {
5925+       maskTypeDimSizes = rankZeroShape;
5926+       maskTypeDimScalableFlags = rankZeroScalableDims;
59335927    }
59345928
5935-     //  Gather constant mask dimension sizes.
5936-     SmallVector<int64_t , 4 > maskDimSizes;
5937-     maskDimSizes.reserve (createMaskOp->getNumOperands ());
5938-     for  (auto  [operand, maxDimSize] : llvm::zip_equal (
5939-              createMaskOp.getOperands (), createMaskOp.getType ().getShape ())) {
5940-       std::optional dimSize = getConstantIntValue (operand);
5941-       if  (!dimSize) {
5942-         //  Although not a constant, it is safe to assume that `operand` is
5943-         //  "vscale * maxDimSize".
5944-         maskDimSizes.push_back (maxDimSize);
5945-         continue ;
5946-       }
5947-       int64_t  dimSizeVal = std::min (dimSize.value (), maxDimSize);
5948-       //  If one of dim sizes is zero, set all dims to zero.
5949-       if  (dimSize <= 0 ) {
5950-         maskDimSizes.assign (createMaskOp.getType ().getRank (), 0 );
5951-         break ;
5929+     //  Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
5930+     //  collect the `constantDims` (for the ConstantMaskOp).
5931+     SmallVector<int64_t , 4 > constantDims;
5932+     for  (auto  [i, dimSize] : llvm::enumerate (createMaskOp.getOperands ())) {
5933+       if  (auto  intSize = getConstantIntValue (dimSize)) {
5934+         //  Constant value.
5935+         //  If the mask dim is non-scalable this can be any value.
5936+         //  If the mask dim is scalable only zero (all-false) is supported.
5937+         if  (maskTypeDimScalableFlags[i] && intSize >= 0 )
5938+           return  failure ();
5939+         constantDims.push_back (*intSize);
5940+       } else  if  (auto  vscaleMultiplier = getConstantVscaleMultiplier (dimSize)) {
5941+         //  Constant vscale multiple (e.g. 4 x vscale).
5942+         //  Must be all-true to fold to a ConstantMask.
5943+         if  (vscaleMultiplier < maskTypeDimSizes[i])
5944+           return  failure ();
5945+         constantDims.push_back (*vscaleMultiplier);
5946+       } else  {
5947+         return  failure ();
59525948      }
5953-       maskDimSizes.push_back (dimSizeVal);
59545949    }
59555950
5951+     //  Clamp values to constant_mask bounds.
5952+     for  (auto  [value, maskDimSize] : llvm::zip (constantDims, maskTypeDimSizes))
5953+       value = std::clamp<int64_t >(value, 0 , maskDimSize);
5954+ 
5955+     //  If one of dim sizes is zero, set all dims to zero.
5956+     if  (llvm::is_contained (constantDims, 0 ))
5957+       constantDims.assign (constantDims.size (), 0 );
5958+ 
59565959    //  Replace 'createMaskOp' with ConstantMaskOp.
5957-     rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, retTy ,
5958-                                                 maskDimSizes );
5960+     rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, maskType ,
5961+                                                 constantDims );
59595962    return  success ();
59605963  }
59615964};
0 commit comments