@@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
400400
401401 // Replace all results with the yielded values.
402402 auto yieldOp = cast<scf::YieldOp>(getBody ()->getTerminator ());
403- rewriter.replaceAllUsesWith (getResults (), yieldOp. getOperands ());
403+ rewriter.replaceAllUsesWith (getResults (), getYieldedValues ());
404404
405405 // Replace block arguments with lower bound (replacement for IV) and
406406 // iter_args.
@@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
772772 LogicalResult matchAndRewrite (scf::ForOp forOp,
773773 PatternRewriter &rewriter) const final {
774774 bool canonicalize = false ;
775- Block &block = forOp.getRegion ().front ();
776- auto yieldOp = cast<scf::YieldOp>(block.getTerminator ());
777775
778776 // An internal flat vector of block transfer
779777 // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
780778 // transformed block argument mappings. This plays the role of a
781779 // IRMapping for the particular use case of calling into
782780 // `inlineBlockBefore`.
781+ int64_t numResults = forOp.getNumResults ();
783782 SmallVector<bool , 4 > keepMask;
784- keepMask.reserve (yieldOp. getNumOperands () );
783+ keepMask.reserve (numResults );
785784 SmallVector<Value, 4 > newBlockTransferArgs, newIterArgs, newYieldValues,
786785 newResultValues;
787- newBlockTransferArgs.reserve (1 + forOp. getInitArgs (). size () );
786+ newBlockTransferArgs.reserve (1 + numResults );
788787 newBlockTransferArgs.push_back (Value ()); // iv placeholder with null value
789788 newIterArgs.reserve (forOp.getInitArgs ().size ());
790- newYieldValues.reserve (yieldOp. getNumOperands () );
791- newResultValues.reserve (forOp. getNumResults () );
789+ newYieldValues.reserve (numResults );
790+ newResultValues.reserve (numResults );
792791 for (auto it : llvm::zip (forOp.getInitArgs (), // iter from outside
793792 forOp.getRegionIterArgs (), // iter inside region
794793 forOp.getResults (), // op results
795- yieldOp. getOperands () // iter yield
794+ forOp. getYieldedValues () // iter yield
796795 )) {
797796 // Forwarded is `true` when:
798797 // 1) The region `iter` argument is yielded.
@@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
946945 return failure ();
947946 // If the loop is empty, iterates at least once, and only returns values
948947 // defined outside of the loop, remove it and replace it with yield values.
949- auto yieldOp = cast<scf::YieldOp>(block.getTerminator ());
950- auto yieldOperands = yieldOp.getOperands ();
951- if (llvm::any_of (yieldOperands,
948+ if (llvm::any_of (op.getYieldedValues (),
952949 [&](Value v) { return !op.isDefinedOutsideOfLoop (v); }))
953950 return failure ();
954- rewriter.replaceOp (op, yieldOperands );
951+ rewriter.replaceOp (op, op. getYieldedValues () );
955952 return success ();
956953 }
957954};
@@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
12241221 return {};
12251222}
12261223
1224+ ValueRange ForOp::getYieldedValues () {
1225+ return cast<scf::YieldOp>(getBody ()->getTerminator ()).getResults ();
1226+ }
1227+
12271228Speculation::Speculatability ForOp::getSpeculatability () {
12281229 // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
12291230 // and End.
@@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
32053206 return cast<YieldOp>(getAfterBody ()->getTerminator ());
32063207}
32073208
3209+ ValueRange WhileOp::getYieldedValues () { return getYieldOp ().getResults (); }
3210+
32083211Block::BlockArgListType WhileOp::getBeforeArguments () {
32093212 return getBeforeBody ()->getArguments ();
32103213}
0 commit comments