diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index a57add9f73197..17d74d62236b2 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2112,6 +2112,8 @@ def fir_DoLoopOp : region_Op<"do_loop", mlir::Operation::operand_range getIterOperands() { return getOperands().drop_front(getNumControlOperands()); } + mlir::OperandRange getInits() { return getIterOperands(); } + mlir::ValueRange getYieldedValues(); void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); } void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); } @@ -2263,6 +2265,8 @@ def fir_IterWhileOp : region_Op<"iterate_while", mlir::Operation::operand_range getIterOperands() { return getOperands().drop_front(getNumControlOperands()); } + mlir::OperandRange getInits() { return getIterOperands(); } + mlir::ValueRange getYieldedValues(); void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); } void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 962b87acd5a80..f71d0047af5b3 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1972,6 +1972,12 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { return {}; } +mlir::ValueRange fir::IterWhileOp::getYieldedValues() { + auto *term = getRegion().front().getTerminator(); + return getFinalValue() ? term->getOperands().drop_front() + : term->getOperands(); +} + //===----------------------------------------------------------------------===// // LenParamIndexOp //===----------------------------------------------------------------------===// @@ -2267,6 +2273,12 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { return {}; } +mlir::ValueRange fir::DoLoopOp::getYieldedValues() { + auto *term = getRegion().front().getTerminator(); + return getFinalValue() ? term->getOperands().drop_front() + : term->getOperands(); +} + //===----------------------------------------------------------------------===// // DTEntryOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index f90a7b23ec12e..36fdf390a7617 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -121,7 +121,8 @@ def AffineForOp : Affine_Op<"for", ImplicitAffineTerminator, ConditionallySpeculatable, RecursiveMemoryEffects, DeclareOpInterfaceMethods, + "getSingleUpperBound", "getYieldedValues", + "replaceWithAdditionalYields"]>, DeclareOpInterfaceMethods]> { let summary = "for operation"; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index e1a604a88715f..f2ea7dd868a37 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -122,8 +122,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ def ForOp : SCF_Op<"for", [AutomaticAllocationScope, DeclareOpInterfaceMethods, + "getSingleStep", "getSingleUpperBound", "getYieldedValues", + "promoteIfSingleIteration", "replaceWithAdditionalYields"]>, AllTypesMatch<["lowerBound", "upperBound", "step"]>, ConditionallySpeculatable, DeclareOpInterfaceMethods; Value getInductionVar() { return getBody()->getArgument(0); } + Block::BlockArgListType getRegionIterArgs() { return getBody()->getArguments().drop_front(getNumInductionVars()); } + /// Return the `index`-th region iteration argument. BlockArgument getRegionIterArg(unsigned index) { assert(index < getNumRegionIterArgs() && @@ -1086,6 +1088,11 @@ def WhileOp : SCF_Op<"while", ConditionOp getConditionOp(); YieldOp getYieldOp(); + + /// Return the values that are yielded from the "after" region (by the + /// scf.yield op). + ValueRange getYieldedValues(); + Block::BlockArgListType getBeforeArguments(); Block::BlockArgListType getAfterArguments(); Block *getBeforeBody() { return &getBefore().front(); } diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h index 0eebb984e5897..7c7d378d0590a 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -24,6 +24,11 @@ class RewriterBase; /// arguments in `newBbArgs`. using NewYieldValuesFn = std::function( OpBuilder &b, Location loc, ArrayRef newBbArgs)>; + +namespace detail { +/// Verify invariants of the LoopLikeOpInterface. +LogicalResult verifyLoopLikeOpInterface(Operation *op); +} // namespace detail } // namespace mlir /// Include the generated interface declarations. diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index ded0a29292ff6..4d2a66dd3143d 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -20,6 +20,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { Contains helper functions to query properties and perform transformations of a loop. Operations that implement this interface will be considered by loop-invariant code motion. + + Loop-carried variables can be exposed through this interface. There are + 3 components to a loop-carried variable. + - The "region iter_arg" is the block argument of the entry block that + represents the loop-carried variable in each iteration. + - The "init value" is an operand of the loop op that serves as the initial + region iter_arg value for the first iteration (if any). + - The "yielded" value is the value that is forwarded from one iteration to + serve as the region iter_arg of the next iteration. + + If one of the respective interface methods is implemented, so must the other + two. The interface verifier ensures that the number of types of the region + iter_args, init values and yielded values match. }]; let cppNamespace = "::mlir"; @@ -141,6 +154,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { return ::mlir::Block::BlockArgListType(); }] >, + InterfaceMethod<[{ + Return the values that are yielded to the next iteration. + }], + /*retTy=*/"::mlir::ValueRange", + /*methodName=*/"getYieldedValues", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::ValueRange(); + }] + >, InterfaceMethod<[{ Append the specified additional "init" operands: replace this loop with a new loop that has the additional init operands. The loop body of @@ -192,6 +216,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { }); } }]; + + let verifyWithRegions = 1; + + let verify = [{ + return detail::verifyLoopLikeOpInterface($_op); + }]; } #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index c61bd566c7676..44e79294a21a0 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2215,6 +2215,10 @@ unsigned AffineForOp::getNumIterOperands() { return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); } +ValueRange AffineForOp::getYieldedValues() { + return cast(getBody()->getTerminator()).getOperands(); +} + void AffineForOp::print(OpAsmPrinter &p) { p << ' '; p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{}, diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 72bd2b409f5d5..8fef99bb37509 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -811,8 +811,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp()); unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber(); - auto yieldOp = cast(forOp.getBody(0)->getTerminator()); - auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber) + auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber] .getDefiningOp(); if (!yieldingExtractSliceOp) return tensor::ExtractSliceOp(); @@ -826,7 +825,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, SmallVector initArgs = forOp.getInitArgs(); initArgs[iterArgNumber] = hoistedPackedTensor; - SmallVector yieldOperands = yieldOp.getOperands(); + SmallVector yieldOperands = llvm::to_vector(forOp.getYieldedValues()); yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource(); int64_t numOriginalForOpResults = initArgs.size(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 8d8481421e18d..508227d6e7ce4 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { // Replace all results with the yielded values. auto yieldOp = cast(getBody()->getTerminator()); - rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands()); + rewriter.replaceAllUsesWith(getResults(), getYieldedValues()); // Replace block arguments with lower bound (replacement for IV) and // iter_args. @@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern { LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const final { bool canonicalize = false; - Block &block = forOp.getRegion().front(); - auto yieldOp = cast(block.getTerminator()); // An internal flat vector of block transfer // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to // transformed block argument mappings. This plays the role of a // IRMapping for the particular use case of calling into // `inlineBlockBefore`. + int64_t numResults = forOp.getNumResults(); SmallVector keepMask; - keepMask.reserve(yieldOp.getNumOperands()); + keepMask.reserve(numResults); SmallVector newBlockTransferArgs, newIterArgs, newYieldValues, newResultValues; - newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size()); + newBlockTransferArgs.reserve(1 + numResults); newBlockTransferArgs.push_back(Value()); // iv placeholder with null value newIterArgs.reserve(forOp.getInitArgs().size()); - newYieldValues.reserve(yieldOp.getNumOperands()); - newResultValues.reserve(forOp.getNumResults()); + newYieldValues.reserve(numResults); + newResultValues.reserve(numResults); for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside forOp.getRegionIterArgs(), // iter inside region forOp.getResults(), // op results - yieldOp.getOperands() // iter yield + forOp.getYieldedValues() // iter yield )) { // Forwarded is `true` when: // 1) The region `iter` argument is yielded. @@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern { return failure(); // If the loop is empty, iterates at least once, and only returns values // defined outside of the loop, remove it and replace it with yield values. - auto yieldOp = cast(block.getTerminator()); - auto yieldOperands = yieldOp.getOperands(); - if (llvm::any_of(yieldOperands, + if (llvm::any_of(op.getYieldedValues(), [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) return failure(); - rewriter.replaceOp(op, yieldOperands); + rewriter.replaceOp(op, op.getYieldedValues()); return success(); } }; @@ -1224,6 +1221,10 @@ std::optional ForOp::getConstantStep() { return {}; } +ValueRange ForOp::getYieldedValues() { + return cast(getBody()->getTerminator()).getResults(); +} + Speculation::Speculatability ForOp::getSpeculatability() { // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start // and End. @@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() { return cast(getAfterBody()->getTerminator()); } +ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); } + Block::BlockArgListType WhileOp::getBeforeArguments() { return getBeforeBody()->getArguments(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 88c025c6b2b2e..79e5eefcb8603 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -605,9 +605,8 @@ struct ForOpInterface auto forOp = cast(op); OpOperand &forOperand = forOp.getOpOperandForResult(opResult); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - auto yieldOp = cast(forOp.getBody()->getTerminator()); bool equivalentYield = state.areEquivalentBufferizedValues( - bbArg, yieldOp->getOperand(opResult.getResultNumber())); + bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]); return equivalentYield ? BufferRelation::Equivalent : BufferRelation::Unknown; } diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 0cd19fbefa8ef..43e79d309c667 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -36,10 +36,9 @@ using namespace mlir::scf; /// type of the corresponding basic block argument of the loop. /// Note: This function handles only simple cases. Expand as needed. static bool isShapePreserving(ForOp forOp, int64_t arg) { - auto yieldOp = cast(forOp.getBody()->getTerminator()); - assert(arg < static_cast(yieldOp.getResults().size()) && + assert(arg < static_cast(forOp.getNumResults()) && "arg is out of bounds"); - Value value = yieldOp.getResults()[arg]; + Value value = forOp.getYieldedValues()[arg]; while (value) { if (value == forOp.getRegionIterArgs()[arg]) return true; diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index 781a21bb3ecd3..15a816f4e4488 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -52,3 +52,40 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) { } return false; } + +LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { + // Note: These invariants are also verified by the RegionBranchOpInterface, + // but the LoopLikeOpInterface provides better error messages. + auto loopLikeOp = cast(op); + + // Verify number of inits/iter_args/yielded values. + if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size()) + return op->emitOpError("different number of inits and region iter_args: ") + << loopLikeOp.getInits().size() + << " != " << loopLikeOp.getRegionIterArgs().size(); + if (loopLikeOp.getRegionIterArgs().size() != + loopLikeOp.getYieldedValues().size()) + return op->emitOpError( + "different number of region iter_args and yielded values: ") + << loopLikeOp.getRegionIterArgs().size() + << " != " << loopLikeOp.getYieldedValues().size(); + + // Verify types of inits/iter_args/yielded values. + int64_t i = 0; + for (const auto it : + llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(), + loopLikeOp.getYieldedValues())) { + if (std::get<0>(it).getType() != std::get<1>(it).getType()) + op->emitOpError(std::to_string(i)) + << "-th init and " << i << "-th region iter_arg have different type: " + << std::get<0>(it).getType() << " != " << std::get<1>(it).getType(); + if (std::get<1>(it).getType() != std::get<2>(it).getType()) + op->emitOpError(std::to_string(i)) + << "-th region iter_arg and " << i + << "-th yielded value have different type: " + << std::get<1>(it).getType() << " != " << std::get<2>(it).getType(); + ++i; + } + + return success(); +} diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index f6044ad108292..1b2c3f563195c 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -96,6 +96,32 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) { // ----- +func.func @too_many_iter_args(%arg0: index, %init: f32) { + // expected-error @below{{different number of inits and region iter_args: 1 != 2}} + %x = "scf.for"(%arg0, %arg0, %arg0, %init) ( + { + ^bb0(%i0 : index, %iter: f32, %iter2: f32): + scf.yield %iter, %iter : f32, f32 + } + ) : (index, index, index, f32) -> (f32) + return +} + +// ----- + +func.func @too_few_yielded_values(%arg0: index, %init: f32) { + // expected-error @below{{different number of region iter_args and yielded values: 2 != 1}} + %x, %x2 = "scf.for"(%arg0, %arg0, %arg0, %init, %init) ( + { + ^bb0(%i0 : index, %iter: f32, %iter2: f32): + scf.yield %iter : f32 + } + ) : (index, index, index, f32, f32) -> (f32, f32) + return +} + +// ----- + func.func @loop_if_not_i1(%arg0: index) { // expected-error@+1 {{operand #0 must be 1-bit signless integer}} "scf.if"(%arg0) ({}, {}) : (index) -> () @@ -422,7 +448,8 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) { %s0 = arith.constant 0.0 : f32 %t0 = arith.constant 1.0 : f32 - // expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}} + // expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}} + // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}} %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (f32, f32) { %sn = arith.addf %si, %si : f32 @@ -432,7 +459,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind return } - // ----- func.func @parallel_invalid_yield( diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 1d40615305c02..565d07669792f 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -50,9 +50,8 @@ struct TestSCFForUtilsPass auto newInitValues = forOp.getInitArgs(); if (newInitValues.empty()) return; - auto yieldOp = cast(forOp.getBody()->getTerminator()); - SmallVector oldYieldValues(yieldOp.getResults().begin(), - yieldOp.getResults().end()); + SmallVector oldYieldValues = + llvm::to_vector(forOp.getYieldedValues()); NewYieldValuesFn fn = [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { SmallVector newYieldValues;