diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h index cf1fd6e2d48ca..20162e6586600 100644 --- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h @@ -24,6 +24,7 @@ #define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H #include +#include #include namespace mlir::dataflow { @@ -101,10 +102,19 @@ struct RunLivenessAnalysis { RunLivenessAnalysis(Operation *op); const Liveness *getLiveness(Value val); + // This only remarks that Liveness results are stale. + void invalidate() { valid = false; } + /// Return the configuration of the solver used for this analysis. + const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); } + /// The function is called by analysis_impl::isInvalidated. + bool isInvalidated(AnalysisManager::PreservedAnalyses &) const { + return !valid; + } private: /// Stores the result of the liveness analysis that was run. DataFlowSolver solver; + bool valid{true}; }; } // end namespace mlir::dataflow diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index d705d8d4c7819..943c60bda9de6 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -325,5 +325,6 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { + assert(valid && "getLiveness called after invalidate"); return solver.lookupState(val); } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index e0c65b0e09774..0f34ebf263e27 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -46,6 +46,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/AnalysisManager.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" @@ -869,10 +870,151 @@ struct RemoveDeadValues : public impl::RemoveDeadValuesBase { }; } // namespace +/// If the target of CallOp is a public function and at least one argument is +/// NonLive, privatize the function. Our strategy here is separation interface +/// and implementation. eg. +/// +/// public void foo(int unused){...} +/// => +/// public void foo(int unused) { // old function, interface +/// return __foo_privatized(unused); +/// } +/// +/// private void __foo_privatized(int unused) { // the new private function, or +/// implementation. +/// ... // the function body of the +/// original function. +/// } +/// +/// Returns true if any IR changes were made, false otherwise. +static bool processCallOp(CallOpInterface callOp, ModuleOp moduleOp, + RunLivenessAnalysis &la) { + Operation *callableOp = callOp.resolveCallable(); + auto funcOp = dyn_cast(callableOp); + if (!funcOp || !funcOp.isPublic()) + return false; + + LDBG() << "Processing callOp " << callOp << " target is a public function: " + << funcOp.getOperation()->getName(); + + // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. + SmallVector arguments(callOp.getArgOperands()); + BitVector nonLiveArgs = markLives(arguments, DenseSet(), la); + nonLiveArgs = nonLiveArgs.flip(); + + if (nonLiveArgs.count() > 0) { + OpBuilder rewriter(moduleOp.getContext()); + + // Clone function and create private version + FunctionOpInterface clonedFunc = cast(funcOp.clone()); + + // Set visibility = 'private' and a new name for the cloned function + SymbolTable::setSymbolVisibility(clonedFunc, + SymbolTable::Visibility::Private); + std::string newName = "__" + funcOp.getName().str() + "_privatized"; + clonedFunc.setName(newName); + + // Insert the cloned function into the module + rewriter.setInsertionPointAfter(funcOp); + rewriter.insert(clonedFunc); + + // Replace ALL callsites of the original function to call the cloned + // function directly + LogicalResult result = SymbolTable::replaceAllSymbolUses( + funcOp, clonedFunc.getNameAttr(), moduleOp); + + if (result.failed()) { + LDBG() << "Failed to replace all symbol uses for " << funcOp.getName(); + return false; + } + + LDBG() << "Redirected all callsites from " << funcOp.getName() << " to " + << newName; + + // Transform the original funcOp into a wrapper that calls the cloned + // function + Region &funcBody = funcOp.getFunctionBody(); + + // Clean the original function body + funcBody.dropAllReferences(); + funcBody.getBlocks().clear(); + + // Create a new entry block for the wrapper function + Block *wrapperBlock = rewriter.createBlock(&funcBody); + + // Add block arguments that match the function signature + for (Type argType : funcOp.getArgumentTypes()) { + wrapperBlock->addArgument(argType, funcOp.getLoc()); + } + + // Set insertion point to the new block + rewriter.setInsertionPointToStart(wrapperBlock); + + // Clone the original call operation and update its callee + auto clonedCallOp = cast(callOp->clone()); + // Update the callee symbol reference to point to the new private function + auto symbolRef = + SymbolRefAttr::get(funcOp.getContext(), clonedFunc.getName()); + clonedCallOp.setCalleeFromCallable(symbolRef); + // Set the call arguments to use the wrapper block's arguments + clonedCallOp->setOperands(wrapperBlock->getArguments()); + rewriter.insert(clonedCallOp); + + // Create return operation of the same type as the original function's + // return + Operation *returnOp = nullptr; + for (Block &block : clonedFunc.getFunctionBody()) { + if (block.getNumSuccessors() > 0) + continue; + + Operation *terminator = block.getTerminator(); + if (terminator && terminator->hasTrait()) { + returnOp = terminator; + break; // Use first return as template + } + } + + if (returnOp) { + Operation *newReturnOp = returnOp->clone(); + newReturnOp->setOperands(clonedCallOp->getResults()); + newReturnOp->setLoc(funcOp.getLoc()); + rewriter.insert(newReturnOp); + } + return true; // Changes were made + } + + return false; +} + void RemoveDeadValues::runOnOperation() { - auto &la = getAnalysis(); + AnalysisManager am = getAnalysisManager(); + RunLivenessAnalysis *la = &am.getAnalysis(); Operation *module = getOperation(); + // Only privatize public functions if liveness analysis is inter-procedural. + if (la->getSolverConfig().isInterprocedural()) { + bool changed = false; + module->walk([&](CallOpInterface callOp) { + if (processCallOp(callOp, cast(module), *la)) { + changed = true; + } + }); + + if (changed) { + LDBG() << "IR has changed, invalidate RunLivenessAnalysis only"; + auto &pa = getPassState().preservedAnalyses; + bool preserved = pa.isPreserved(); + la->invalidate(); + am.invalidate(pa); + la = &am.getAnalysis(); + // If RunLivenessAnalysis was previously preserved, preserved the updated + // results. + if (preserved) { + pa.preserve(); + } + } + } + // Tracks values eligible for erasure - complements liveness analysis to // identify "droppable" values. DenseSet deadVals; @@ -883,11 +1025,11 @@ void RemoveDeadValues::runOnOperation() { module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { - processFuncOp(funcOp, module, la, deadVals, finalCleanupList); + processFuncOp(funcOp, module, *la, deadVals, finalCleanupList); } else if (auto regionBranchOp = dyn_cast(op)) { - processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); + processRegionBranchOp(regionBranchOp, *la, deadVals, finalCleanupList); } else if (auto branchOp = dyn_cast(op)) { - processBranchOp(branchOp, la, deadVals, finalCleanupList); + processBranchOp(branchOp, *la, deadVals, finalCleanupList); } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { // Nothing to do here because this is a terminator op and it should be // honored with respect to its parent @@ -895,7 +1037,7 @@ void RemoveDeadValues::runOnOperation() { // Nothing to do because this op is associated with a function op and gets // cleaned when the latter is cleaned. } else { - processSimpleOp(op, la, deadVals, finalCleanupList); + processSimpleOp(op, *la, deadVals, finalCleanupList); } }); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 56449469dc29f..bcac4638604fa 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -571,6 +571,54 @@ module @return_void_with_unused_argument { } } +// check that public functions with non-live arguments correctly. +module @public_function_with_nonlive_arguments { + // the function signature is immutable because it is public. + func.func public @public_fn_with_unused_argument(%unused: i32) -> () { + return + } + // CHECK-LABEL: func.func @main + // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> () + func.func @main() -> () { + %zero = arith.constant 0 : i32 + call @public_fn_with_unused_argument(%zero) : (i32) -> () + return + } + + // CHECK-LABEL: func.func @main2 + // CHECK: call @__public_fn_with_unused_argument_privatized() : () -> () + func.func @main2(%arg0: i1) { + %0 = scf.if %arg0 -> (i32) { + %c1_i32 = arith.constant 1 : i32 + scf.yield %c1_i32 : i32 + } else { + %c0_i32 = arith.constant 0 : i32 + scf.yield %c0_i32 : i32 + } + + call @public_fn_with_unused_argument(%0) : (i32) -> () + return + } + + func.func public @fn_return_multiple(%arg0: i32) -> (i32, i32, i32) { + %one = arith.constant 1 : i32 + %two = arith.constant 2 : i32 + %three = arith.constant 4 : i32 + + return %one, %two, %three: i32, i32, i32 + } + + // CHECK-LABEL: func.func @main3 + // CHECK: call @__fn_return_multiple_privatized() : () -> (i32, i32, i32) + func.func @main3(%arg: i32) -> () { + %one = arith.constant 1 : i32 + %scalar = arith.addi %arg, %one: i32 + + call @fn_return_multiple(%scalar) : (i32) -> (i32, i32, i32) + return + } +} + // ----- // CHECK-LABEL: module @dynamically_unreachable