@@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
24642464// ShuffleOp
24652465// ===----------------------------------------------------------------------===//
24662466
2467- void ShuffleOp::build (OpBuilder &builder, OperationState &result, Value v1,
2468- Value v2, ArrayRef<int64_t > mask) {
2469- build (builder, result, v1, v2, getVectorSubscriptAttr (builder, mask));
2470- }
2471-
24722467LogicalResult ShuffleOp::verify () {
24732468 VectorType resultType = getResultVectorType ();
24742469 VectorType v1Type = getV1VectorType ();
@@ -2491,19 +2486,18 @@ LogicalResult ShuffleOp::verify() {
24912486 return emitOpError (" dimension mismatch" );
24922487 }
24932488 // Verify mask length.
2494- auto maskAttr = getMask (). getValue ();
2495- int64_t maskLength = maskAttr .size ();
2489+ ArrayRef< int64_t > mask = getMask ();
2490+ int64_t maskLength = mask .size ();
24962491 if (maskLength <= 0 )
24972492 return emitOpError (" invalid mask length" );
24982493 if (maskLength != resultType.getDimSize (0 ))
24992494 return emitOpError (" mask length mismatch" );
25002495 // Verify all indices.
25012496 int64_t indexSize = (v1Type.getRank () == 0 ? 1 : v1Type.getDimSize (0 )) +
25022497 (v2Type.getRank () == 0 ? 1 : v2Type.getDimSize (0 ));
2503- for (const auto &en : llvm::enumerate (maskAttr)) {
2504- auto attr = llvm::dyn_cast<IntegerAttr>(en.value ());
2505- if (!attr || attr.getInt () < 0 || attr.getInt () >= indexSize)
2506- return emitOpError (" mask index #" ) << (en.index () + 1 ) << " out of range" ;
2498+ for (auto [idx, maskPos] : llvm::enumerate (mask)) {
2499+ if (maskPos < 0 || maskPos >= indexSize)
2500+ return emitOpError (" mask index #" ) << (idx + 1 ) << " out of range" ;
25072501 }
25082502 return success ();
25092503}
@@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
25272521 return success ();
25282522}
25292523
2530- static bool isStepIndexArray (ArrayAttr idxArr, uint64_t begin, size_t width) {
2531- uint64_t expected = begin;
2532- return idxArr.size () == width &&
2533- llvm::all_of (idxArr.getAsValueRange <IntegerAttr>(),
2534- [&expected](auto attr) {
2535- return attr.getZExtValue () == expected++;
2536- });
2524+ template <typename T>
2525+ static bool isStepIndexArray (ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2526+ T expected = begin;
2527+ return idxArr.size () == width && llvm::all_of (idxArr, [&expected](T value) {
2528+ return value == expected++;
2529+ });
25372530}
25382531
25392532OpFoldResult vector::ShuffleOp::fold (FoldAdaptor adaptor) {
@@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
25682561 SmallVector<Attribute> results;
25692562 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues <Attribute>();
25702563 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues <Attribute>();
2571- for (const auto &index : this ->getMask ().getAsValueRange <IntegerAttr>()) {
2572- int64_t i = index.getZExtValue ();
2564+ for (int64_t i : this ->getMask ()) {
25732565 if (i >= lhsSize) {
25742566 results.push_back (rhsElements[i - lhsSize]);
25752567 } else {
@@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
25902582 LogicalResult matchAndRewrite (ShuffleOp shuffleOp,
25912583 PatternRewriter &rewriter) const override {
25922584 VectorType v1VectorType = shuffleOp.getV1VectorType ();
2593- ArrayAttr mask = shuffleOp.getMask ();
2585+ ArrayRef< int64_t > mask = shuffleOp.getMask ();
25942586 if (v1VectorType.getRank () > 0 )
25952587 return failure ();
25962588 if (mask.size () != 1 )
25972589 return failure ();
25982590 VectorType resType = VectorType::Builder (v1VectorType).setShape ({1 });
2599- if (llvm::cast<IntegerAttr>( mask[0 ]). getInt () == 0 )
2591+ if (mask[0 ] == 0 )
26002592 rewriter.replaceOpWithNewOp <vector::BroadcastOp>(shuffleOp, resType,
26012593 shuffleOp.getV1 ());
26022594 else
@@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
26512643 op, " ShuffleOp types don't match an interleave" );
26522644 }
26532645
2654- ArrayAttr shuffleMask = op.getMask ();
2646+ ArrayRef< int64_t > shuffleMask = op.getMask ();
26552647 int64_t resultVectorSize = resultType.getNumElements ();
26562648 for (int i = 0 , e = resultVectorSize / 2 ; i < e; ++i) {
2657- int64_t maskValueA = cast<IntegerAttr>( shuffleMask[i * 2 ]). getInt () ;
2658- int64_t maskValueB = cast<IntegerAttr>( shuffleMask[(i * 2 ) + 1 ]). getInt () ;
2649+ int64_t maskValueA = shuffleMask[i * 2 ];
2650+ int64_t maskValueB = shuffleMask[(i * 2 ) + 1 ];
26592651 if (maskValueA != i || maskValueB != (resultVectorSize / 2 ) + i)
26602652 return rewriter.notifyMatchFailure (op,
26612653 " ShuffleOp mask not interleaving" );
0 commit comments